import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # Model and tokenizer loading (with error handling) try: model_name = "google/gemma-3-1b-it" # Correct model name tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency, if supported device_map="auto", # Automatically use GPU if available, otherwise CPU ) # Create the pipeline pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype device_map="auto", # and device mapping model_kwargs={"attn_implementation": "flash_attention_2"} # Enable Flash Attention 2 if supported by your hardware and transformers version ) except Exception as e: error_message = f"Error loading model or tokenizer: {e}" print(error_message) # Log the error to the console # Provide a fallback, even if it's just displaying the error. def error_response(message, history): return f"Model loading failed. Error: {error_message}" # Minimal Gradio interface to show the error with gr.Blocks() as demo: gr.ChatInterface(error_response) demo.launch() exit() # Important: exit to prevent running the rest of the (broken) code # Chat template handling (important for correct prompting) def apply_chat_template(messages, chat_template=None): """Applies the chat template to the message history. Args: messages: A list of dictionaries, where each dictionary has a "role" ("user" or "assistant") and "content" key. chat_template: The chat template string (optional). If None, try to get from tokenizer. Returns: A single string representing the formatted conversation. """ if chat_template is None: if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: chat_template = tokenizer.chat_template else: # Fallback to a simple template if no chat template is found. This is # *critical* to prevent the model from generating nonsensical output. chat_template = "{% for message in messages %}" \ "{{ '' + message['role'] + '\n' + message['content'] + '\n' }}" \ "{% endfor %}" \ "{% if add_generation_prompt %}{{ 'model\n' }}{% endif %}" return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template ) # Prediction function (modified for chat) def predict(message, history): """Generates a response to the user's message. Args: message: The user's input message (string). history: A list of (user_message, bot_response) tuples representing the conversation history. Returns: The generated bot response (string). """ # Build the conversation history in the required format. messages = [] for user_msg, bot_response in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "model", "content": bot_response}) messages.append({"role": "user", "content": message}) # Apply the chat template. prompt = apply_chat_template(messages) # Generate the response using the pipeline (much cleaner). try: sequences = pipe( prompt, max_new_tokens=512, # Limit response length do_sample=True, # Use sampling for more diverse responses temperature=0.7, # Control randomness (higher = more random) top_k=50, # Top-k sampling top_p=0.95, # Nucleus sampling repetition_penalty=1.2, # Reduce repetition pad_token_id=tokenizer.eos_token_id, # Ensure padding is correct. ) response = sequences[0]['generated_text'][len(prompt):].strip() # Extract *only* generated text return response except Exception as e: return f"An error occurred during generation: {e}" # Gradio interface (using gr.ChatInterface for a chatbot UI) with gr.Blocks() as demo: gr.ChatInterface( predict, chatbot=gr.Chatbot(height=500), # Set a reasonable height textbox=gr.Textbox(placeholder="Ask me anything!", container=False, scale=7), title="Gemma-3-1b-it Chatbot", description="Chat with the Gemma-3-1b-it model.", retry_btn=None, # Remove redundant buttons undo_btn=None, clear_btn=None, ) demo.launch(share=False) # Set share=True to create a publicly shareable link