import argparse
import glob
import os.path

import gradio as gr

import pickle
import tqdm
import json

import MIDI
from midi_synthesizer import synthesis

from fuzzywuzzy import process

import copy
from collections import Counter
import random
import statistics

import matplotlib.pyplot as plt

#==========================================================================================================

in_space = os.getenv("SYSTEM") == "spaces"

#==========================================================================================================

def find_midi(title, artist):
  
    print('=' * 70)
    print('Loading MIDI file...')
    
    #==================================================
    
    print('Searching titles...Please wait...')
    random.shuffle(AUX_DATA)
    
    titles_index = []
    
    for A in AUX_DATA:
      titles_index.append(A[0])
    
    search_string = ''
    
    if title != '' and artist != '':
      search_string = title + ' --- ' + artist
    
    else:
      search_string = title + artist
    
    search_match = process.extract(query=search_string, choices=titles_index, limit=1)
    search_index = titles_index.index(search_match[0][0])
    
    print('Done!')
    print('=' * 70)
    print('Selected title:', AUX_DATA[search_index][0])
    print('=' * 70)

    outy = AUX_DATA[search_index][1]

    print('Sample INTs', outy[:12])
    print('=' * 70)
    
    if len(outy) != 0:
    
      song = outy
      song_f = []
    
      time = 0
      dur = 0
      vel = 90
      pitch = 0
      channel = 0
    
      patches = [-1] * 16
    
      channels = [0] * 16
      channels[9] = 1
    
      for ss in song:
    
          if 0 <= ss < 256:
    
              time += ss * 16
    
          if 256 <= ss < 2304:
    
              dur = ((ss-256) // 8) * 16
              vel = (((ss-256) % 8)+1) * 15
    
          if 2304 <= ss < 18945:
    
              patch = (ss-2304) // 129
    
              if patch < 128:
    
                  if patch not in patches:
                    if 0 in channels:
                        cha = channels.index(0)
                        channels[cha] = 1
                    else:
                        cha = 15
    
                    patches[cha] = patch
                    channel = patches.index(patch)
                  else:
                    channel = patches.index(patch)
    
              if patch == 128:
                  channel = 9
    
              pitch = (ss-2304) % 129
    
              song_f.append(['note', time, dur, channel, pitch, vel, patch ])
    

    
    x = []
    y = []
    c = []
    
    colors = ['red', 'yellow', 'green', 'cyan',
            'blue', 'pink', 'orange', 'purple',
            'gray', 'white', 'gold', 'silver',
            'lightgreen', 'indigo', 'maroon', 'turquoise']
    
    for s in [m for m in song_f if m[0] == 'note']:
        x.append(s[1])
        y.append(s[4])
        c.append(colors[s[3]])

    plt.close()
    plt.figure(figsize=(14,5))
    ax=plt.axes(title='MIDI Match Plot')
    ax.set_facecolor('black')
    
    plt.scatter(x,y, c=c)
    plt.xlabel("Time in MIDI ticks")
    plt.ylabel("MIDI Pitch")

    output_signature = AUX_DATA[search_index][0]
    track_name = 'Project Los Angeles'
    text_encoding = 'ISO-8859-1'

    list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0]

    output_header = [1000,
                    [['set_tempo', 0, 1000000],
                     ['time_signature', 0, 4, 2, 24, 8],
                     ['track_name', 0, bytes(output_signature, text_encoding)]]]

    patch_list = [['patch_change', 0, 0, list_of_MIDI_patches[0]], 
                    ['patch_change', 0, 1, list_of_MIDI_patches[1]],
                    ['patch_change', 0, 2, list_of_MIDI_patches[2]],
                    ['patch_change', 0, 3, list_of_MIDI_patches[3]],
                    ['patch_change', 0, 4, list_of_MIDI_patches[4]],
                    ['patch_change', 0, 5, list_of_MIDI_patches[5]],
                    ['patch_change', 0, 6, list_of_MIDI_patches[6]],
                    ['patch_change', 0, 7, list_of_MIDI_patches[7]],
                    ['patch_change', 0, 8, list_of_MIDI_patches[8]],
                    ['patch_change', 0, 9, list_of_MIDI_patches[9]],
                    ['patch_change', 0, 10, list_of_MIDI_patches[10]],
                    ['patch_change', 0, 11, list_of_MIDI_patches[11]],
                    ['patch_change', 0, 12, list_of_MIDI_patches[12]],
                    ['patch_change', 0, 13, list_of_MIDI_patches[13]],
                    ['patch_change', 0, 14, list_of_MIDI_patches[14]],
                    ['patch_change', 0, 15, list_of_MIDI_patches[15]],
                    ['track_name', 0, bytes(track_name, text_encoding)]]

    output = output_header + [patch_list + song_f]

    with open(f"MIDI-Search-Sample.mid", 'wb') as f:
        f.write(MIDI.score2midi(output))
    audio = synthesis(MIDI.score2opus(output), soundfont_path)
    yield AUX_DATA[search_index][0], "MIDI-Search-Sample.mid", (44100, audio), plt

#==========================================================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
    parser.add_argument("--port", type=int, default=7860, help="gradio server port")
    parser.add_argument("--max-gen", type=int, default=1024, help="max")
    
    opt = parser.parse_args()
    
    soundfont_path = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
    meta_data_path = "Giant_Music_Transformer_Aux_Data.pickle"

    print('Loading meta-data...')
    with open(meta_data_path, 'rb') as f:
        AUX_DATA = pickle.load(f)
    print('Done!')
    
    app = gr.Blocks()
    with app:
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Search</h1>")
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Search and explore 160k+ MIDI titles</h1>")
        
        gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Search&style=flat)\n\n"
                    "Giant Music Transformer Aux Data Demo\n\n"
                    "Please see [Giant Music Transformer](https://github.com/asigalov61/Giant-Music-Transformer) for more information and features\n\n"
                    "[Open In Colab]"
                    "(https://colab.research.google.com/github/asigalov61/Giant-Music-Transformer/blob/main/Giant_Music_Transformer_TTM.ipynb)"
                    " for all features"
                    )
        
        title = gr.Textbox(label="Desired Song Title", value="Family Guy")
        artist = gr.Textbox(label="Desired Song Artist", value="TV Themes")
        submit = gr.Button()

        gr.Markdown("# Search results")

        output_midi_seq = gr.Textbox(label="Found MIDI search title")
        output_audio = gr.Audio(label="Output MIDI search sample audio", format="mp3", elem_id="midi_audio")
        output_plot = gr.Plot(label="Output MIDI search sample plot")
        output_midi = gr.File(label="Output MIDI search sample MIDI", file_types=[".mid"])
        
        run_event = submit.click(find_midi, [title, artist],
                                                  [output_midi_seq, output_midi, output_audio, output_plot])
        
    app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)