from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging from huggingface_hub import login import torch import os # --- 1. Authentication (Choose ONE method and follow the instructions) --- # Method 1: Environment Variable (RECOMMENDED for security and Hugging Face Spaces) # - Set the HUGGING_FACE_HUB_TOKEN environment variable *before* running. # - Linux/macOS: `export HUGGING_FACE_HUB_TOKEN=your_token` (in terminal) # - Windows (PowerShell): `$env:HUGGING_FACE_HUB_TOKEN = "your_token"` # - Hugging Face Spaces: Add `HUGGING_FACE_HUB_TOKEN` as a secret in your Space's settings. # - Then, uncomment the following line: login() # Method 2: Direct Token (ONLY for local testing, NOT for deployment) # - Replace "YOUR_HUGGING_FACE_TOKEN" with your actual token. # - WARNING: Do NOT commit your token to a public repository! # login(token="YOUR_HUGGING_FACE_TOKEN") # Method 3: huggingface-cli (Interactive, one-time setup, good for local development) # - Run `huggingface-cli login` in your terminal. # - Paste your token when prompted. # - No code changes are needed after this; the token is stored. # --- 2. Model and Tokenizer Setup (with comprehensive error handling) --- def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"): """Loads the model and tokenizer, handling potential errors.""" try: # Suppress unnecessary warning messages from transformers logging.set_verbosity_error() tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", # Automatically use GPU if available, else CPU torch_dtype=torch.bfloat16, # Use bfloat16 for speed/memory if supported attn_implementation="flash_attention_2" # Use Flash Attention 2 if supported ) return model, tokenizer except Exception as e: print(f"ERROR: Failed to load model or tokenizer: {e}") print("\nTroubleshooting Steps:") print("1. Ensure you have a Hugging Face account and have accepted the model's terms.") print("2. Verify your internet connection.") print("3. Double-check the model name: 'google/gemma-3-1b-it'") print("4. Ensure you are properly authenticated (see authentication section above).") print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.") exit(1) # Exit with an error code model, tokenizer = load_model_and_tokenizer() # --- 3. Chat Template Function (CRITICAL for conversational models) --- def apply_chat_template(messages, tokenizer): """Applies the appropriate chat template to the message history. Args: messages: A list of dictionaries, where each dictionary has 'role' (user/model) and 'content' keys. tokenizer: The tokenizer object. Returns: A formatted prompt string ready for the model. """ try: if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: # Use the tokenizer's built-in chat template if available return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: # Fallback to a standard chat template if no specific one is found print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.") chat_template = "{% for message in messages %}" \ "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \ "{% endfor %}" \ "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}" return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template) except Exception as e: print(f"ERROR: Failed to apply chat template: {e}") exit(1) # --- 4. Text Generation Function --- def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2): """Generates a response using the model and tokenizer.""" prompt = apply_chat_template(messages, tokenizer) try: pipeline_instance = pipeline( "text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype device_map="auto", # and device mapping model_kwargs={"attn_implementation": "flash_attention_2"} ) outputs = pipeline_instance( prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id, # Important for proper padding ) # Extract *only* the generated text (remove the prompt) generated_text = outputs[0]["generated_text"][len(prompt):].strip() return generated_text except Exception as e: print(f"ERROR: Failed to generate response: {e}") return "Sorry, I encountered an error while generating a response." # --- 5. Main Interaction Loop (for command-line interaction) --- def main(): """Main function for interactive command-line chat.""" messages = [] # Initialize the conversation history while True: user_input = input("You: ") if user_input.lower() in ("exit", "quit", "bye"): break messages.append({"role": "user", "content": user_input}) response = generate_response(messages, model, tokenizer) print(f"Model: {response}") messages.append({"role": "model", "content": response}) if __name__ == "__main__": main()