Spaces:
Runtime error
Runtime error
import gradio as gr | |
import huggingface_hub as hfh | |
from requests.exceptions import HTTPError | |
# ===================================================================================================================== | |
# DATA | |
# ===================================================================================================================== | |
# Dict with the tasks considered in this spaces, {pretty name: space tag} | |
TASK_TYPES = { | |
"✍️ Text Generation": "txtgen", | |
"🤏 Summarization": "summ", | |
"🫂 Translation": "trans", | |
"💬 Conversational / Chatbot": "chat", | |
"🤷 Text Question Answering": "txtqa", | |
"🕵️ (Table/Document/Visual) Question Answering": "otherqa", | |
"🎤 Automatic Speech Recognition": "asr", | |
"🌇 Image to Text": "img2txt", | |
} | |
# Dict matching all task types with their possible hub tags, {space tag: (possible hub tags)} | |
HUB_TAGS = { | |
"txtgen": ("text-generation", "text2text-generation"), | |
"summ": ("summarization", "text-generation", "text2text-generation"), | |
"trans": ("translation", "text-generation", "text2text-generation"), | |
"chat": ("conversational", "text-generation", "text2text-generation"), | |
"txtqa": ("text-generation", "text2text-generation"), | |
"otherqa": ("table-question-answering", "document-question-answering", "visual-question-answering"), | |
"asr": ("automatic-speech-recognition"), | |
"img2txt": ("image-to-text"), | |
} | |
assert len(TASK_TYPES) == len(TASK_TYPES) | |
assert all(tag in HUB_TAGS for tag in TASK_TYPES.values()) | |
# Dict with the problems considered in this spaces, {problem: space tag} | |
PROBLEMS = { | |
"🌈 I would like a ChatGPT-like model": "chatgpt", | |
"📈 I want to improve the overall quality of the output": "quality", | |
"🏎 Speed! Make it faster!": "faster", | |
"😵💫 I would like to reduce model hallucinations": "hallucinations", | |
"📏 I want to control the length of the output": "length", | |
"🤔 The model is returning nothing / random words": "random", | |
"❓ Other": "other", | |
} | |
INIT_MARKDOWN = """ | |
| |
👈 Please select a task type... | |
| |
| |
| |
| |
| |
👈 ... and a problem type... | |
| |
| |
👈 ... then click here! | |
""" | |
# ===================================================================================================================== | |
# ===================================================================================================================== | |
# SUGGESTIONS LOGIC | |
# ===================================================================================================================== | |
def is_valid_task_for_model(model_name, task_type): | |
if model_name == "": | |
return True | |
try: | |
model_tags = hfh.HfApi().model_info(model_name).tags | |
except HTTPError: | |
return True # Assume everything is okay | |
possible_tags = HUB_TAGS[TASK_TYPES[task_type]] | |
return any(tag in model_tags for tag in possible_tags) | |
def get_suggestions(task_type, model_name, problem_type): | |
if task_type in ("" or None) or problem_type in ("" or None): | |
return INIT_MARKDOWN | |
return "" | |
# ===================================================================================================================== | |
# ===================================================================================================================== | |
# GRADIO | |
# ===================================================================================================================== | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# 🚀💬 Improving Text Generation 💬🚀 | |
This is a ever-evolving guide on how to improve your text generation results. It is community-led and | |
curated by Hugging Face 🤗 | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
task_type = gr.Dropdown( | |
label="What is the task you're trying to accomplish?", | |
choices=list(TASK_TYPES.keys()), | |
interactive=True, | |
) | |
model_name = gr.Textbox( | |
label="Which model are you using? (leave blank if you haven't decided)", | |
placeholder="e.g. google/flan-t5-xl", | |
interactive=True, | |
) | |
problem_type = gr.Dropdown( | |
label="What would you like to improve?", | |
choices=list(PROBLEMS.keys()), | |
interactive=True, | |
) | |
button = gr.Button(value="Get Suggestions!") | |
with gr.Column(scale=2): | |
suggestions = gr.Markdown(value=INIT_MARKDOWN) | |
button.click(get_suggestions, inputs=[task_type, model_name, problem_type], outputs=suggestions) | |
gr.Markdown( | |
""" | |
Is your problem not on the list? Need more suggestions? Have you spotted an error? Please open a | |
[new discussion](https://huggingface.co/spaces/joaogante/generate_quality_improvement/discussions) 🙏 | |
""" | |
) | |
# ===================================================================================================================== | |
if __name__ == "__main__": | |
demo.launch() | |