Llama2 / app.py
MAsad789565's picture
Update app.py
439424f verified
import gradio as gr
from openai import AsyncInferenceClient
# Assuming client is a global variable
client = AsyncInferenceClient("meta-llama/Llama-2-70b-chat-hf")
def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0):
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " if system_prompt else "<s>[INST] "
temperature = max(1e-2, float(temperature))
top_p = float(top_p)
for interaction in chatbot:
input_prompt += f"{interaction[0]} [/INST] {interaction[1]} </s><s>[INST] "
input_prompt += f"{message} [/INST] "
partial_message = ""
for token in client.text_generation(
prompt=input_prompt,
max_new_tokens=max_new_tokens,
stream=True,
best_of=1,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
):
partial_message += token
yield partial_message
# Create a Gradio interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox("text", label="Message"),
gr.Textbox("text", label="Chatbot"),
gr.Textbox("text", label="System Prompt"),
gr.Number("slider", minimum=0.1, maximum=2, default=0.9, label="Temperature"),
gr.Number("slider", minimum=1, maximum=1000, default=256, label="Max New Tokens"),
gr.Number("slider", minimum=0.1, maximum=1, default=0.6, label="Top P"),
gr.Number("slider", minimum=0.1, maximum=2, default=1.0, label="Repetition Penalty"),
],
outputs=gr.Textbox(),
)
iface.launch()