"""
@Author     : karan kumar pathak
@Contact    : 2020fc04334@wilp.bits-pilani.com
@Description:
"""

import os
import gradio as gr
from huggingface_hub import snapshot_download
from prettytable import PrettyTable
import pandas as pd
import torch
import traceback

config = {
    "model_type": "roberta",
    "model_name_or_path": "roberta-large",
    "logic_lambda": 0.5,
    "prior": "random",
    "mask_rate": 0.0,
    "cand_k": 1,
    "max_seq1_length": 256,
    "max_seq2_length": 128,
    "max_num_questions": 8,
    "do_lower_case": False,
    "seed": 42,
    "n_gpu": torch.cuda.device_count(),
}

os.system('git clone https://github.com/jiangjiechen/LOREN/')
os.system('rm -r LOREN/data/')
os.system('rm -r LOREN/results/')
os.system('rm -r LOREN/models/')
os.system('mv LOREN/* ./')

model_dir = snapshot_download('Jiangjie/loren')
config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')


from src.loren import Loren


loren = Loren(config, verbose=False)
try:
    js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
except Exception as e:
    raise ValueError(e)


def highlight_phrase(text, phrase):
    text = loren.fc_client.tokenizer.clean_up_tokenization(text)
    return text.replace('<mask>', f'<i><b>{phrase}</b></i>')


def highlight_entity(text, entity):
    return text.replace(entity, f'<i><b>{entity}</b></i>')


def gradio_formatter(js, output_type):
    zebra_css = '''
    tr:nth-child(even) {
        background: #f1f1f1;
    }
    thead{
        background: #f1f1f1;
    }'''
    if output_type == 'e':
        data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]}
    elif output_type == 'z':
        p_sup, p_ref, p_nei = [], [], []
        for x in js['phrase_veracity']:
            max_idx = torch.argmax(torch.tensor(x)).tolist()
            x = ['%.4f' % xx for xx in x]
            x[max_idx] = f'<i><b>{x[max_idx]}</b></i>'
            p_sup.append(x[2])
            p_ref.append(x[0])
            p_nei.append(x[1])

        data = {
            'Claim Phrase': js['claim_phrases'],
            'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])],
            'p_SUP': p_sup,
            'p_REF': p_ref,
            'p_NEI': p_nei,
        }
    else:
        raise NotImplementedError
    data = pd.DataFrame(data)
    pt = PrettyTable(field_names=list(data.columns), 
        align='l', border=True, hrules=1, vrules=1)
    for v in data.values:
        pt.add_row(v)
    html = pt.get_html_string(attributes={
        'style': 'border-width: 2px; bordercolor: black'
    }, format=True)
    html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html
    html = html.replace('&lt;', '<').replace('&gt;', '>')
    return html


def run(claim):
    try:
        js = loren.check(claim)
    except Exception as error_msg:
        exc = traceback.format_exc()
        msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
        loren.logger.error(claim)
        loren.logger.error(msg)
        return 'Oops, something went wrong.', '', ''
    label = js['claim_veracity']
    loren.logger.warning(label + str(js))
    ev_html = gradio_formatter(js, 'e')
    z_html = gradio_formatter(js, 'z')
    return label, z_html, ev_html


iface = gr.Interface(
    fn=run,
    inputs="text",
    outputs=[
        'text',
        'html',
        'html',
    ],
    examples=['Donald Trump won the U.S. 2020 presidential election.',
              'The first inauguration of Bill Clinton was in the United States.',
              'The Cry of the Owl is based on a book by an American.',
              'Smriti Mandhana is an Indian woman.'],
    title="LOREN",
    layout='horizontal',
    description="LOREN is an interpretable Fact Verification model using Wikipedia as its knowledge source. "
                "This is a demo system for the AAAI 2022 paper: \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\"(https://arxiv.org/abs/2012.13577). "
                "See the paper for more details. You can add a *FLAG* on the bottom to record interesting or bad cases! "
                "(Note that the demo system directly retrieves evidence from an up-to-date Wikipedia, which is different from the evidence used in the paper.)",
    flagging_dir='results/flagged/',
    allow_flagging=True,
    flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise',
                      'Error: Require Commonsense', 'Error: Evidence Retrieval'],
    enable_queue=True
)
iface.launch()