Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from unsloth import FastLanguageModel | |
# --- Load Model --- | |
max_seq_length = 4096 | |
dtype = torch.float16 | |
load_in_4bit = True | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = "Ahmed-El-Sharkawy/Meta-Llama-3.1-8B-alpaca", # directly load your uploaded model | |
max_seq_length = max_seq_length, | |
dtype = dtype, | |
load_in_4bit = load_in_4bit, | |
) | |
FastLanguageModel.for_inference(model) # Enable 2x faster inference | |
# Define Alpaca prompt | |
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
### Instruction: | |
{instruction} | |
### Input: | |
{input_text} | |
### Response: | |
""" | |
def generate_response(instruction, input_text): | |
prompt = alpaca_prompt.format(instruction=instruction, input_text=input_text) | |
inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
use_cache=True | |
) | |
decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
response = decoded_output[0].replace("<|begin_of_text|>", "").replace("<|end_of_text|>", "").strip() | |
# Optional: Remove the prompt part if model echoes it back | |
if prompt.strip() in response: | |
response = response.replace(prompt.strip(), "").strip() | |
return response | |
# --- Gradio UI --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# π LLaMA-3 Alpaca Fine-tuned Chatbot") | |
with gr.Row(): | |
instruction = gr.Textbox(label="Instruction", lines=2) | |
input_text = gr.Textbox(label="Input", lines=5) | |
output = gr.Textbox(label="Response", lines=10) | |
btn = gr.Button("Generate") | |
btn.click(generate_response, [instruction, input_text], output) | |
demo.launch() | |