import gradio as gr from openai import AsyncInferenceClient # Assuming client is a global variable client = AsyncInferenceClient("meta-llama/Llama-2-70b-chat-hf") def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0): input_prompt = f"[INST] <>\n{system_prompt}\n<>\n\n " if system_prompt else "[INST] " temperature = max(1e-2, float(temperature)) top_p = float(top_p) for interaction in chatbot: input_prompt += f"{interaction[0]} [/INST] {interaction[1]} [INST] " input_prompt += f"{message} [/INST] " partial_message = "" for token in client.text_generation( prompt=input_prompt, max_new_tokens=max_new_tokens, stream=True, best_of=1, temperature=temperature, top_p=top_p, do_sample=True, repetition_penalty=repetition_penalty, ): partial_message += token yield partial_message # Create a Gradio interface iface = gr.Interface( fn=predict, inputs=[ gr.Textbox("text", label="Message"), gr.Textbox("text", label="Chatbot"), gr.Textbox("text", label="System Prompt"), gr.Number("slider", minimum=0.1, maximum=2, default=0.9, label="Temperature"), gr.Number("slider", minimum=1, maximum=1000, default=256, label="Max New Tokens"), gr.Number("slider", minimum=0.1, maximum=1, default=0.6, label="Top P"), gr.Number("slider", minimum=0.1, maximum=2, default=1.0, label="Repetition Penalty"), ], outputs=gr.Textbox(), ) iface.launch()