kimhyunwoo's picture
Update app.py
da1470a verified
raw
history blame
5.97 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging
from huggingface_hub import login
import torch
import os
# --- 1. Authentication (Choose ONE method and follow the instructions) ---
# Method 1: Environment Variable (RECOMMENDED for security and Hugging Face Spaces)
# - Set the HUGGING_FACE_HUB_TOKEN environment variable *before* running.
# - Linux/macOS: `export HUGGING_FACE_HUB_TOKEN=your_token` (in terminal)
# - Windows (PowerShell): `$env:HUGGING_FACE_HUB_TOKEN = "your_token"`
# - Hugging Face Spaces: Add `HUGGING_FACE_HUB_TOKEN` as a secret in your Space's settings.
# - Then, uncomment the following line:
login()
# Method 2: Direct Token (ONLY for local testing, NOT for deployment)
# - Replace "YOUR_HUGGING_FACE_TOKEN" with your actual token.
# - WARNING: Do NOT commit your token to a public repository!
# login(token="YOUR_HUGGING_FACE_TOKEN")
# Method 3: huggingface-cli (Interactive, one-time setup, good for local development)
# - Run `huggingface-cli login` in your terminal.
# - Paste your token when prompted.
# - No code changes are needed after this; the token is stored.
# --- 2. Model and Tokenizer Setup (with comprehensive error handling) ---
def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
"""Loads the model and tokenizer, handling potential errors."""
try:
# Suppress unnecessary warning messages from transformers
logging.set_verbosity_error()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # Automatically use GPU if available, else CPU
torch_dtype=torch.bfloat16, # Use bfloat16 for speed/memory if supported
attn_implementation="flash_attention_2" # Use Flash Attention 2 if supported
)
return model, tokenizer
except Exception as e:
print(f"ERROR: Failed to load model or tokenizer: {e}")
print("\nTroubleshooting Steps:")
print("1. Ensure you have a Hugging Face account and have accepted the model's terms.")
print("2. Verify your internet connection.")
print("3. Double-check the model name: 'google/gemma-3-1b-it'")
print("4. Ensure you are properly authenticated (see authentication section above).")
print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.")
exit(1) # Exit with an error code
model, tokenizer = load_model_and_tokenizer()
# --- 3. Chat Template Function (CRITICAL for conversational models) ---
def apply_chat_template(messages, tokenizer):
"""Applies the appropriate chat template to the message history.
Args:
messages: A list of dictionaries, where each dictionary has 'role' (user/model)
and 'content' keys.
tokenizer: The tokenizer object.
Returns:
A formatted prompt string ready for the model.
"""
try:
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
# Use the tokenizer's built-in chat template if available
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
# Fallback to a standard chat template if no specific one is found
print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.")
chat_template = "{% for message in messages %}" \
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
"{% endfor %}" \
"{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)
except Exception as e:
print(f"ERROR: Failed to apply chat template: {e}")
exit(1)
# --- 4. Text Generation Function ---
def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
"""Generates a response using the model and tokenizer."""
prompt = apply_chat_template(messages, tokenizer)
try:
pipeline_instance = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype
device_map="auto", # and device mapping
model_kwargs={"attn_implementation": "flash_attention_2"}
)
outputs = pipeline_instance(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id, # Important for proper padding
)
# Extract *only* the generated text (remove the prompt)
generated_text = outputs[0]["generated_text"][len(prompt):].strip()
return generated_text
except Exception as e:
print(f"ERROR: Failed to generate response: {e}")
return "Sorry, I encountered an error while generating a response."
# --- 5. Main Interaction Loop (for command-line interaction) ---
def main():
"""Main function for interactive command-line chat."""
messages = [] # Initialize the conversation history
while True:
user_input = input("You: ")
if user_input.lower() in ("exit", "quit", "bye"):
break
messages.append({"role": "user", "content": user_input})
response = generate_response(messages, model, tokenizer)
print(f"Model: {response}")
messages.append({"role": "model", "content": response})
if __name__ == "__main__":
main()