import argparse
import glob
import os.path

import gradio as gr

import pickle
import tqdm
import json

import MIDI
from midi_synthesizer import synthesis

from fuzzywuzzy import process

import copy
from collections import Counter
import random
import statistics

import matplotlib.pyplot as plt

#==========================================================================================================

in_space = os.getenv("SYSTEM") == "spaces"

#==========================================================================================================

def find_midi(title, artist):
  
    print('=' * 70)
    print('Loading MIDI file...')
    
    #==================================================
    
    print('Searching titles...Please wait...')
    random.shuffle(AUX_DATA)
    
    titles_index = []
    
    for A in AUX_DATA:
      titles_index.append(A[0])
    
    search_string = ''
    
    if enter_desired_song_title != '' and enter_desired_artist != '':
      search_string = enter_desired_song_title + ' --- ' + enter_desired_artist
    
    else:
      search_string = enter_desired_song_title + enter_desired_artist
    
    search_match = process.extract(query=search_string, choices=titles_index, limit=1)
    search_index = titles_index.index(search_match[0][0])
    
    print('Done!')
    print('=' * 70)
    print('Selected title:', AUX_DATA[search_index][0])
    print('=' * 70)

    outy = AUX_DATA[search_index][1]

    print('Sample INTs', outy[:12])
    print('=' * 70)
    
    if len(outy) != 0:
    
      song = outy
      song_f = []
    
      time = 0
      dur = 0
      vel = 90
      pitch = 0
      channel = 0
    
      patches = [-1] * 16
    
      channels = [0] * 16
      channels[9] = 1
    
      for ss in song:
    
          if 0 <= ss < 256:
    
              time += ss * 16
    
          if 256 <= ss < 2304:
    
              dur = ((ss-256) // 8) * 16
              vel = (((ss-256) % 8)+1) * 15
    
          if 2304 <= ss < 18945:
    
              patch = (ss-2304) // 129
    
              if patch < 128:
    
                  if patch not in patches:
                    if 0 in channels:
                        cha = channels.index(0)
                        channels[cha] = 1
                    else:
                        cha = 15
    
                    patches[cha] = patch
                    channel = patches.index(patch)
                  else:
                    channel = patches.index(patch)
    
              if patch == 128:
                  channel = 9
    
              pitch = (ss-2304) % 129
    
              song_f.append(['note', time, dur, channel, pitch, vel, patch ])
    
    patches = [0 if x==-1 else x for x in patches]
    
    detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                            output_signature = 'Giant Music Transformer',
                                                            output_file_name = '/content/Giant-Music-Transformer-Music-Composition_'+str(i),
                                                            track_name='Project Los Angeles',
                                                            list_of_MIDI_patches=patches
                                                            )

    
    x = []
    y = []
    c = []
    
    colors = ['red', 'yellow', 'green', 'cyan',
            'blue', 'pink', 'orange', 'purple',
            'gray', 'white', 'gold', 'silver',
            'lightgreen', 'indigo', 'maroon', 'turquoise']
    
    for s in [m for m in mid_seq if m[0] == 'note']:
        x.append(s[1])
        y.append(s[4])
        c.append(colors[s[3]])

    plt.close()
    plt.figure(figsize=(14,5))
    ax=plt.axes(title='MIDI Match Plot')
    ax.set_facecolor('black')
    
    plt.scatter(x,y, c=c)
    plt.xlabel("Time in MIDI ticks")
    plt.ylabel("MIDI Pitch")
 
    with open(f"MIDI-Match-Sample.mid", 'wb') as f:
        f.write(MIDI.score2midi([mid_seq_ticks, mid_seq]))
    audio = synthesis(MIDI.score2opus([mid_seq_ticks, mid_seq]), soundfont_path)
    yield txt_mdata, "MIDI-Match-Sample.mid", (44100, audio), plt

#==========================================================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
    parser.add_argument("--port", type=int, default=7860, help="gradio server port")
    parser.add_argument("--max-gen", type=int, default=1024, help="max")
    
    opt = parser.parse_args()
    
    soundfont_path = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
    meta_data_path = "Giant_Music_Transformer_Aux_Data.pickle"

    print('Loading meta-data...')
    with open(meta_data_path, 'rb') as f:
        AUX_DATA = pickle.load(f)
    print('Done!')
    
    app = gr.Blocks()
    with app:
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Search</h1>")
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Upload any MIDI file to find its closest match</h1>")
        
        gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Match&style=flat)\n\n"
                    "Los Angeles MIDI Dataset Search and Explore Demo\n\n"
                    "Please see [Los Angeles MIDI Dataset](https://github.com/asigalov61/Los-Angeles-MIDI-Dataset) for more information and features\n\n"
                    "[Open In Colab]"
                    "(https://colab.research.google.com/github/asigalov61/Los-Angeles-MIDI-Dataset/blob/main/Los_Angeles_MIDI_Dataset_Search_and_Explore.ipynb)"
                    " for faster execution"
                    )

        gr.Markdown("# Upload MIDI")

        artist = gr.Textbox()
        title = gr.Textbox()

        gr.Markdown("# Match results")
        
        output_audio = gr.Audio(label="Output MIDI match sample audio", format="mp3", elem_id="midi_audio")
        output_plot = gr.Plot(label="Output MIDI match sample plot")
        output_midi = gr.File(label="Output MIDI match sample MIDI", file_types=[".mid"])
        output_midi_seq = gr.Textbox(label="Output MIDI match metadata")
        
        run_event = input_midi.upload(match_midi, [input_midi, maximum_match_ratio],
                                                  [output_midi_seq, output_midi, output_audio, output_plot])
        
    app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)