Spaces:
Sleeping
Sleeping
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging | |
from huggingface_hub import login | |
import torch | |
import os | |
import gradio as gr | |
# --- 1. Authentication (Using Environment Variable - the ONLY correct way for Spaces) --- | |
# Hugging Face Spaces CANNOT use interactive login. You MUST use an environment variable. | |
# 1. Go to your Space's settings. | |
# 2. Click on "Repository Secrets". | |
# 3. Click "New Secret". | |
# 4. Name the secret: HUGGING_FACE_HUB_TOKEN | |
# 5. Paste your Hugging Face API token (with read access) as the value. | |
# 6. Save the secret. | |
# The login() call below will now automatically use the environment variable. | |
login() | |
# --- 2. Model and Tokenizer Setup (with comprehensive error handling) --- | |
def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"): | |
"""Loads the model and tokenizer, handling potential errors.""" | |
try: | |
# Suppress unnecessary warning messages from transformers | |
logging.set_verbosity_error() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", # Automatically use GPU if available, else CPU | |
torch_dtype=torch.bfloat16, # Use bfloat16 for speed/memory if supported | |
attn_implementation="flash_attention_2" # Use Flash Attention 2 if supported | |
) | |
return model, tokenizer | |
except Exception as e: | |
print(f"ERROR: Failed to load model or tokenizer: {e}") | |
print("\nTroubleshooting Steps:") | |
print("1. Ensure you have a Hugging Face account and have accepted the model's terms.") | |
print("2. Verify your internet connection.") | |
print("3. Double-check the model name: 'google/gemma-3-1b-it'") | |
print("4. Ensure you are properly authenticated using a Repository Secret (see above).") | |
print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.") | |
# Instead of exiting, raise the exception to be caught by Gradio | |
raise | |
model, tokenizer = load_model_and_tokenizer() | |
# --- 3. Chat Template Function (CRITICAL for conversational models) --- | |
def apply_chat_template(messages, tokenizer): | |
"""Applies the appropriate chat template.""" | |
try: | |
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: | |
return tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
else: | |
print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.") | |
chat_template = "{% for message in messages %}" \ | |
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \ | |
"{% endfor %}" \ | |
"{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}" | |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template) | |
except Exception as e: | |
print(f"ERROR: Failed to apply chat template: {e}") | |
raise # Re-raise to be caught by Gradio | |
# --- 4. Text Generation Function --- | |
def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2): | |
"""Generates a response.""" | |
prompt = apply_chat_template(messages, tokenizer) | |
try: | |
pipeline_instance = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
model_kwargs={"attn_implementation": "flash_attention_2"} | |
) | |
outputs = pipeline_instance( | |
prompt, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
generated_text = outputs[0]["generated_text"][len(prompt):].strip() | |
return generated_text | |
except Exception as e: | |
print(f"ERROR: Failed to generate response: {e}") | |
raise # Re-raise the exception | |
# --- 5. Gradio Interface --- | |
def predict(message, history): | |
if not history: | |
history = [] | |
messages = [] | |
for user_msg, bot_response in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if bot_response: # Check if bot_response is not None | |
messages.append({"role": "model", "content": bot_response}) | |
messages.append({"role": "user", "content": message}) | |
try: | |
response = generate_response(messages, model, tokenizer) | |
history.append((message, response)) | |
return "", history | |
except Exception as e: | |
# Catch any exceptions during generation and display in the UI | |
return f"Error: {e}", history | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot(label="Gemma Chatbot", height=500) | |
msg = gr.Textbox(placeholder="Ask me anything!", container=False, scale=7) | |
clear = gr.ClearButton([msg, chatbot]) | |
msg.submit(predict, [msg, chatbot], [msg, chatbot]) | |
demo.launch() |