using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text;
using UnityEditor;
#if UNITY_2020_2_OR_NEWER
using UnityEditor.AssetImporters;
using UnityEditor.Experimental.AssetImporters;
#else
using UnityEditor.Experimental.AssetImporters;
#endif
using UnityEngine;
using System;

namespace Unity.Barracuda
{
/// <summary>
/// Asset Importer Editor of ONNX models
/// </summary>
[CustomEditor(typeof(ONNXModelImporter))]
public class ONNXModelImporterEditor : ScriptedImporterEditor
{
    public override void OnInspectorGUI()
    {
        var onnxModelImporter = target as ONNXModelImporter;
        if (onnxModelImporter == null)
            return;

        SerializedProperty iterator = this.serializedObject.GetIterator();
        for (bool enterChildren = true; iterator.NextVisible(enterChildren); enterChildren = false)
        {
            if (iterator.propertyPath != "m_Script")
                EditorGUILayout.PropertyField(iterator, true);
        }
        ApplyRevertGUI();
    }

}

/// <summary>
/// Asset Importer Editor of NNModel (the serialized file generated by ONNXModelImporter)
/// </summary>
[CustomEditor(typeof(NNModel))]
public class NNModelEditor : Editor
{
    private Model m_Model;
    private List<string> m_Inputs = new List<string>();
    private List<string> m_InputsDesc = new List<string>();
    private List<string> m_Outputs = new List<string>();
    private List<string> m_OutputsDesc = new List<string>();
    private List<string> m_Memories = new List<string>();
    private List<string> m_MemoriesDesc = new List<string>();
    private List<string> m_Layers = new List<string>();
    private List<string> m_LayersDesc = new List<string>();
    private List<string> m_Constants = new List<string>();
    private List<string> m_ConstantsDesc = new List<string>();
    private List<string> m_Warnings = new List<string>();
    private List<string> m_WarningsDesc = new List<string>();
    private long m_NumEmbeddedWeights;
    private long m_NumConstantWeights;
    private long m_TotalWeightsSizeInBytes;
    private Vector2 m_WarningsScrollPosition = Vector2.zero;
    private Vector2 m_InputsScrollPosition = Vector2.zero;
    private Vector2 m_OutputsScrollPosition = Vector2.zero;
    private Vector2 m_MemoriesScrollPosition = Vector2.zero;
    private Vector2 m_LayerScrollPosition = Vector2.zero;
    private Vector2 m_ConstantScrollPosition = Vector2.zero;
    private const float k_Space = 5f;

    private Texture2D m_IconTexture;
    private Texture2D LoadIconTexture()
    {
        if (m_IconTexture != null)
            return m_IconTexture;

        string[] allCandidates = AssetDatabase.FindAssets(ONNXModelImporter.iconName);
        if (allCandidates.Length > 0)
            m_IconTexture = AssetDatabase.LoadAssetAtPath(AssetDatabase.GUIDToAssetPath(allCandidates[0]), typeof(Texture2D)) as Texture2D;

        return m_IconTexture;
    }
    public override Texture2D RenderStaticPreview(string assetPath, UnityEngine.Object[] subAssets, int width, int height)
    {
        Texture2D tex = new Texture2D(width, height);
        EditorUtility.CopySerialized(LoadIconTexture(), tex);
        return tex;
    }

    void OnEnable()
    {
        var nnModel = target as NNModel;
        if (nnModel == null)
            return;
        if (nnModel.modelData == null)
            return;

        m_Model = ModelLoader.Load(nnModel, verbose:false, skipWeights:true);
        if (m_Model == null)
            return;

        m_Inputs = m_Model.inputs.Select(i => i.name).ToList();
        m_InputsDesc = m_Model.inputs.Select(i => $"shape: ({String.Join(",", i.shape)})").ToList();
        m_Outputs = m_Model.outputs.ToList();

        bool allKnownShapes = true;
        var inputShapes = new Dictionary<string, TensorShape>();
        foreach (var i in m_Model.inputs)
        {
            allKnownShapes = allKnownShapes && !i.shape.Contains(-1) && !i.shape.Contains(0);
            if (!allKnownShapes)
                break;
            inputShapes.Add(i.name, new TensorShape(i.shape));
        }
        if (allKnownShapes)
        {
            m_OutputsDesc = m_Model.outputs.Select(i => {
                string output = "(-1,-1,-1,-1)";
                try
                {
                    TensorShape shape;
                    if (ModelAnalyzer.TryGetOutputTensorShape(m_Model, inputShapes, i, out shape))
                        output = shape.ToString();
                }
                catch (Exception e)
                {
                    Debug.LogError($"Unexpected error while evaluating model output {i}. {e}");
                }
                return $"shape: {output}"; }).ToList();
        }
        else
        {
            m_OutputsDesc = m_Model.outputs.Select(i => "shape: (-1,-1,-1,-1)").ToList();
        }

        m_Memories = m_Model.memories.Select(i => i.input).ToList();
        m_MemoriesDesc = m_Model.memories.Select(i => $"shape:{i.shape.ToString()} output:{i.output}").ToList();

        var layers = m_Model.layers.Where(i => i.type != Layer.Type.Load);
        var constants = m_Model.layers.Where(i => i.type == Layer.Type.Load);

        m_Layers        = layers.Select(i => i.type.ToString()).ToList();
        m_LayersDesc    = layers.Select(i => i.ToString()).ToList();
        m_Constants     = constants.Select(i => i.type.ToString()).ToList();
        m_ConstantsDesc = constants.Select(i => i.ToString()).ToList();

        m_NumEmbeddedWeights = layers.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));
        m_NumConstantWeights = constants.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));

        // weights are not loaded for UI, recompute size
        m_TotalWeightsSizeInBytes = 0;
        for (var l = 0; l < m_Model.layers.Count; ++l)
            for (var d = 0; d < m_Model.layers[l].datasets.Length; ++d)
                m_TotalWeightsSizeInBytes += m_Model.layers[l].datasets[d].length;

        m_Warnings = m_Model.Warnings.Select(i => i.LayerName).ToList();
        m_WarningsDesc = m_Model.Warnings.Select(i => i.Message).ToList();
    }

    public override void OnInspectorGUI()
    {
        if (m_Model == null)
            return;

        GUI.enabled = true;
        GUILayout.Label($"Source: {m_Model.IrSource}");
        GUILayout.Label($"Version: {m_Model.IrVersion}");
        GUILayout.Label($"Producer Name: {m_Model.ProducerName}");

        if(m_Warnings.Any())
        {
            ListUIHelper($"Warnings {m_Warnings.Count.ToString()}", m_Warnings, m_WarningsDesc, ref m_WarningsScrollPosition);
            EditorGUILayout.HelpBox("Model contains warnings. Behavior might be incorrect", MessageType.Warning, true);
        }
        var constantWeightInfo = m_Constants.Count > 0 ? $" using {m_NumConstantWeights:n0} weights" : "";
        ListUIHelper($"Inputs ({m_Inputs.Count})", m_Inputs, m_InputsDesc, ref m_InputsScrollPosition);
        ListUIHelper($"Outputs ({m_Outputs.Count})", m_Outputs, m_OutputsDesc, ref m_OutputsScrollPosition);
        ListUIHelper($"Memories ({m_Memories.Count})", m_Memories, m_MemoriesDesc, ref m_MemoriesScrollPosition);
        ListUIHelper($"Layers ({m_Layers.Count} using {m_NumEmbeddedWeights:n0} embedded weights)", m_Layers, m_LayersDesc, ref m_LayerScrollPosition, m_Constants.Count == 0 ? 1.5f: 1f);
        ListUIHelper($"Constants ({m_Constants.Count}{constantWeightInfo})", m_Constants, m_ConstantsDesc, ref m_ConstantScrollPosition);

        GUILayout.Label($"Total weight size: {m_TotalWeightsSizeInBytes:n0} bytes");
    }

    private static void ListUIHelper(string sectionTitle, IReadOnlyList<string> names, IReadOnlyList<string> descriptions, ref Vector2 scrollPosition, float maxHeightMultiplier = 1f)
    {
        int n = names.Count();
        Debug.Assert(descriptions.Count == n);
        if (descriptions.Count < n)
            return;

        GUILayout.Space(k_Space);
        GUILayout.Label(sectionTitle, EditorStyles.boldLabel);
        float height = Mathf.Min(n * 20f + 2f, 150f * maxHeightMultiplier);
        if (n == 0)
            return;

        scrollPosition = GUILayout.BeginScrollView(scrollPosition, GUI.skin.box, GUILayout.MinHeight(height));
        Event e = Event.current;
        float lineHeight = 16.0f;

        StringBuilder fullText = new StringBuilder();
        fullText.Append(sectionTitle);
        fullText.AppendLine();
        for (int i = 0; i < n; ++i)
        {
            string name = names[i];
            string description = descriptions[i];
            fullText.Append($"{name} {description}");
            fullText.AppendLine();
        }

        for (int i = 0; i < n; ++i)
        {
            Rect r = EditorGUILayout.GetControlRect(false, lineHeight);

            string name = names[i];
            string description = descriptions[i];

            // Context menu, "Copy"
            if (e.type == EventType.ContextClick && r.Contains(e.mousePosition))
            {
                e.Use();
                var menu = new GenericMenu();
                // need to copy current value to be used in delegate
                // (C# closures close over variables, not their values)
                menu.AddItem(new GUIContent($"Copy current line"), false, delegate {
                    EditorGUIUtility.systemCopyBuffer = $"{name} {description}";
                });
                menu.AddItem(new GUIContent($"Copy section"), false, delegate {
                    EditorGUIUtility.systemCopyBuffer = fullText.ToString();
                });
                menu.ShowAsContext();
            }

            // Color even line for readability
            if (e.type == EventType.Repaint)
            {
                GUIStyle st = "CN EntryBackEven";
                if ((i & 1) == 0)
                    st.Draw(r, false, false, false, false);
            }

            // layer name on the right side
            Rect locRect = r;
            locRect.xMax = locRect.xMin;
            GUIContent gc = new GUIContent(name.ToString(CultureInfo.InvariantCulture));

            // calculate size so we can left-align it
            Vector2 size = EditorStyles.miniBoldLabel.CalcSize(gc);
            locRect.xMax += size.x;
            GUI.Label(locRect, gc, EditorStyles.miniBoldLabel);
            locRect.xMax += 2;

            // message
            Rect msgRect = r;
            msgRect.xMin = locRect.xMax;
            GUI.Label(msgRect, new GUIContent(description.ToString(CultureInfo.InvariantCulture)), EditorStyles.miniLabel);
        }
        GUILayout.EndScrollView();
    }
}

}
