Spaces:
Runtime error
Runtime error
File size: 1,621 Bytes
73b4a04 439424f 73b4a04 439424f b12c822 73b4a04 439424f 73b4a04 87ca261 b12c822 439424f 73b4a04 439424f b12c822 439424f 87ca261 439424f b12c822 87ca261 439424f 87ca261 439424f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import gradio as gr
from openai import AsyncInferenceClient
# Assuming client is a global variable
client = AsyncInferenceClient("meta-llama/Llama-2-70b-chat-hf")
def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0):
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " if system_prompt else "<s>[INST] "
temperature = max(1e-2, float(temperature))
top_p = float(top_p)
for interaction in chatbot:
input_prompt += f"{interaction[0]} [/INST] {interaction[1]} </s><s>[INST] "
input_prompt += f"{message} [/INST] "
partial_message = ""
for token in client.text_generation(
prompt=input_prompt,
max_new_tokens=max_new_tokens,
stream=True,
best_of=1,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
):
partial_message += token
yield partial_message
# Create a Gradio interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox("text", label="Message"),
gr.Textbox("text", label="Chatbot"),
gr.Textbox("text", label="System Prompt"),
gr.Number("slider", minimum=0.1, maximum=2, default=0.9, label="Temperature"),
gr.Number("slider", minimum=1, maximum=1000, default=256, label="Max New Tokens"),
gr.Number("slider", minimum=0.1, maximum=1, default=0.6, label="Top P"),
gr.Number("slider", minimum=0.1, maximum=2, default=1.0, label="Repetition Penalty"),
],
outputs=gr.Textbox(),
)
iface.launch()
|