Spaces:
Sleeping
Sleeping
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging | |
from huggingface_hub import login | |
import torch | |
import os | |
# --- 1. Authentication (Choose ONE method and follow the instructions) --- | |
# Method 1: Environment Variable (RECOMMENDED for security and Hugging Face Spaces) | |
# - Set the HUGGING_FACE_HUB_TOKEN environment variable *before* running. | |
# - Linux/macOS: `export HUGGING_FACE_HUB_TOKEN=your_token` (in terminal) | |
# - Windows (PowerShell): `$env:HUGGING_FACE_HUB_TOKEN = "your_token"` | |
# - Hugging Face Spaces: Add `HUGGING_FACE_HUB_TOKEN` as a secret in your Space's settings. | |
# - Then, uncomment the following line: | |
login() | |
# Method 2: Direct Token (ONLY for local testing, NOT for deployment) | |
# - Replace "YOUR_HUGGING_FACE_TOKEN" with your actual token. | |
# - WARNING: Do NOT commit your token to a public repository! | |
# login(token="YOUR_HUGGING_FACE_TOKEN") | |
# Method 3: huggingface-cli (Interactive, one-time setup, good for local development) | |
# - Run `huggingface-cli login` in your terminal. | |
# - Paste your token when prompted. | |
# - No code changes are needed after this; the token is stored. | |
# --- 2. Model and Tokenizer Setup (with comprehensive error handling) --- | |
def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"): | |
"""Loads the model and tokenizer, handling potential errors.""" | |
try: | |
# Suppress unnecessary warning messages from transformers | |
logging.set_verbosity_error() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", # Automatically use GPU if available, else CPU | |
torch_dtype=torch.bfloat16, # Use bfloat16 for speed/memory if supported | |
attn_implementation="flash_attention_2" # Use Flash Attention 2 if supported | |
) | |
return model, tokenizer | |
except Exception as e: | |
print(f"ERROR: Failed to load model or tokenizer: {e}") | |
print("\nTroubleshooting Steps:") | |
print("1. Ensure you have a Hugging Face account and have accepted the model's terms.") | |
print("2. Verify your internet connection.") | |
print("3. Double-check the model name: 'google/gemma-3-1b-it'") | |
print("4. Ensure you are properly authenticated (see authentication section above).") | |
print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.") | |
exit(1) # Exit with an error code | |
model, tokenizer = load_model_and_tokenizer() | |
# --- 3. Chat Template Function (CRITICAL for conversational models) --- | |
def apply_chat_template(messages, tokenizer): | |
"""Applies the appropriate chat template to the message history. | |
Args: | |
messages: A list of dictionaries, where each dictionary has 'role' (user/model) | |
and 'content' keys. | |
tokenizer: The tokenizer object. | |
Returns: | |
A formatted prompt string ready for the model. | |
""" | |
try: | |
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: | |
# Use the tokenizer's built-in chat template if available | |
return tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
else: | |
# Fallback to a standard chat template if no specific one is found | |
print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.") | |
chat_template = "{% for message in messages %}" \ | |
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \ | |
"{% endfor %}" \ | |
"{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}" | |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template) | |
except Exception as e: | |
print(f"ERROR: Failed to apply chat template: {e}") | |
exit(1) | |
# --- 4. Text Generation Function --- | |
def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2): | |
"""Generates a response using the model and tokenizer.""" | |
prompt = apply_chat_template(messages, tokenizer) | |
try: | |
pipeline_instance = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype | |
device_map="auto", # and device mapping | |
model_kwargs={"attn_implementation": "flash_attention_2"} | |
) | |
outputs = pipeline_instance( | |
prompt, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id, # Important for proper padding | |
) | |
# Extract *only* the generated text (remove the prompt) | |
generated_text = outputs[0]["generated_text"][len(prompt):].strip() | |
return generated_text | |
except Exception as e: | |
print(f"ERROR: Failed to generate response: {e}") | |
return "Sorry, I encountered an error while generating a response." | |
# --- 5. Main Interaction Loop (for command-line interaction) --- | |
def main(): | |
"""Main function for interactive command-line chat.""" | |
messages = [] # Initialize the conversation history | |
while True: | |
user_input = input("You: ") | |
if user_input.lower() in ("exit", "quit", "bye"): | |
break | |
messages.append({"role": "user", "content": user_input}) | |
response = generate_response(messages, model, tokenizer) | |
print(f"Model: {response}") | |
messages.append({"role": "model", "content": response}) | |
if __name__ == "__main__": | |
main() |