File size: 1,621 Bytes
73b4a04
439424f
73b4a04
439424f
b12c822
73b4a04
439424f
 
 
73b4a04
87ca261
b12c822
439424f
73b4a04
439424f
b12c822
439424f
 
87ca261
 
 
 
 
 
 
 
 
439424f
b12c822
87ca261
439424f
 
 
 
 
 
 
 
 
 
 
 
 
87ca261
 
439424f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
from openai import AsyncInferenceClient

# Assuming client is a global variable
client = AsyncInferenceClient("meta-llama/Llama-2-70b-chat-hf")

def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0):
    input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " if system_prompt else "<s>[INST] "
    temperature = max(1e-2, float(temperature))
    top_p = float(top_p)

    for interaction in chatbot:
        input_prompt += f"{interaction[0]} [/INST] {interaction[1]} </s><s>[INST] "

    input_prompt += f"{message} [/INST] "
    partial_message = ""

    for token in client.text_generation(
        prompt=input_prompt,
        max_new_tokens=max_new_tokens,
        stream=True,
        best_of=1,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        repetition_penalty=repetition_penalty,
    ):
        partial_message += token
        yield partial_message

# Create a Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Textbox("text", label="Message"),
        gr.Textbox("text", label="Chatbot"),
        gr.Textbox("text", label="System Prompt"),
        gr.Number("slider", minimum=0.1, maximum=2, default=0.9, label="Temperature"),
        gr.Number("slider", minimum=1, maximum=1000, default=256, label="Max New Tokens"),
        gr.Number("slider", minimum=0.1, maximum=1, default=0.6, label="Top P"),
        gr.Number("slider", minimum=0.1, maximum=2, default=1.0, label="Repetition Penalty"),
    ],
    outputs=gr.Textbox(),
)

iface.launch()