File size: 4,893 Bytes
ce2134c
662b714
 
ce2134c
662b714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce2134c
662b714
 
 
 
 
 
 
 
 
 
ce2134c
662b714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Model and tokenizer loading (with error handling)
try:
    model_name = "google/gemma-3-1b-it"  # Correct model name
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,  # Use bfloat16 for efficiency, if supported
        device_map="auto",  # Automatically use GPU if available, otherwise CPU
    )
    # Create the pipeline
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype
        device_map="auto", # and device mapping
        model_kwargs={"attn_implementation": "flash_attention_2"}  # Enable Flash Attention 2 if supported by your hardware and transformers version
    )

except Exception as e:
    error_message = f"Error loading model or tokenizer: {e}"
    print(error_message)  # Log the error to the console
    #  Provide a fallback, even if it's just displaying the error.
    def error_response(message, history):
        return f"Model loading failed.  Error: {error_message}"
    
    # Minimal Gradio interface to show the error
    with gr.Blocks() as demo:
         gr.ChatInterface(error_response)
    demo.launch()
    exit() # Important: exit to prevent running the rest of the (broken) code


# Chat template handling (important for correct prompting)
def apply_chat_template(messages, chat_template=None):
    """Applies the chat template to the message history.

    Args:
        messages: A list of dictionaries, where each dictionary has a "role"
            ("user" or "assistant") and "content" key.
        chat_template:  The chat template string (optional). If None,
        try to get from tokenizer.

    Returns:
        A single string representing the formatted conversation.
    """
    if chat_template is None:
        if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
            chat_template = tokenizer.chat_template
        else:
            # Fallback to a simple template if no chat template is found.  This is
            # *critical* to prevent the model from generating nonsensical output.
            chat_template = "{% for message in messages %}" \
                            "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
                            "{% endfor %}" \
                            "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"

    return tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template
    )

# Prediction function (modified for chat)
def predict(message, history):
    """Generates a response to the user's message.

    Args:
        message: The user's input message (string).
        history: A list of (user_message, bot_response) tuples representing
            the conversation history.

    Returns:
        The generated bot response (string).
    """
    # Build the conversation history in the required format.
    messages = []
    for user_msg, bot_response in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "model", "content": bot_response})
    messages.append({"role": "user", "content": message})

    # Apply the chat template.
    prompt = apply_chat_template(messages)

    # Generate the response using the pipeline (much cleaner).
    try:
        sequences = pipe(
            prompt,
            max_new_tokens=512,   # Limit response length
            do_sample=True,       # Use sampling for more diverse responses
            temperature=0.7,      # Control randomness (higher = more random)
            top_k=50,             # Top-k sampling
            top_p=0.95,            # Nucleus sampling
            repetition_penalty=1.2, # Reduce repetition
            pad_token_id=tokenizer.eos_token_id,  # Ensure padding is correct.

        )
        response = sequences[0]['generated_text'][len(prompt):].strip() # Extract *only* generated text
        return response

    except Exception as e:
        return f"An error occurred during generation: {e}"


# Gradio interface (using gr.ChatInterface for a chatbot UI)
with gr.Blocks() as demo:
    gr.ChatInterface(
        predict,
        chatbot=gr.Chatbot(height=500),  # Set a reasonable height
        textbox=gr.Textbox(placeholder="Ask me anything!", container=False, scale=7),
        title="Gemma-3-1b-it Chatbot",
        description="Chat with the Gemma-3-1b-it model.",
        retry_btn=None,   # Remove redundant buttons
        undo_btn=None,
        clear_btn=None,
    )

demo.launch(share=False) # Set share=True to create a publicly shareable link