import gradio
import inseq
from inseq.data.aggregator import AggregatorPipeline, SubwordAggregator, SequenceAttributionAggregator, PairAggregator
import torch
from os.path import exists

if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

# Start downloading the Hu-En model
# model_hu_en = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients")

def swap_pronoun(sentence):
    if "He" in sentence:
        return sentence.replace("He", "She")
    elif "She" in sentence:
        return sentence.replace("She", "He")
    else:
        return sentence

def run_counterfactual(occupation): 
    occupation = occupation.split(" (")[0]
    result_fp = f"results/counterfactual_{occupation}.html"
    if exists(result_fp):
        with open(result_fp, 'r') as f:
            return f.read()

    # "egy" means something like "a", but is used less frequently than in English.
    #source = f"Ő egy {occupation}."
    source = f"Ő {occupation}."
    
    model = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients")
    model.device = DEVICE
    target = model.generate(source)[0]
    #target_modified = swap_pronoun(target)

    out = model.attribute(
    [
        source,
        source,
    ],
    [
        #target,
        #target_modified,
        target.replace("She", "He"),
        target.replace("He", "She"),
    ],
    n_steps=150,
    return_convergence_delta=False,
    attribute_target=False,
    step_scores=["probability"],
    internal_batch_size=100,
    include_eos_baseline=False,
    device=DEVICE,
)
    #out = model.attribute(source, attribute_target=False, n_steps=150, device=DEVICE, return_convergence_delta=False, step_scores=["probability"])

    squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator])
    masculine = out.sequence_attributions[0].aggregate(aggregator=squeezesum)
    feminine = out.sequence_attributions[1].aggregate(aggregator=squeezesum)

    html = masculine.show(aggregator=PairAggregator, paired_attr=feminine, return_html=True, display=True)

    # Save html
    with open(result_fp, 'w') as f:
        f.write(html)

    return html
    #return out.show(return_html=True, display=True)

def run_simple(occupation, lang, aggregate):
    aggregate = True if aggregate == "yes" else False
    occupation = occupation.split(" (")[0]

    result_fp = f"results/simple_{occupation}_{lang}{'_aggregate' if aggregate else ''}.html"
    if exists(result_fp):
        with open(result_fp, 'r') as f:
            return f.read()

    model_name = f"Helsinki-NLP/opus-mt-hu-{lang}"

    # "egy" means something like "a", but is used less frequently than in English.
    #source = f"Ő egy {occupation}."
    source = f"Ő {occupation}."
    
    model = inseq.load_model(model_name, "integrated_gradients")
    out = model.attribute([source], attribute_target=True, n_steps=150, device=DEVICE, return_convergence_delta=False)
    
    if aggregate:
        squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator])
        html = out.show(return_html=True, display=True, aggregator=squeezesum)
    else:
        html = out.show(return_html=True, display=True)

    # Save html
    with open(result_fp, 'w') as f:
        f.write(html)
    return html



with open("description.md") as fh:
    desc = fh.read()

with open("simple_translation.md") as fh:
    simple_translation = fh.read()

with open("contrastive_pair.md") as fh:
    contrastive_pair = fh.read()

with open("notice.md") as fh:
    notice = fh.read()

OCCUPATIONS = [
    "nő (woman)",    
    "férfi (man)",
    "nővér (nurse)",
    "tudós (scientist)",
    "mérnök (engineer)",
    "pék (baker)",
    "tanár (teacher)",
    "esküvőszervező (wedding organizer)",
    "vezérigazgató (CEO)",
]

LANGS = [
    "en",
    "fr",
    "de",
]

with gradio.Blocks(title="Gender Bias in MT: Hungarian to English") as iface:
    gradio.Markdown(desc)

    print(simple_translation)
    with gradio.Accordion("Simple translation", open=True):
        gradio.Markdown(simple_translation)

    with gradio.Accordion("Contrastive pair", open=False):
        gradio.Markdown(contrastive_pair)

    gradio.Markdown("**Does the model seem to rely on gender stereotypes in its translations?**")

    with gradio.Tab("Simple translation"):
        with gradio.Row(equal_height=True):
            with gradio.Column(scale=4):
                occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0])
            with gradio.Column(scale=4):
                target_lang = gradio.Dropdown(label="Target Language", choices=LANGS, value=LANGS[0])
        aggregate_subwords = gradio.Radio(
            ["yes", "no"], label="Aggregate subwords?", value="yes"
        )
        but = gradio.Button("Translate & Attribute")
        out = gradio.HTML()
        args = [occupation_sel, target_lang, aggregate_subwords]
        but.click(run_simple, inputs=args, outputs=out)

    with gradio.Tab("Contrastive pair"):
        with gradio.Row(equal_height=True):
            with gradio.Column(scale=4):
                occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0])
        but = gradio.Button("Translate & Attribute")
        out = gradio.HTML()
        args = [occupation_sel]
        but.click(run_counterfactual, inputs=args, outputs=out)

    with gradio.Accordion("Notes & References", open=False):
        gradio.Markdown(notice)


iface.launch()