import gradio as gr
import spacy
from spacy import displacy
from spacy.tokens import Span
import pandas as pd
import base64
import random


DEFAULT_MODEL = "en_core_web"
DEFAULT_TEXT = "Apple is looking at buying U.K. startup for $1 billion."
texts = {"en": DEFAULT_TEXT, "ca": "Apple està buscant comprar una startup del Regne Unit per mil milions de dòlars", "da": "Apple overvejer at købe et britisk startup for 1 milliard dollar.", "de": "Die ganze Stadt ist ein Startup: Shenzhen ist das Silicon Valley für Hardware-Firmen",
         "el": "Η άνιση κατανομή του πλούτου και του εισοδήματος, η οποία έχει λάβει τρομερές διαστάσεις, δεν δείχνει τάσεις βελτίωσης.", "es": "Apple está buscando comprar una startup del Reino Unido por mil millones de dólares.", "fi": "Itseajavat autot siirtävät vakuutusvastuun autojen valmistajille", "fr": "Apple cherche à acheter une start-up anglaise pour 1 milliard de dollars", "it": "Apple vuole comprare una startup del Regno Unito per un miliardo di dollari",
         "ja": "アップルがイギリスの新興企業を10億ドルで購入を検討", "ko": "애플이 영국의 스타트업을 10억 달러에 인수하는 것을 알아보고 있다.", "lt": "Jaunikis pirmąją vestuvinę naktį iškeitė į areštinės gultą", "nb": "Apple vurderer å kjøpe britisk oppstartfirma for en milliard dollar.", "nl": "Apple overweegt om voor 1 miljard een U.K. startup te kopen",
         "pl": "Poczuł przyjemną woń mocnej kawy.", "pt": "Apple está querendo comprar uma startup do Reino Unido por 100 milhões de dólares", "ro": "Apple plănuiește să cumpere o companie britanică pentru un miliard de dolari", "ru": "Apple рассматривает возможность покупки стартапа из Соединённого Королевства за $1 млрд", "sv": "Apple överväger att köpa brittisk startup för 1 miljard dollar.", "zh": "作为语言而言,为世界使用人数最多的语言,目前世界有五分之一人口做为母语。"}
button_css = "float: right; --tw-border-opacity: 1; border-color: rgb(229 231 235 / var(--tw-border-opacity)); --tw-gradient-from: rgb(243 244 246 / 0.7); --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to, rgb(243 244 246 / 0)); --tw-gradient-to: rgb(229 231 235 / 0.8); --tw-text-opacity: 1; color: rgb(55 65 81 / var(--tw-text-opacity));    border-width: 1px; --tw-bg-opacity: 1; background-color: rgb(255 255 255 / var(--tw-bg-opacity)); background-image: linear-gradient(to bottom right, var(--tw-gradient-stops)); display: inline-flex; flex: 1 1 0%; align-items: center; justify-content: center;    --tw-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); --tw-shadow-colored: 0 1px 2px 0 var(--tw-shadow-color); box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000), var(--tw-shadow); -webkit-appearance: button; border-radius: 0.5rem; padding-top: 0.5rem; padding-bottom: 0.5rem; padding-left: 1rem; padding-right: 1rem; font-size: 1rem; line-height: 1.5rem; font-weight: 600;"
DEFAULT_COLOR = "linear-gradient(90deg, #FFCA74, #7AECEC)"
DEFAULT_ENTS = ['CARDINAL', 'DATE', 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 'MONEY',
                'NORP', 'ORDINAL', 'ORG', 'PERCENT', 'PERSON', 'PRODUCT', 'QUANTITY', 'TIME', 'WORK_OF_ART']
DEFAULT_TOK_ATTR = ['idx', 'text', 'pos_', 'lemma_', 'shape_', 'dep_']
NOUN_ATTR = ['text', 'root.text', 'root.dep_', 'root.head.text']




# get the huggingface models specified in the requirements.txt file
def get_all_models():
    with open("requirements.txt") as f:
        content = f.readlines()
        models = []
        for line in content:
            if "huggingface.co" in line:
                # the first three tokens in model, ex. en_core_web
                model = "_".join(line.split("/")[4].split("_")[:3])
                if model not in models:
                    models.append(model)
        return models

models = get_all_models()

# when clicked, download as SVG. Rendered as HTML on the page
def download_svg(svg):
    encode = base64.b64encode(bytes(svg, 'utf-8'))
    img = 'data:image/svg+xml;base64,' + str(encode)[2:-1]
    html = f'<a download="displacy.svg" href="{img}" style="{button_css}">Download as SVG</a>'
    return html

# create dependency graph, inputs are text, collapse punctuation, 
# collapse phrases, compact, background color, font color, and model
def dependency(text, col_punct, col_phrase, compact, bg, font, model):
    model_name = model + "_sm"
    nlp = spacy.load(model_name)
    doc = nlp(text)
    options = {"compact": compact, "collapse_phrases": col_phrase,
               "collapse_punct": col_punct, "bg": bg, "color": font}
    svg = displacy.render(doc, style="dep", options=options)
    download = download_svg(svg) # download button for SVG
    return svg, download, model_name

# returns the NER displacy, inputs are text, checked ents, and model
def entity(text, ents, model):
    model_name = model + "_sm"
    nlp = spacy.load(model_name)
    doc = nlp(text)
    options = {"ents": ents}
    svg = displacy.render(doc, style="ent", options=options)
    return svg, model_name

# returns token attributes for the user inputs
def token(text, attributes, model):
    model_name = model + "_sm"
    nlp = spacy.load(model_name)
    data = []
    doc = nlp(text)
    for tok in doc:
        tok_data = []
        for attr in attributes:
            tok_data.append(getattr(tok, attr))
        data.append(tok_data)
    data = pd.DataFrame(data, columns=attributes)
    return data, model_name

# returns token attributtes in the default state
# the return value is not a pandas DataFrame
def default_token(text, attributes, model):
    model_name = model + "_sm"
    nlp = spacy.load(model_name)
    data = []
    doc = nlp(text)
    for tok in doc:
        tok_data = []
        for attr in attributes:
            tok_data.append(getattr(tok, attr))
        data.append(tok_data)
    return data, model_name

# Get similarity of two random generated vectors
def random_vectors(text, model):
    model_name = model + "_md"
    nlp = spacy.load(model_name)
    doc = nlp(text)
    n_chunks = [chunk for chunk in doc.noun_chunks if doc.noun_chunks]
    words = [tok for tok in doc if not tok.is_stop and tok.pos_ not in [
        'PUNCT', "PROPN"]]
    str_list = n_chunks + words
    choice = random.choices(str_list, k=2)
    return round(choice[0].similarity(choice[1]), 2), choice[0].text, choice[1].text, model_name

# Get similarity of two inputted vectors
def vectors(input1, input2, model):
    model_name = model + "_md"
    nlp = spacy.load(model_name)
    return round(nlp(input1).similarity(nlp(input2)), 2), model_name

# display spans, inputs are text, spans, labels, and model
def span(text, span1, span2, label1, label2, model):
    model_name = model + "_sm"
    nlp = spacy.load(model_name)
    doc = nlp(text)
    if span1:
        idx1_1 = 0
        idx1_2 = 0
        idx2_1 = 0
        idx2_2 = 0

        span1 = [split for split in span1.split(" ") if split]
        span2 = [split for split in span2.split(" ") if split]

        for i in range(len(list(doc))):
            tok = list(doc)[i]
            if span1[0] == tok.text:
                idx1_1 = i
            if span1[-1] == tok.text:
                idx1_2 = i + 1
            if span2[0] == tok.text:
                idx2_1 = i
            if span2[-1] == tok.text:
                idx2_2 = i + 1

        doc.spans["sc"] = [
            Span(doc, idx1_1, idx1_2, label1),
            Span(doc, idx2_1, idx2_2, label2),
        ]
    else:
        idx1_1 = 0
        idx1_2 = round(len(list(doc)) / 2)
        idx2_1 = 0
        idx2_2 = 1

        doc.spans["sc"] = [
            Span(doc, idx1_1, idx1_2, label1),
            Span(doc, idx2_1, idx2_2, label2),
        ]

    svg = displacy.render(doc, style="span")
    return svg, model_name

# returns noun chunks in text
def noun_chunks(text, model):
    model_name = model + "_sm"
    nlp = spacy.load(model_name)
    data = []
    doc = nlp(text)
    for chunk in doc.noun_chunks:
        data.append([chunk.text, chunk.root.text, chunk.root.dep_,
            chunk.root.head.text])
    data = pd.DataFrame(data, columns=NOUN_ATTR)
    return data, model_name

# returns noun chuncks for the default value
# the return value is not a pandas DataFrame
def default_noun_chunks(text, model):
    model_name = model + "_sm"
    nlp = spacy.load(model_name)
    data = []
    doc = nlp(text)
    for chunk in doc.noun_chunks:
        data.append([chunk.text, chunk.root.text, chunk.root.dep_,
            chunk.root.head.text])
    return data, model_name

# get default text based on language model
def get_text(model):
    for i in range(len(models)):
        model = model.split("_")[0]
        new_text = texts[model]
    return new_text

demo = gr.Blocks(css="scrollbar.css")

with demo:
    with gr.Box():
        with gr.Row():
            with gr.Column():
                gr.Markdown("# Pipeline Visualizer")
                gr.Markdown(
                    "### Visualize parts of the spaCy pipeline in an interactive Gradio demo")
            with gr.Column():
                gr.Image("pipeline.svg")
    with gr.Box():
        with gr.Column():
            gr.Markdown(" ## Choose a language model and the inputted text")
            with gr.Row():
                with gr.Column(scale=0.25):
                    model_input = gr.Dropdown(
                        choices=models, value=DEFAULT_MODEL, interactive=True, label="Pretrained Pipelines")
            with gr.Row():
                with gr.Column(scale=0.5):
                    text_input = gr.Textbox(
                        value=DEFAULT_TEXT, interactive=True, label="Input Text")
            with gr.Row():
                with gr.Column(scale=0.25):
                    button = gr.Button("Update", variant="primary").style(full_width=False)
    with gr.Box():
        with gr.Column():
            with gr.Row():
                with gr.Column(scale=0.75):
                    gr.Markdown(
                        "## [🔗 Dependency Parser](https://spacy.io/usage/visualizers#dep)")
                    gr.Markdown(
                        "The dependency visualizer shows part-of-speech tags and syntactic dependencies")
                with gr.Column(scale=0.25):
                    dep_model = gr.Textbox(
                        label="Model", value="en_core_web_sm")
            with gr.Row():
                with gr.Column():
                    col_punct = gr.Checkbox(
                        label="Collapse Punctuation", value=True)
                    col_phrase = gr.Checkbox(
                        label="Collapse Phrases", value=True)
                    compact = gr.Checkbox(label="Compact", value=False)
                with gr.Column():
                    bg = gr.Textbox(
                        label="Background Color", value=DEFAULT_COLOR)
                with gr.Column():
                    text = gr.Textbox(
                        label="Text Color", value="black")
            with gr.Row():
                dep_output = gr.HTML(value=dependency(
                    DEFAULT_TEXT, True, True, False, DEFAULT_COLOR, "black", DEFAULT_MODEL)[0])
            with gr.Row():
                with gr.Column(scale=0.25):
                    dep_button = gr.Button(
                        "Update Dependency Parser", variant="primary").style(full_width=False)
                with gr.Column():
                    dep_download_button = gr.HTML(
                        value=download_svg(dep_output.value))
    with gr.Box():
        with gr.Column():
            with gr.Row():
                with gr.Column(scale=0.75):
                    gr.Markdown(
                        "## [🔗 Entity Recognizer](https://spacy.io/usage/visualizers#ent)")
                    gr.Markdown(
                        "The entity visualizer highlights named entities and their labels in a text")
                with gr.Column(scale=0.25):
                        ent_model = gr.Textbox(
                            label="Model", value="en_core_web_sm")
            ent_input = gr.CheckboxGroup(
                DEFAULT_ENTS, value=DEFAULT_ENTS, label="Entity Types")
            ent_output = gr.HTML(value=entity(
                DEFAULT_TEXT, DEFAULT_ENTS, DEFAULT_MODEL)[0])
            with gr.Row():
                with gr.Column(scale=0.25):
                    ent_button = gr.Button(
                        "Update Entity Recognizer", variant="primary")
    with gr.Box():
        with gr.Column():
            with gr.Row():
                with gr.Column(scale=0.75):
                    gr.Markdown(
                        "## [🔗 Token Properties](https://spacy.io/usage/linguistic-features)")
                    gr.Markdown(
                        "When you put in raw text to spaCy, it returns a Doc object with different linguistic features")
                with gr.Column(scale=0.25):
                        tok_model = gr.Textbox(
                                label="Model", value="en_core_web_sm")               
            with gr.Row():
                with gr.Column(scale=0.5):
                    tok_input = gr.CheckboxGroup(
                        DEFAULT_TOK_ATTR, value=DEFAULT_TOK_ATTR, label="Token Attributes", interactive=True)
            tok_output = gr.Dataframe(headers=DEFAULT_TOK_ATTR, value=default_token(
                DEFAULT_TEXT, DEFAULT_TOK_ATTR, DEFAULT_MODEL)[0], overflow_row_behaviour="paginate")
            with gr.Row():
                with gr.Column(scale=0.25):
                    tok_button = gr.Button(
                        "Update Token Properties", variant="primary")
    with gr.Box():
        with gr.Column():
            with gr.Row():
                with gr.Column(scale=0.75):
                    gr.Markdown(
                        "## [🔗 Word and Phrase Similarity](https://spacy.io/usage/linguistic-features#vectors-similarity)")
                    gr.Markdown(
                        "Words and spans have similarity ratings based on their word vectors")
                with gr.Column(scale=0.25):
                        sim_model = gr.Textbox(
                            label="Model", value="en_core_web_md")             
            with gr.Row():
                with gr.Column(scale=0.25):
                    sim_text1 = gr.Textbox(
                        value="Apple", label="Word 1", interactive=True,)
                with gr.Column(scale=0.25):
                    sim_text2 = gr.Textbox(
                        value="U.K. startup", label="Word 2", interactive=True,)
                with gr.Column(scale=0.25):
                    sim_output = gr.Textbox(
                        label="Similarity Score", value="0.12")
            with gr.Row():
                with gr.Column(scale=0.25):
                    sim_random_button = gr.Button("Update random words")
                with gr.Column(scale=0.25):
                    sim_button = gr.Button("Update similarity", variant="primary")
    with gr.Box():
        with gr.Column():
            with gr.Row():
                with gr.Column(scale=0.75):
                    gr.Markdown(
                        "## [🔗 Spans](https://spacy.io/usage/visualizers#span)")
                    gr.Markdown(
                        "The span visualizer highlights overlapping spans in a text")
                with gr.Column(scale=0.25):
                        span_model = gr.Textbox(
                                label="Model", value="en_core_web_sm")
            with gr.Row():
                with gr.Column(scale=0.3):
                    span1 = gr.Textbox(
                        label="Span 1", value="U.K. startup", placeholder="Input a part of the sentence")
                with gr.Column(scale=0.3):
                    label1 = gr.Textbox(value="ORG",
                                        label="Label for Span 1")
            with gr.Row():
                with gr.Column(scale=0.3):
                    span2 = gr.Textbox(
                        label="Span 2", value="U.K.", placeholder="Input another part of the sentence")
                with gr.Column(scale=0.3):
                    label2 = gr.Textbox(value="GPE",
                                        label="Label for Span 2")
            span_output = gr.HTML(value=span(
                DEFAULT_TEXT, "U.K. startup", "U.K.", "ORG", "GPE", DEFAULT_MODEL)[0])
            with gr.Row():
                with gr.Column(scale=0.25):
                    span_button = gr.Button("Update Spans", variant="primary")
    with gr.Box():
        with gr.Column():
            with gr.Row():
                with gr.Column(scale=0.75):
                    gr.Markdown(
                        "## [🔗 Noun chunks](https://spacy.io/usage/linguistic-features#noun-chunks)")
                    gr.Markdown(
                        "You can use `doc.noun_chunks` to extract noun phrases from a doc object")
                with gr.Column(scale=0.25):
                        noun_model = gr.Textbox(
                                label="Model", value="en_core_web_sm") 
            noun_output = gr.Dataframe(headers=NOUN_ATTR, value=default_noun_chunks(
                DEFAULT_TEXT, DEFAULT_MODEL)[0], overflow_row_behaviour="paginate")
            with gr.Row():
                with gr.Column(scale=0.25):
                    noun_button = gr.Button(
                        "Update Noun Chunks", variant="primary")

    # change text based on model input
    model_input.change(get_text, inputs=[model_input], outputs=text_input)
    # main button - update all components
    button.click(dependency, inputs=[
        text_input, col_punct, col_phrase, compact, bg, text, model_input], outputs=[dep_output, dep_download_button, dep_model])
    button.click(
        entity, inputs=[text_input, ent_input, model_input], outputs=[ent_output, ent_model])
    button.click(
        token, inputs=[text_input, tok_input, model_input], outputs=[tok_output, tok_model])
    button.click(vectors, inputs=[sim_text1,
                 sim_text2, model_input], outputs=[sim_output, sim_model])
    button.click(
        span, inputs=[text_input, span1, span2, label1, label2, model_input], outputs=[span_output, span_model])
    button.click(
        noun_chunks, inputs=[text_input, model_input], outputs=[noun_output, noun_model])
    
    # individual component buttons
    dep_button.click(dependency, inputs=[
        text_input, col_punct, col_phrase, compact, bg, text, model_input], outputs=[dep_output, dep_download_button, dep_model])
    ent_button.click(
        entity, inputs=[text_input, ent_input, model_input], outputs=[ent_output, ent_model])
    tok_button.click(
        token, inputs=[text_input, tok_input, model_input], outputs=[tok_output, tok_model])
    sim_button.click(vectors, inputs=[
                     sim_text1, sim_text2, model_input], outputs=[sim_output, sim_model])
    sim_random_button.click(random_vectors, inputs=[text_input, model_input], outputs=[
                            sim_output, sim_text1, sim_text2, sim_model])
    span_button.click(
        span, inputs=[text_input, span1, span2, label1, label2, model_input], outputs=[span_output, span_model])
    noun_button.click(
        noun_chunks, inputs=[text_input, model_input], outputs=[noun_output, noun_model])
      
demo.launch()