Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
# Model and tokenizer loading (with error handling) | |
try: | |
model_name = "google/gemma-3-1b-it" # Correct model name | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency, if supported | |
device_map="auto", # Automatically use GPU if available, otherwise CPU | |
) | |
# Create the pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype | |
device_map="auto", # and device mapping | |
model_kwargs={"attn_implementation": "flash_attention_2"} # Enable Flash Attention 2 if supported by your hardware and transformers version | |
) | |
except Exception as e: | |
error_message = f"Error loading model or tokenizer: {e}" | |
print(error_message) # Log the error to the console | |
# Provide a fallback, even if it's just displaying the error. | |
def error_response(message, history): | |
return f"Model loading failed. Error: {error_message}" | |
# Minimal Gradio interface to show the error | |
with gr.Blocks() as demo: | |
gr.ChatInterface(error_response) | |
demo.launch() | |
exit() # Important: exit to prevent running the rest of the (broken) code | |
# Chat template handling (important for correct prompting) | |
def apply_chat_template(messages, chat_template=None): | |
"""Applies the chat template to the message history. | |
Args: | |
messages: A list of dictionaries, where each dictionary has a "role" | |
("user" or "assistant") and "content" key. | |
chat_template: The chat template string (optional). If None, | |
try to get from tokenizer. | |
Returns: | |
A single string representing the formatted conversation. | |
""" | |
if chat_template is None: | |
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: | |
chat_template = tokenizer.chat_template | |
else: | |
# Fallback to a simple template if no chat template is found. This is | |
# *critical* to prevent the model from generating nonsensical output. | |
chat_template = "{% for message in messages %}" \ | |
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \ | |
"{% endfor %}" \ | |
"{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}" | |
return tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template | |
) | |
# Prediction function (modified for chat) | |
def predict(message, history): | |
"""Generates a response to the user's message. | |
Args: | |
message: The user's input message (string). | |
history: A list of (user_message, bot_response) tuples representing | |
the conversation history. | |
Returns: | |
The generated bot response (string). | |
""" | |
# Build the conversation history in the required format. | |
messages = [] | |
for user_msg, bot_response in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "model", "content": bot_response}) | |
messages.append({"role": "user", "content": message}) | |
# Apply the chat template. | |
prompt = apply_chat_template(messages) | |
# Generate the response using the pipeline (much cleaner). | |
try: | |
sequences = pipe( | |
prompt, | |
max_new_tokens=512, # Limit response length | |
do_sample=True, # Use sampling for more diverse responses | |
temperature=0.7, # Control randomness (higher = more random) | |
top_k=50, # Top-k sampling | |
top_p=0.95, # Nucleus sampling | |
repetition_penalty=1.2, # Reduce repetition | |
pad_token_id=tokenizer.eos_token_id, # Ensure padding is correct. | |
) | |
response = sequences[0]['generated_text'][len(prompt):].strip() # Extract *only* generated text | |
return response | |
except Exception as e: | |
return f"An error occurred during generation: {e}" | |
# Gradio interface (using gr.ChatInterface for a chatbot UI) | |
with gr.Blocks() as demo: | |
gr.ChatInterface( | |
predict, | |
chatbot=gr.Chatbot(height=500), # Set a reasonable height | |
textbox=gr.Textbox(placeholder="Ask me anything!", container=False, scale=7), | |
title="Gemma-3-1b-it Chatbot", | |
description="Chat with the Gemma-3-1b-it model.", | |
retry_btn=None, # Remove redundant buttons | |
undo_btn=None, | |
clear_btn=None, | |
) | |
demo.launch(share=False) # Set share=True to create a publicly shareable link |