kimhyunwoo's picture
Update app.py
a41650d verified
raw
history blame
5.3 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging
from huggingface_hub import login
import torch
import os
import gradio as gr
# --- 1. Authentication (Using Environment Variable - the ONLY correct way for Spaces) ---
# Hugging Face Spaces CANNOT use interactive login. You MUST use an environment variable.
# 1. Go to your Space's settings.
# 2. Click on "Repository Secrets".
# 3. Click "New Secret".
# 4. Name the secret: HUGGING_FACE_HUB_TOKEN
# 5. Paste your Hugging Face API token (with read access) as the value.
# 6. Save the secret.
# The login() call below will now automatically use the environment variable.
login()
# --- 2. Model and Tokenizer Setup (with comprehensive error handling) ---
def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
"""Loads the model and tokenizer, handling potential errors."""
try:
# Suppress unnecessary warning messages from transformers
logging.set_verbosity_error()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # Automatically use GPU if available, else CPU
torch_dtype=torch.bfloat16, # Use bfloat16 for speed/memory if supported
attn_implementation="flash_attention_2" # Use Flash Attention 2 if supported
)
return model, tokenizer
except Exception as e:
print(f"ERROR: Failed to load model or tokenizer: {e}")
print("\nTroubleshooting Steps:")
print("1. Ensure you have a Hugging Face account and have accepted the model's terms.")
print("2. Verify your internet connection.")
print("3. Double-check the model name: 'google/gemma-3-1b-it'")
print("4. Ensure you are properly authenticated using a Repository Secret (see above).")
print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.")
# Instead of exiting, raise the exception to be caught by Gradio
raise
model, tokenizer = load_model_and_tokenizer()
# --- 3. Chat Template Function (CRITICAL for conversational models) ---
def apply_chat_template(messages, tokenizer):
"""Applies the appropriate chat template."""
try:
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.")
chat_template = "{% for message in messages %}" \
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
"{% endfor %}" \
"{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)
except Exception as e:
print(f"ERROR: Failed to apply chat template: {e}")
raise # Re-raise to be caught by Gradio
# --- 4. Text Generation Function ---
def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
"""Generates a response."""
prompt = apply_chat_template(messages, tokenizer)
try:
pipeline_instance = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
model_kwargs={"attn_implementation": "flash_attention_2"}
)
outputs = pipeline_instance(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = outputs[0]["generated_text"][len(prompt):].strip()
return generated_text
except Exception as e:
print(f"ERROR: Failed to generate response: {e}")
raise # Re-raise the exception
# --- 5. Gradio Interface ---
def predict(message, history):
if not history:
history = []
messages = []
for user_msg, bot_response in history:
messages.append({"role": "user", "content": user_msg})
if bot_response: # Check if bot_response is not None
messages.append({"role": "model", "content": bot_response})
messages.append({"role": "user", "content": message})
try:
response = generate_response(messages, model, tokenizer)
history.append((message, response))
return "", history
except Exception as e:
# Catch any exceptions during generation and display in the UI
return f"Error: {e}", history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="Gemma Chatbot", height=500)
msg = gr.Textbox(placeholder="Ask me anything!", container=False, scale=7)
clear = gr.ClearButton([msg, chatbot])
msg.submit(predict, [msg, chatbot], [msg, chatbot])
demo.launch()