Spaces:
Runtime error
Runtime error
import gradio as gr | |
from openai import AsyncInferenceClient | |
# Assuming client is a global variable | |
client = AsyncInferenceClient("meta-llama/Llama-2-70b-chat-hf") | |
def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0): | |
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " if system_prompt else "<s>[INST] " | |
temperature = max(1e-2, float(temperature)) | |
top_p = float(top_p) | |
for interaction in chatbot: | |
input_prompt += f"{interaction[0]} [/INST] {interaction[1]} </s><s>[INST] " | |
input_prompt += f"{message} [/INST] " | |
partial_message = "" | |
for token in client.text_generation( | |
prompt=input_prompt, | |
max_new_tokens=max_new_tokens, | |
stream=True, | |
best_of=1, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
repetition_penalty=repetition_penalty, | |
): | |
partial_message += token | |
yield partial_message | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox("text", label="Message"), | |
gr.Textbox("text", label="Chatbot"), | |
gr.Textbox("text", label="System Prompt"), | |
gr.Number("slider", minimum=0.1, maximum=2, default=0.9, label="Temperature"), | |
gr.Number("slider", minimum=1, maximum=1000, default=256, label="Max New Tokens"), | |
gr.Number("slider", minimum=0.1, maximum=1, default=0.6, label="Top P"), | |
gr.Number("slider", minimum=0.1, maximum=2, default=1.0, label="Repetition Penalty"), | |
], | |
outputs=gr.Textbox(), | |
) | |
iface.launch() | |