File size: 4,806 Bytes
93a8576
03844f7
 
93a8576
03844f7
 
 
 
 
 
 
 
 
 
 
 
 
 
93a8576
03844f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29462dd
dddcb06
 
 
 
 
03844f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a9b76
 
 
 
 
 
03844f7
 
 
 
 
 
cdf3b28
 
 
 
 
 
d9ee818
 
 
 
cdf3b28
2b2223b
cdf3b28
 
 
 
d9ee818
2b2223b
cdf3b28
03844f7
 
5265793
d9ee818
 
03844f7
 
5265793
d9ee818
 
dddcb06
03844f7
5265793
cdf3b28
15a9b76
03844f7
15a9b76
 
 
 
cdf3b28
03844f7
5265793
cdf3b28
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import huggingface_hub as hfh
from requests.exceptions import HTTPError

# =====================================================================================================================
# DATA
# =====================================================================================================================
# Dict with the tasks considered in this spaces, {pretty name: space tag}
TASK_TYPES = {
    "✍️ Text Generation": "txtgen",
    "🤏 Summarization": "summ",
    "🫂 Translation": "trans",
    "💬 Conversational / Chatbot": "chat",
    "🤷 Text Question Answering": "txtqa",
    "🕵️ (Table/Document/Visual) Question Answering": "otherqa",
    "🎤 Automatic Speech Recognition": "asr",
    "🌇 Image to Text": "img2txt",
}

# Dict matching all task types with their possible hub tags, {space tag: (possible hub tags)}
HUB_TAGS = {
    "txtgen": ("text-generation", "text2text-generation"),
    "summ": ("summarization", "text-generation", "text2text-generation"),
    "trans": ("translation", "text-generation", "text2text-generation"),
    "chat": ("conversational", "text-generation", "text2text-generation"),
    "txtqa": ("text-generation", "text2text-generation"),
    "otherqa": ("table-question-answering", "document-question-answering", "visual-question-answering"),
    "asr": ("automatic-speech-recognition"),
    "img2txt": ("image-to-text"),
}
assert len(TASK_TYPES) == len(TASK_TYPES)
assert all(tag in HUB_TAGS for tag in TASK_TYPES.values())

# Dict with the problems considered in this spaces, {problem: space tag}
PROBLEMS = {
    "I would like a ChatGPT-like model": "chatgpt",
    "I want to improve the overall quality of the output": "quality",
    "Speed! Make it faster 🚀": "faster",
    "I would like to reduce model hallucinations": "hallucinations",
    "I want to control the length of the output": "length",
    "The model is returning nothing / random words 🤔": "random",
    "Other": "other",
}

# =====================================================================================================================


# =====================================================================================================================
# SUGGESTIONS LOGIC
# =====================================================================================================================
def is_valid_task_for_model(model_name, task_type):
    if model_name == "":
        return True
    try:
        model_tags = hfh.HfApi().model_info(model_name).tags
    except HTTPError:
        return True  # Assume everything is okay

    possible_tags = HUB_TAGS[TASK_TYPES[task_type]]
    return any(tag in model_tags for tag in possible_tags)


def get_suggestions(task_type, model_name, problem_type):
    if task_type == "" or problem_type == "":
        return "👈 Please select a task type and a problem type."
    return ""
# =====================================================================================================================


# =====================================================================================================================
# GRADIO
# =====================================================================================================================
demo = gr.Blocks()
with demo:
    gr.Markdown(
        """
        # 🚀💬 Improving Text Generation 💬🚀

        This is a ever-evolving guide on how to improve your text generation results. It is community-led and
        curated by Hugging Face 🤗

        How to use it:
        1. Answer the questions using the dropdown menus
        2. Click on "Get Suggestions" button
        3. Explore the suggestions 🤗
        """
    )

    with gr.Row():
        with gr.Column():
            task_type = gr.Dropdown(
                label="What is the task you're trying to accomplish?",
                choices=list(TASK_TYPES.keys()),
                interactive=True,
            )
            model_name = gr.Textbox(
                label="Which model are you using? (leave blank if you haven't decided)",
                placeholder="e.g. google/flan-t5-xl",
                interactive=True,
            )
            problem_type = gr.Dropdown(
                label="What would you like to improve?",
                choices=list(PROBLEMS.keys()),
                interactive=True,
            )
            button = gr.Button(value="Get Suggestions!")
        with gr.Column(scale=2):
            suggestions = gr.Markdown(value="")

    button.click(get_suggestions, inputs=[task_type, model_name, problem_type], outputs=suggestions)


# =====================================================================================================================

if __name__ == "__main__":
    demo.launch()