/*==========================================================================;
 *
 *  (c) 2007-08 JSI.  All rights reserved.
 *
 *  File:          Models.cs
 *  Version:       1.0
 *  Desc:		   Machine learning models
 *  Author:        Miha Grcar
 *  Created on:    Aug-2007
 *  Last modified: Oct-2008
 *  Revision:      Oct-2008
 *
 ***************************************************************************/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
//using Latino.TextGarden;

namespace Latino.Model
{
    /* .-----------------------------------------------------------------------
       |
       |  Interface IModel<LblT>
       |
       '-----------------------------------------------------------------------
    */
    public interface IModel<LblT> : ISerializable
    {
        Type RequiredExampleType { get; }
        bool IsTrained { get; }
        void Train(IExampleCollection<LblT> dataset);
        Classification<LblT> Classify(object example);
    }

    /* .-----------------------------------------------------------------------
       |
       |  Interface IModel<LblT, ExT>
       |
       '-----------------------------------------------------------------------
    */
    public interface IModel<LblT, ExT> : IModel<LblT>
    {
        void Train(IExampleCollection<LblT, ExT> dataset);
        Classification<LblT> Classify(ExT example);
    }

    /* .-----------------------------------------------------------------------
       |
       |  Class Classification<LblT>
       |
       '-----------------------------------------------------------------------
    */
    public class Classification<LblT> : IEnumerableList<KeyDat2<double, LblT>>
    {
        private ArrayList<KeyDat2<double, LblT>> m_classes
            = new ArrayList<KeyDat2<double, LblT>>();
        private static DescSort<KeyDat2<double, LblT>> m_desc_sort 
            = new DescSort<KeyDat2<double, LblT>>();
        internal Classification()
        {
        }
        public Classification(IEnumerable<KeyDat2<double, LblT>> classes)
        {
            Utils.ThrowException(classes == null ? new ArgumentNullException("classes") : null);
            AddRange(classes);
        }
        internal void Add(KeyDat2<double, LblT> @class)
        { 
            m_classes.InsertSorted(@class, m_desc_sort);
        }
        internal void AddRange(IEnumerable<KeyDat2<double, LblT>> classes)
        {
            foreach (KeyDat2<double, LblT> @class in classes)
            {
                m_classes.Add(@class);
            }
            m_classes.Sort(m_desc_sort);
        }
        public double GetScoreAt(int idx)
        {
            Utils.ThrowException((idx < 0 || idx >= m_classes.Count) ? new ArgumentOutOfRangeException("idx") : null);
            return m_classes[idx].Key;
        }
        public LblT GetClassLabelAt(int idx)
        {
            Utils.ThrowException((idx < 0 || idx >= m_classes.Count) ? new ArgumentOutOfRangeException("idx") : null);
            return m_classes[idx].Dat;        
        }
        public double BestScore
        {
            get 
            {
                Utils.ThrowException(m_classes.Count == 0 ? new InvalidOperationException() : null);
                return m_classes[0].Key; 
            }
        }
        public LblT BestClassLabel
        {
            get 
            {
                Utils.ThrowException(m_classes.Count == 0 ? new InvalidOperationException() : null);
                return m_classes[0].Dat; 
            }
        }
        // *** IEnumerableList<KeyDat2<double, LblT>> interface implementation ***
        public int Count
        {
            get { return m_classes.Count; }
        }
        public KeyDat2<double, LblT> this[int idx]
        {
            get
            {
                Utils.ThrowException((idx < 0 || idx >= m_classes.Count) ? new ArgumentOutOfRangeException("idx") : null);
                return m_classes[idx];
            }
        }
        object IEnumerableList.this[int idx]
        {
            get { return this[idx]; } // throws ArgumentOutOfRangeException
        }
        public IEnumerator<KeyDat2<double, LblT>> GetEnumerator()
        {
            return new ListEnum<KeyDat2<double, LblT>>(this);
        }
        IEnumerator IEnumerable.GetEnumerator()
        {
            return new ListEnum(this);
        }
    }

    ///* .-----------------------------------------------------------------------
    //   |
    //   |  Class CentroidClassifier<LblT>
    //   |
    //   '-----------------------------------------------------------------------
    //*/
    //public class CentroidClassifier<LblT> : IModel<LblT, BowSpV>
    //{
    //    private ArrayList<Pair2<LblT, BowSpV>> m_centroids
    //        = null;
    //    private CosSim m_cos_sim
    //        = new CosSim();
    //    private bool m_normalize
    //        = true;
    //    private IEqualityComparer<LblT> m_lbl_cmp
    //        = null;
    //    public bool NormalizeCentroids
    //    {
    //        get { return m_normalize; }
    //        set { m_normalize = value; }
    //    }
    //    public IEqualityComparer<LblT> LabelEqualityComparer
    //    {
    //        get { return m_lbl_cmp; }
    //        set { m_lbl_cmp = value; }
    //    }
    //    // *** IModel<LblT, BowSpV> interface implementation ***
    //    public Type RequiredExampleType
    //    {
    //        get { return typeof(BowSpV); }
    //    }
    //    public bool IsTrained
    //    {
    //        get { return m_centroids != null; }
    //    }
    //    public void Train(IExampleCollection<LblT, BowSpV> dataset)
    //    {
    //        Utils.ThrowException(dataset == null ? new ArgumentNullException("dataset") : null);
    //        Utils.ThrowException(dataset.Count == 0 ? new InvalidArgumentValueException("dataset") : null);
    //        m_centroids = new ArrayList<Pair2<LblT, BowSpV>>();
    //        Dictionary<LblT, ArrayList<BowSpV>> tmp = new Dictionary<LblT, ArrayList<BowSpV>>(m_lbl_cmp);
    //        foreach (Pair2<LblT, BowSpV> labeled_example in dataset)
    //        {
    //            if (!tmp.ContainsKey(labeled_example.First))
    //            {
    //                tmp.Add(labeled_example.First, new ArrayList<BowSpV>(new BowSpV[] { labeled_example.Second }));
    //            }
    //            else
    //            {
    //                tmp[labeled_example.First].Add(labeled_example.Second);
    //            }
    //        }
    //        foreach (KeyValuePair<LblT, ArrayList<BowSpV>> centroid_data in tmp)
    //        {
    //            BowSpV centroid = TextGardenUtils.ComputeCentroid(centroid_data.Value, m_normalize);
    //            m_centroids.Add(new Pair2<LblT, BowSpV>(centroid_data.Key, centroid));                
    //        }
    //    }
    //    void IModel<LblT>.Train(IExampleCollection<LblT> dataset)
    //    {
    //        Utils.ThrowException(dataset == null ? new ArgumentNullException("dataset") : null);
    //        Utils.ThrowException(!(dataset is IExampleCollection<LblT, BowSpV>) ? new ArgumentTypeException("dataset") : null);
    //        Train((IExampleCollection<LblT, BowSpV>)dataset); // throws InvalidArgumentValueException
    //    }
    //    public Classification<LblT> Classify(BowSpV example)
    //    {
    //        Utils.ThrowException(m_centroids == null ? new InvalidOperationException() : null);
    //        Utils.ThrowException(example == null ? new ArgumentNullException("example") : null);
    //        Classification<LblT> result = new Classification<LblT>();
    //        foreach (Pair2<LblT, BowSpV> labeled_centroid in m_centroids)
    //        {
    //            double sim = m_cos_sim.GetSim(labeled_centroid.Second, example);
    //            result.Add(new KeyDat2<double, LblT>(sim, labeled_centroid.First));
    //        }
    //        return result;
    //    }
    //    Classification<LblT> IModel<LblT>.Classify(object example)
    //    {
    //        Utils.ThrowException(example == null ? new ArgumentNullException("example") : null);
    //        Utils.ThrowException(!(example is BowSpV) ? new ArgumentTypeException("example") : null);
    //        return Classify((BowSpV)example); // throws InvalidOperationException
    //    }
    //    // *** ISerializable interface implementation ***
    //    public void Save(BinarySerializer writer)
    //    {
    //        throw new NotImplementedException();
    //    }
    //}

    /* .-----------------------------------------------------------------------
       |
       |  Class SvmMulticlassClassifier<LblT>
       |
       '-----------------------------------------------------------------------
    */
    public class SvmMulticlassClassifier<LblT> : IModel<LblT, SvmFeatureVector>, IDisposable
    {
        private Dictionary<LblT, int> m_lbl_to_id
            = new Dictionary<LblT, int>();
        private ArrayList<LblT> m_id_to_lbl
            = new ArrayList<LblT>();
        private Guid m_guid
            = Guid.NewGuid();
        private int m_id
            = -1;
        private int m_v // verbosity level [0..3] 
            = 1;
        private double m_c // c: trade-off between training error and margin (0..inf)
            = 0.01;
        private double m_e // eps: allow that error for termination criterion (0..inf)
            = 0.1;
        private string m_args // additional arguments; no error checking (see http://svmlight.joachims.org/svm_multiclass.html)
            = "";
        public SvmMulticlassClassifier()
        {
        }
        public SvmMulticlassClassifier(IEqualityComparer<LblT> lbl_cmp)
        {
            m_lbl_to_id = new Dictionary<LblT, int>(lbl_cmp);
        }
        public SvmMulticlassClassifier(BinarySerializer reader)
        {
            Load(reader); // throws ArgumentNullException, serialization-related exceptions
        }
        public SvmMulticlassClassifier(BinarySerializer reader, IEqualityComparer<LblT> lbl_cmp)
        {
            Load(reader, lbl_cmp); // throws ArgumentNullException, serialization-related exceptions
        }
        ~SvmMulticlassClassifier()
        {
            Dispose();
        }
        public int VerbosityLevel
        {
            get { return m_v; }
            set 
            {
                Utils.ThrowException((value < 0 || value > 3) ? new ArgumentOutOfRangeException("VerbosityLevel setter value") : null);
                m_v = value; 
            }
        }
        public double C
        {
            get { return m_c; }
            set 
            {
                Utils.ThrowException(value <= 0 ? new ArgumentOutOfRangeException("C setter value") : null);
                m_c = value; 
            }
        }
        public double Eps
        {
            get { return m_e; }
            set 
            {
                Utils.ThrowException(value <= 0 ? new ArgumentOutOfRangeException("Eps setter value") : null);
                m_e = value; 
            }
        }
        public string Args
        {
            get { return m_args; }
            set 
            {
                Utils.ThrowException(value == null ? new ArgumentNullException("Args setter value") : null);
                m_args = value;
            }
        }
        // *** IModel<LblT, SvmFeatureVector> interface implementation ***
        public Type RequiredExampleType
        {
            get { return typeof(SvmFeatureVector); }
        }
        public bool IsTrained
        {
            get { return m_id > 0; }
        }
        public Guid Guid
        {
            get { return m_guid; }
        }
        public void Train(IExampleCollection<LblT, SvmFeatureVector> dataset)
        {
            Utils.ThrowException(dataset == null ? new ArgumentNullException("dataset") : null);
            Utils.ThrowException(dataset.Count == 0 ? new InvalidArgumentValueException("dataset") : null);
            Dispose();
            m_lbl_to_id.Clear();
            m_id_to_lbl.Clear();
            //for (int i = 1; i <= 7; i++)                
            //{                                           
            //    m_lbl_to_id.Add((LblT)(object)i, i - 1);
            //    m_id_to_lbl.Add((LblT)(object)i);       
            //}                                           
            foreach (Pair2<LblT, SvmFeatureVector> example in dataset)
            {
                int label_id;
                if (!m_lbl_to_id.ContainsKey(example.First))
                {
                    m_lbl_to_id.Add(example.First, label_id = m_lbl_to_id.Count);
                    m_id_to_lbl.Add(example.First);
                }
                else
                { 
                    label_id = m_lbl_to_id[example.First];
                }
                SvmLightLib.SetFeatureVectorLabel(example.Second.Id, label_id + 1);
            }
            int[] id_list = ModelUtils.GetIdList(dataset); // remove this?! ***********************************
            m_id = SvmLightLib.TrainMulticlassModel(string.Format("-v {0} -c {1} -e {2} {3}", 
                m_v, m_c.ToString(CultureInfo.InvariantCulture), m_e.ToString(CultureInfo.InvariantCulture), m_args), id_list.Length, id_list);
            GC.KeepAlive(dataset);
        }
        void IModel<LblT>.Train(IExampleCollection<LblT> dataset)
        {
            Utils.ThrowException(dataset == null ? new ArgumentNullException("dataset") : null);
            Utils.ThrowException(!(dataset is IExampleCollection<double, SvmFeatureVector>) ? new ArgumentTypeException("dataset") : null);
            Train((IExampleCollection<LblT, SvmFeatureVector>)dataset);
        }
        public Classification<LblT> Classify(SvmFeatureVector example)
        {
            Utils.ThrowException(m_id <= 0 ? new InvalidOperationException() : null);
            Utils.ThrowException(example == null ? new ArgumentNullException("example") : null);
            Classification<LblT> result = new Classification<LblT>();
            SvmLightLib.MulticlassClassify(m_id, /*feature_vector_count=*/1, new int[] { example.Id });
            int classif_score_count = SvmLightLib.GetFeatureVectorClassifScoreCount(example.Id);
            for (int i = 0; i < classif_score_count; i++)
            {
                result.Add(new KeyDat2<double, LblT>(SvmLightLib.GetFeatureVectorClassifScore(example.Id, i), m_id_to_lbl[i])); 
            }
            return result;
        }
        Classification<LblT> IModel<LblT>.Classify(object example)
        {
            Utils.ThrowException(example == null ? new ArgumentNullException("example") : null);
            Utils.ThrowException(!(example is SvmFeatureVector) ? new ArgumentTypeException("example") : null);
            return Classify((SvmFeatureVector)example); // throws InvalidOperationException
        }
        // *** IDisposable interface implementation ***
        public void Dispose()
        {
            if (m_id > 0)
            {
                SvmLightLib.DeleteMulticlassModel(m_id);
                m_id = -1;
            }
        }
        // *** ISerializable interface implementation ***
        public void Save(BinarySerializer writer)
        {
            Utils.ThrowException(m_id <= 0 ? new InvalidOperationException() : null);
            Utils.ThrowException(writer == null ? new ArgumentNullException("writer") : null);                 
            // the following functions throw serialization-related exceptions
            Utils.SaveDictionary(m_lbl_to_id, writer);
            m_id_to_lbl.Save(writer);
            writer.WriteInt(m_v);
            writer.WriteDouble(m_c);
            writer.WriteDouble(m_e);
            writer.WriteString(m_args);
            writer.WriteString(m_guid.ToString("N"));
        }
        public void Load(BinarySerializer reader)
        {
            Load(reader, /*lbl_cmp=*/null); // throws ArgumentNullException, InvalidArgumentValueException, serialization-related exceptions
        }
        public void Load(BinarySerializer reader, IEqualityComparer<LblT> lbl_cmp)
        {
            Utils.ThrowException(reader == null ? new ArgumentNullException("reader") : null);         
            Dispose();
            // the following functions throw serialization-related exceptions
            m_lbl_to_id = Utils.LoadDictionary<LblT, int>(reader, lbl_cmp);
            m_id_to_lbl = new ArrayList<LblT>(reader);
            m_v = reader.ReadInt();
            m_c = reader.ReadDouble();
            m_e = reader.ReadDouble();
            m_args = reader.ReadString();
            m_guid = new Guid(reader.ReadString());
        }
        public void WriteSvmModel(string file_name)
        {
            Utils.ThrowException(m_id <= 0 ? new InvalidOperationException() : null);
            Utils.ThrowException(file_name == null ? new ArgumentNullException("file_name") : null);
            Utils.ThrowException(!Utils.VerifyFileName(file_name, /*must_exist=*/false) ? new InvalidArgumentValueException("file_name") : null);
            SvmLightLib.SaveMulticlassModelBin(m_id, file_name);
        }
        public void ReadSvmModel(string file_name)
        {
            Utils.ThrowException((m_id > 0 || m_id_to_lbl.Count == 0) ? new InvalidOperationException() : null);
            Utils.ThrowException(file_name == null ? new ArgumentNullException("file_name") : null);
            Utils.ThrowException(!Utils.VerifyFileName(file_name, /*must_exist=*/true) ? new InvalidArgumentValueException("file_name") : null);
            m_id = SvmLightLib.LoadMulticlassModelBin(file_name);
        }
    }
}