kimhyunwoo's picture
Update app.py
c5ec987 verified
raw
history blame
4.63 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging
from huggingface_hub import login
import torch
import os
import gradio as gr
# --- 1. Authentication (Using User-Provided Token) ---
def authenticate(token):
"""Attempts to authenticate with the provided token."""
try:
login(token=token)
return True
except Exception as e:
print(f"Authentication failed: {e}")
return False
# --- 2. Model and Tokenizer Setup ---
def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
"""Loads the model and tokenizer."""
try:
logging.set_verbosity_error()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
)
return model, tokenizer
except Exception as e:
print(f"ERROR: Failed to load model/tokenizer: {e}")
raise # Re-raise for Gradio
# --- 3. Chat Template Function ---
def apply_chat_template(messages, tokenizer):
"""Applies the chat template."""
try:
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
print("WARNING: Tokenizer lacks chat_template. Using fallback.")
chat_template = "{% for message in messages %}" \
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
"{% endfor %}" \
"{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)
except Exception as e:
print(f"ERROR: Chat template application failed: {e}")
raise
# --- 4. Text Generation Function ---
def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
"""Generates a response."""
prompt = apply_chat_template(messages, tokenizer)
try:
pipeline_instance = pipeline(
"text-generation", model=model, tokenizer=tokenizer,
torch_dtype=torch.bfloat16, device_map="auto",
model_kwargs={"attn_implementation": "flash_attention_2"}
)
outputs = pipeline_instance(
prompt, max_new_tokens=max_new_tokens, do_sample=True,
temperature=temperature, top_k=top_k, top_p=top_p,
repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id
)
return outputs[0]["generated_text"][len(prompt):].strip()
except Exception as e:
print(f"ERROR: Response generation failed: {e}")
raise
# --- 5. Gradio Interface ---
model = None # Initialize model and tokenizer as global variables
tokenizer = None
def chat(token, message, history):
global model, tokenizer # Access the global model and tokenizer
if not authenticate(token):
return "Authentication failed. Please enter a valid Hugging Face token.", history
if model is None or tokenizer is None:
try:
model, tokenizer = load_model_and_tokenizer()
except Exception as e:
return f"Model loading error: {e}", history
if not history:
history = []
messages = [{"role": "user", "content": msg} for msg, _ in history]
messages.extend([{"role": "model", "content": resp} for _, resp in history if resp])
messages.append({"role": "user", "content": message})
try:
response = generate_response(messages, model, tokenizer)
history.append((message, response))
return "", history
except Exception as e:
return f"Error during generation: {e}", history
with gr.Blocks() as demo:
gr.Markdown("# Gemma Chatbot")
gr.Markdown("Enter your Hugging Face API token (read access required):")
token_input = gr.Textbox(label="Hugging Face Token", type="password") # Use type="password"
chatbot = gr.Chatbot(label="Chat", height=400)
msg_input = gr.Textbox(label="Message", placeholder="Ask me anything!")
clear_btn = gr.ClearButton([msg_input, chatbot])
msg_input.submit(chat, [token_input, msg_input, chatbot], [msg_input, chatbot])
clear_btn.click(lambda: (None, []), [], [msg_input, chatbot])
demo.launch()