using System;
using System.Collections.Generic;
using Latino;
using Latino.Model;

namespace SvmPosTagger
{
    /* .-----------------------------------------------------------------------
       |
       |  Class PosModel
       |
       '-----------------------------------------------------------------------
    */
    public class PosModel : ISerializable
    {
        private ArrayList<Pair2<string, SparseVector2<double>.ReadOnly>> m_training_set
            = new ArrayList<Pair2<string, SparseVector2<double>.ReadOnly>>();
        private Set<string> m_tags
            = new Set<string>();
        private Dictionary<int, int> m_feature_mapping
            = new Dictionary<int, int>();
        private SvmMulticlassClassifier<string> m_model
            = null; 
        private bool m_is_trained
            = false;
        public PosModel()
        {
        }
        public PosModel(BinarySerializer reader)
        {
            Load(reader); // throws ArgumentNullException, serialization-related exceptions
        }
        public void AddExample(string tag, SparseVector2<double>.ReadOnly feature_vector) 
        {
            Utils.ThrowException(m_is_trained ? new InvalidOperationException() : null);
            Utils.ThrowException(feature_vector == null ? new ArgumentNullException("feature_vector") : null);
            m_training_set.Add(new Pair2<string, SparseVector2<double>.ReadOnly>(tag, feature_vector));
            m_tags.Add(tag);
        }
        public Classification<string> Classify(SparseVector2<double>.ReadOnly feature_vector) 
        {
            Utils.ThrowException(!m_is_trained ? new InvalidOperationException() : null);
            Utils.ThrowException(feature_vector == null ? new ArgumentNullException("feature_vector") : null);
            if (m_tags.Count == 1) 
            {
                return new Classification<string>(new KeyDat2<double, string>[] { 
                    new KeyDat2<double, string>(1, m_tags.Any) // *** classification score set to 1 
                });                
            }
            return m_model.Classify(PrepareFeatureVector(feature_vector));
        }
        private SvmFeatureVector PrepareFeatureVector(SparseVector2<double>.ReadOnly feature_vector)
        {
            SparseVector2<float> tmp = new SparseVector2<float>();
            foreach (IdxDat2<double> item in feature_vector)
            {
                if (m_feature_mapping.ContainsKey(item.Idx))
                {
                    tmp[m_feature_mapping[item.Idx]] = (float)item.Dat;
                }
            }
            return new SvmFeatureVector(tmp);
        }
        public void Train(bool verbose, double c, double eps, int min_feat_freq) 
        {
            Utils.ThrowException(m_is_trained ? new InvalidOperationException() : null);
            Utils.ThrowException(c <= 0 ? new ArgumentOutOfRangeException("c") : null);
            Utils.ThrowException(eps <= 0 ? new ArgumentOutOfRangeException("eps") : null);
            Utils.ThrowException(min_feat_freq <= 0 ? new ArgumentOutOfRangeException("min_feat_freq") : null);
            if (m_tags.Count != 1)
            {
                // prepare feature space and training set
                Dictionary<int, int> feat_freq = new Dictionary<int, int>();
                foreach (Pair2<string, SparseVector2<double>.ReadOnly> example in m_training_set)
                {
                    foreach (IdxDat2<double> item in example.Second)
                    {
                        if (!feat_freq.ContainsKey(item.Idx))
                        {
                            feat_freq.Add(item.Idx, 1);
                        }
                        else
                        {
                            feat_freq[item.Idx]++;
                        }
                    }
                }
                foreach (KeyValuePair<int, int> item in feat_freq)
                {
                    if (item.Value >= min_feat_freq) { m_feature_mapping.Add(item.Key, m_feature_mapping.Count); }
                }
                SvmDataset<string> dataset = new SvmDataset<string>();                
                for (int i = m_training_set.Count - 1; i >= 0; i--)
                {
                    dataset.Add(m_training_set[i].First, PrepareFeatureVector(m_training_set[i].Second));
                    m_training_set.RemoveAt(i); // feature vectors are gradually removed from the dataset to save space
                }
                // train SVM model
                m_model = new SvmMulticlassClassifier<string>();
                m_model.VerbosityLevel = verbose ? 1 : 0;
                m_model.C = c;
                m_model.Eps = eps;
                m_model.Train(dataset);                
            }
            m_is_trained = true;
        }
        public bool IsTrained
        {
            get { return m_is_trained; }
        }
        // *** ISerializable interface implementation ***
        public void Save(BinarySerializer writer)
        {
            Utils.ThrowException(!m_is_trained ? new InvalidOperationException() : null);
            Utils.ThrowException(writer == null ? new ArgumentNullException("writer") : null);
            // the following functions throw serialization-related exceptions
            m_tags.Save(writer);
            Utils.SaveDictionary(m_feature_mapping, writer);
            writer.WriteObject(m_model);
            if (m_model != null)
            {
                m_model.WriteSvmModel(string.Format("{0}\\{1}.SvmModel", writer.DataDir.TrimEnd('\\', '/'), m_model.Guid.ToString("N")));
            }
        }
        public void Load(BinarySerializer reader)
        {
            Utils.ThrowException(m_is_trained ? new InvalidOperationException() : null);
            Utils.ThrowException(reader == null ? new ArgumentNullException("reader") : null);
            // the following functions throw serialization-related exceptions
            m_training_set.Clear();
            m_tags = new Set<string>(reader);
            m_feature_mapping = Utils.LoadDictionary<int, int>(reader);
            m_model = reader.ReadObject<SvmMulticlassClassifier<string>>();
            if (m_model != null)
            {
                m_model.ReadSvmModel(string.Format("{0}\\{1}.SvmModel", reader.DataDir.TrimEnd('\\', '/'), m_model.Guid.ToString("N")));
            }
            m_is_trained = true;
        }
    }

    /* .-----------------------------------------------------------------------
       |
       |  Class ModelIndex
       |
       '-----------------------------------------------------------------------
    */
    public class ModelIndex : ISerializable
    {
        private Dictionary<string, PosModel> m_index
            = new Dictionary<string, PosModel>();
        private ArrayList<PosModel> m_models
            = new ArrayList<PosModel>();
        public ModelIndex(WordDictionary dictionary)
        {
            Utils.ThrowException(dictionary == null ? new ArgumentNullException("dictionary") : null);
            Utils.ThrowException(dictionary.GetWords().Length == 0 ? new InvalidArgumentValueException("dictionary") : null);
            ArrayList<Set<string>> tag_classes = ComputeDisjunctiveTagClasses(dictionary);
            foreach (Set<string> tag_class in tag_classes)
            {
                PosModel model = new PosModel();
                m_models.Add(model);
                foreach (string tag in tag_class)
                {
                    m_index.Add(tag, model);
                }
            }
        }
        public ModelIndex(BinarySerializer reader)
        {
            Load(reader); // throws ArgumentNullException, serialization-related exceptions
        }
        private static ArrayList<Set<string>> ComputeDisjunctiveTagClasses(WordDictionary dictionary)
        {
            ArrayList<Set<string>> tag_classes = new ArrayList<Set<string>>();
            foreach (string word in dictionary.GetWords())
            {
                Set<string>.ReadOnly tags = dictionary.GetTags(word);
                tag_classes.Add(tags.GetWritableCopy());
            }
            for (int i = tag_classes.Count - 1; i >= 0; i--)
            {
                Set<string> tag_class = tag_classes[i];
                for (int j = i - 1; j >= 0; j--)
                {
                    Set<string> other_class = tag_classes[j];
                    if (Set<string>.Intersection(tag_class, other_class).Count > 0)
                    {
                        other_class.AddRange(tag_class);
                        tag_classes.RemoveAt(i);
                        break;
                    }
                }
            }
            return tag_classes;
        }
        public void AddExample(string tag, SparseVector2<double>.ReadOnly feature_vector) 
        {
            Utils.ThrowException(IsTrained ? new InvalidOperationException() : null);
            Utils.ThrowException(feature_vector == null ? new ArgumentNullException("feature_vector") : null);
            Utils.ThrowException(!m_index.ContainsKey(tag) ? new InvalidArgumentValueException("tag") : null);
            m_index[tag].AddExample(tag, feature_vector);
        }
        public string Classify(SparseVector2<double>.ReadOnly feature_vector, Set<string>.ReadOnly possible_tags) 
        {
            Utils.ThrowException(!IsTrained ? new InvalidOperationException() : null);
            Utils.ThrowException(feature_vector == null ? new ArgumentNullException("feature_vector") : null);
            Utils.ThrowException(possible_tags == null ? new ArgumentNullException("possible_tags") : null);
            string key_tag = possible_tags.Any;
            Utils.ThrowException((possible_tags.Count == 0 || !m_index.ContainsKey(key_tag)) ? new InvalidArgumentValueException("possible_tags") : null);
            Classification<string> classification = m_index[key_tag].Classify(feature_vector);
            foreach (KeyDat2<double, string> tag_score in classification)
            {
                if (possible_tags.Contains(tag_score.Dat)) { return tag_score.Dat; }
            }
            throw new InvalidOperationException();
        }
        public void Train(bool verbose, double c, double eps, int min_feat_freq)
        {
            Utils.ThrowException(c <= 0 ? new ArgumentOutOfRangeException("c") : null);
            Utils.ThrowException(eps <= 0 ? new ArgumentOutOfRangeException("eps") : null);
            Utils.ThrowException(min_feat_freq <= 0 ? new ArgumentOutOfRangeException("min_feat_freq") : null);
            Utils.ThrowException(IsTrained ? new InvalidOperationException() : null);
            foreach (PosModel model in m_models)
            {
                model.Train(verbose, c, eps, min_feat_freq); 
            }
        }
        public bool IsTrained
        {
            get { return m_models[0].IsTrained; } // there is at least one model in the array (ensured by the constructor)
        }
        // *** ISerializable interface implementation ***
        public void Save(BinarySerializer writer)
        {
            Utils.ThrowException(!IsTrained ? new InvalidOperationException() : null);
            Utils.ThrowException(writer == null ? new ArgumentNullException("writer") : null);
            Dictionary<PosModel, Set<string>> inv_idx = new Dictionary<PosModel, Set<string>>();
            foreach (KeyValuePair<string, PosModel> item in m_index)
            {
                if (inv_idx.ContainsKey(item.Value))
                {
                    inv_idx[item.Value].Add(item.Key);
                }
                else
                {
                    inv_idx.Add(item.Value, new Set<string>(new string[] { item.Key }));
                }
            }
            // the following functions throw serialization-related exceptions
            writer.WriteInt(inv_idx.Count);
            foreach (KeyValuePair<PosModel, Set<string>> item in inv_idx)
            {
                item.Key.Save(writer);
                item.Value.Save(writer);
            }
        }
        public void Load(BinarySerializer reader)
        {
            Utils.ThrowException(reader == null ? new ArgumentNullException("reader") : null);
            m_models.Clear();
            m_index.Clear();
            // the following functions throw serialization-related exceptions
            int n = reader.ReadInt();
            for (int i = 0; i < n; i++)
            {
                PosModel model = new PosModel(reader);
                m_models.Add(model);
                Set<string> tags = new Set<string>(reader);
                foreach (string tag in tags) { m_index.Add(tag, model); }
            }
        }
    }
}
