Spaces:
Sleeping
Sleeping
File size: 4,893 Bytes
ce2134c 662b714 ce2134c 662b714 ce2134c 662b714 ce2134c 662b714 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Model and tokenizer loading (with error handling)
try:
model_name = "google/gemma-3-1b-it" # Correct model name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency, if supported
device_map="auto", # Automatically use GPU if available, otherwise CPU
)
# Create the pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype
device_map="auto", # and device mapping
model_kwargs={"attn_implementation": "flash_attention_2"} # Enable Flash Attention 2 if supported by your hardware and transformers version
)
except Exception as e:
error_message = f"Error loading model or tokenizer: {e}"
print(error_message) # Log the error to the console
# Provide a fallback, even if it's just displaying the error.
def error_response(message, history):
return f"Model loading failed. Error: {error_message}"
# Minimal Gradio interface to show the error
with gr.Blocks() as demo:
gr.ChatInterface(error_response)
demo.launch()
exit() # Important: exit to prevent running the rest of the (broken) code
# Chat template handling (important for correct prompting)
def apply_chat_template(messages, chat_template=None):
"""Applies the chat template to the message history.
Args:
messages: A list of dictionaries, where each dictionary has a "role"
("user" or "assistant") and "content" key.
chat_template: The chat template string (optional). If None,
try to get from tokenizer.
Returns:
A single string representing the formatted conversation.
"""
if chat_template is None:
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
chat_template = tokenizer.chat_template
else:
# Fallback to a simple template if no chat template is found. This is
# *critical* to prevent the model from generating nonsensical output.
chat_template = "{% for message in messages %}" \
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
"{% endfor %}" \
"{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template
)
# Prediction function (modified for chat)
def predict(message, history):
"""Generates a response to the user's message.
Args:
message: The user's input message (string).
history: A list of (user_message, bot_response) tuples representing
the conversation history.
Returns:
The generated bot response (string).
"""
# Build the conversation history in the required format.
messages = []
for user_msg, bot_response in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "model", "content": bot_response})
messages.append({"role": "user", "content": message})
# Apply the chat template.
prompt = apply_chat_template(messages)
# Generate the response using the pipeline (much cleaner).
try:
sequences = pipe(
prompt,
max_new_tokens=512, # Limit response length
do_sample=True, # Use sampling for more diverse responses
temperature=0.7, # Control randomness (higher = more random)
top_k=50, # Top-k sampling
top_p=0.95, # Nucleus sampling
repetition_penalty=1.2, # Reduce repetition
pad_token_id=tokenizer.eos_token_id, # Ensure padding is correct.
)
response = sequences[0]['generated_text'][len(prompt):].strip() # Extract *only* generated text
return response
except Exception as e:
return f"An error occurred during generation: {e}"
# Gradio interface (using gr.ChatInterface for a chatbot UI)
with gr.Blocks() as demo:
gr.ChatInterface(
predict,
chatbot=gr.Chatbot(height=500), # Set a reasonable height
textbox=gr.Textbox(placeholder="Ask me anything!", container=False, scale=7),
title="Gemma-3-1b-it Chatbot",
description="Chat with the Gemma-3-1b-it model.",
retry_btn=None, # Remove redundant buttons
undo_btn=None,
clear_btn=None,
)
demo.launch(share=False) # Set share=True to create a publicly shareable link |