/*==========================================================================;
 *
 *  (c) 2008 JSI.  All rights reserved.
 *
 *  File:          ModelUtils.cs
 *  Version:       1.0
 *  Desc:		   Common ML-related routines
 *  Author:        Miha Grcar
 *  Created on:    Aug-2008
 *  Last modified: Aug-2008
 *  Revision:      Aug-2008
 *
 ***************************************************************************/

using System;
using System.Collections.Generic;
//using Latino.TextGarden;

namespace Latino.Model
{
    /* .-----------------------------------------------------------------------
       |
       |  Enum GroupClassifyType
       |
       '-----------------------------------------------------------------------
    */
    public enum GroupClassifyMethod
    {
        Sum,
        Max,
        Vote
    }

    /* .-----------------------------------------------------------------------
       |
       |  Static class ModelUtils
       |
       '-----------------------------------------------------------------------
    */
    public static class ModelUtils
    {
        public static OutExT ConvertExample<InExT, OutExT>(InExT in_ex)
        {
            Utils.ThrowException(in_ex == null ? new ArgumentNullException("in_ex") : null);
            SparseVector2<double> tmp;
            if (typeof(InExT) == typeof(SparseVector2<double>))
            {
                tmp = (SparseVector2<double>)(object)in_ex;
            }
            //else if (typeof(InExT) == typeof(BowSpV))
            //{
            //    tmp = new SparseVector2<double>(((BowSpV)(object)in_ex).Count);
            //    tmp.Inner.AddRange((BowSpV)(object)in_ex);
            //    tmp.Inner.Sort(); // *** this is neccessary
            //}
            else if (typeof(InExT) == typeof(SvmFeatureVector))
            {
                tmp = new SparseVector2<double>(((SvmFeatureVector)(object)in_ex).Count);
                foreach (IdxDat2<float> item in (SvmFeatureVector)(object)in_ex)
                {
                    tmp.Inner.Add(new IdxDat2<double>(item.Idx, item.Dat));
                }
            }
            else
            {
                throw new ArgumentTypeException("InExT");
            }
            OutExT out_ex = default(OutExT);
            if (typeof(OutExT) == typeof(SparseVector2<double>))
            {
                out_ex = (OutExT)(object)tmp;
            }
            //else if (typeof(OutExT) == typeof(BowSpV))
            //{
            //    out_ex = (OutExT)(object)new BowSpV(tmp);
            //}
            else if (typeof(OutExT) == typeof(SvmFeatureVector))
            {
                SparseVector2<float> tmp_2 = new SparseVector2<float>(tmp.Count);
                foreach (IdxDat2<double> item in tmp)
                {
                    tmp_2.Inner.Add(new IdxDat2<float>(item.Idx, (float)item.Dat)); // *** casting double to float
                }
                out_ex = (OutExT)(object)new SvmFeatureVector(tmp_2);
            }
            else
            {
                throw new ArgumentTypeException("OutExT");
            }
            return out_ex;
        }
        public static ArrayList<KeyDat2<double, LblT>> ClassifyGroup<LblT, ExT>(IEnumerable<ExT> examples, IModel<LblT, ExT> model)
        {
            return ClassifyGroup<LblT, ExT>(examples, model, GroupClassifyMethod.Sum, /*lbl_cmp=*/null); // throws InvalidOperationException, ArgumentNullException
        }
        public static ArrayList<KeyDat2<double, LblT>> ClassifyGroup<LblT, ExT>(IEnumerable<ExT> examples, IModel<LblT, ExT> model, GroupClassifyMethod method)
        {
            return ClassifyGroup<LblT, ExT>(examples, model, method, /*lbl_cmp=*/null); // throws InvalidOperationException, ArgumentNullException
        }
        public static ArrayList<KeyDat2<double, LblT>> ClassifyGroup<LblT, ExT>(IEnumerable<ExT> examples, IModel<LblT, ExT> model, GroupClassifyMethod method, IEqualityComparer<LblT> lbl_cmp)
        {
            Dictionary<LblT, double> tmp = new Dictionary<LblT, double>(lbl_cmp);
            foreach (ExT example in examples)
            {
                Classification<LblT> result = model.Classify(example); // throws InvalidOperationException, ArgumentNullException
                foreach (KeyDat2<double, LblT> lbl_info in result)
                {
                    if (method == GroupClassifyMethod.Vote)
                    {
                        if (!tmp.ContainsKey(lbl_info.Dat)) { tmp.Add(lbl_info.Dat, 1); } else { tmp[lbl_info.Dat]++; }
                        break;
                    }
                    else
                    {
                        if (!tmp.ContainsKey(lbl_info.Dat))
                        {
                            tmp.Add(lbl_info.Dat, lbl_info.Key);
                        }
                        else
                        {
                            switch (method)
                            {
                                case GroupClassifyMethod.Max:
                                    tmp[lbl_info.Dat] = Math.Max(lbl_info.Key, tmp[lbl_info.Dat]);
                                    break;
                                case GroupClassifyMethod.Sum:
                                    tmp[lbl_info.Dat] += lbl_info.Key;
                                    break;
                            }
                        }
                    }
                }
            }
            ArrayList<KeyDat2<double, LblT>> final_result = new ArrayList<KeyDat2<double, LblT>>();
            foreach (KeyValuePair<LblT, double> item in tmp)
            {
                final_result.Add(new KeyDat2<double, LblT>(item.Value, item.Key));
            }
            final_result.Sort(new DescSort<KeyDat2<double, LblT>>());
            return final_result;
        }
        internal static int[] GetIdList<LblT>(IExampleCollection<LblT, SvmFeatureVector> dataset)
        {
            int[] id_list = new int[dataset.Count];
            int i = 0;
            foreach (Pair2<LblT, SvmFeatureVector> example in dataset) { id_list[i++] = example.Second.Id; }
            return id_list;
        }
    }
}
