File size: 3,028 Bytes
8f29b1d 3f6f474 cf575f8 ec1d54e cf575f8 d654474 cf575f8 e32f803 cf575f8 613e689 dd2409d 613e689 8f29b1d e691ea0 8f29b1d ec1d54e 8f29b1d e691ea0 8f29b1d d654474 cf575f8 0e97d35 d654474 3f6f474 d654474 cf575f8 dd2409d 3f6f474 8920952 cf575f8 3f6f474 613e689 3f6f474 0e97d35 613e689 cf575f8 dd2409d b49f004 a4780a0 ec1d54e cf575f8 ec1d54e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import traceback
from io import StringIO
from typing import Optional
import gradio as gr
import pandas as pd
from loguru import logger
from utils import pipeline
from utils.models import list_models
def read_data(filepath: str) -> Optional[pd.DataFrame]:
if filepath.endswith('.xlsx'):
df = pd.read_excel(filepath)
elif filepath.endswith('.csv'):
df = pd.read_csv(filepath)
else:
raise Exception('File type not supported')
return df
def process(
task_name: str,
model_name: str,
pooling: str,
text: str,
file=None,
) -> (None, pd.DataFrame, str):
try:
logger.info(f'Processing {task_name} with {model_name} and {pooling}')
# load file
if file:
df = read_data(file.name)
elif text:
string_io = StringIO(text)
df = pd.read_csv(string_io)
assert len(df) >= 1, 'No input data'
else:
raise Exception('No input data')
# check
if len(df) > 10000:
raise Exception('Data exceeds 10,000 rows')
# process
if task_name == 'Originality':
df = pipeline.p0_originality(df, model_name, pooling)
elif task_name == 'Flexibility':
df = pipeline.p1_flexibility(df, model_name, pooling)
else:
raise Exception('Task not supported')
# save
path = 'output.csv'
df.to_csv(path, index=False, encoding='utf-8-sig')
return None, df.iloc[:10], path
except:
error = traceback.format_exc()
logger.warning({
'error': error,
'task_name': task_name,
'model_name': model_name,
'pooling': pooling,
'text': text,
'file': file,
})
return {'Info': 'Something wrong', 'Error': traceback.format_exc()}, None, None
# input
task_name_dropdown = gr.components.Dropdown(
label='Task Name',
value='Originality',
choices=['Originality', 'Flexibility']
)
model_name_dropdown = gr.components.Dropdown(
label='Model Name',
value=list_models[0],
choices=list_models
)
pooling_dropdown = gr.components.Dropdown(
label='Pooling',
value='mean',
choices=['mean', 'cls']
)
text_input = gr.components.Textbox(
value=open('data/example_xlm.csv', 'r').read(),
lines=10,
)
file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx'])
# output
text_output = gr.components.Textbox(label='Output')
dataframe_output = gr.components.Dataframe(label='DataFrame')
file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx'])
app = gr.Interface(
fn=process,
inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input],
outputs=[text_output, dataframe_output, file_output],
description=open('data/description.txt', 'r').read(),
title='TransDis-CreativityAutoAssessment',
concurrency_limit=1,
)
app.launch(max_threads=1)
|