File size: 5,295 Bytes
da1470a
 
662b714
da1470a
a41650d
da1470a
a41650d
da1470a
a41650d
 
 
 
 
 
 
da1470a
a41650d
 
da1470a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a41650d
da1470a
a41650d
 
da1470a
 
 
 
 
 
 
a41650d
da1470a
662b714
da1470a
 
 
662b714
da1470a
662b714
 
 
 
da1470a
662b714
da1470a
 
a41650d
662b714
 
da1470a
662b714
da1470a
a41650d
da1470a
662b714
 
da1470a
 
 
 
a41650d
 
da1470a
 
 
 
662b714
da1470a
 
 
 
 
 
a41650d
662b714
da1470a
 
 
662b714
 
da1470a
a41650d
 
 
 
da1470a
a41650d
 
 
 
 
 
 
 
 
da1470a
a41650d
 
 
 
 
 
 
da1470a
 
a41650d
 
 
 
da1470a
a41650d
da1470a
a41650d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging
from huggingface_hub import login
import torch
import os
import gradio as gr

# --- 1. Authentication (Using Environment Variable - the ONLY correct way for Spaces) ---

# Hugging Face Spaces CANNOT use interactive login.  You MUST use an environment variable.
# 1. Go to your Space's settings.
# 2. Click on "Repository Secrets".
# 3. Click "New Secret".
# 4. Name the secret: HUGGING_FACE_HUB_TOKEN
# 5. Paste your Hugging Face API token (with read access) as the value.
# 6. Save the secret.

# The login() call below will now automatically use the environment variable.
login()

# --- 2. Model and Tokenizer Setup (with comprehensive error handling) ---

def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
    """Loads the model and tokenizer, handling potential errors."""
    try:
        # Suppress unnecessary warning messages from transformers
        logging.set_verbosity_error()

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",  # Automatically use GPU if available, else CPU
            torch_dtype=torch.bfloat16,  # Use bfloat16 for speed/memory if supported
            attn_implementation="flash_attention_2"  # Use Flash Attention 2 if supported
        )
        return model, tokenizer

    except Exception as e:
        print(f"ERROR: Failed to load model or tokenizer: {e}")
        print("\nTroubleshooting Steps:")
        print("1. Ensure you have a Hugging Face account and have accepted the model's terms.")
        print("2. Verify your internet connection.")
        print("3. Double-check the model name: 'google/gemma-3-1b-it'")
        print("4. Ensure you are properly authenticated using a Repository Secret (see above).")
        print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.")
        # Instead of exiting, raise the exception to be caught by Gradio
        raise

model, tokenizer = load_model_and_tokenizer()


# --- 3. Chat Template Function (CRITICAL for conversational models) ---

def apply_chat_template(messages, tokenizer):
    """Applies the appropriate chat template."""
    try:
        if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
            return tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.")
            chat_template = "{% for message in messages %}" \
                            "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
                            "{% endfor %}" \
                            "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
            return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)

    except Exception as e:
        print(f"ERROR: Failed to apply chat template: {e}")
        raise # Re-raise to be caught by Gradio


# --- 4. Text Generation Function ---

def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
    """Generates a response."""
    prompt = apply_chat_template(messages, tokenizer)

    try:
        pipeline_instance = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            model_kwargs={"attn_implementation": "flash_attention_2"}
            )

        outputs = pipeline_instance(
            prompt,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,
        )

        generated_text = outputs[0]["generated_text"][len(prompt):].strip()
        return generated_text

    except Exception as e:
        print(f"ERROR: Failed to generate response: {e}")
        raise # Re-raise the exception


# --- 5. Gradio Interface ---

def predict(message, history):
    if not history:
        history = []
    messages = []
    for user_msg, bot_response in history:
        messages.append({"role": "user", "content": user_msg})
        if bot_response:  # Check if bot_response is not None
            messages.append({"role": "model", "content": bot_response})
    messages.append({"role": "user", "content": message})

    try:
      response = generate_response(messages, model, tokenizer)
      history.append((message, response))
      return "", history
    except Exception as e:
        # Catch any exceptions during generation and display in the UI
        return f"Error: {e}", history


with gr.Blocks() as demo:
    chatbot = gr.Chatbot(label="Gemma Chatbot", height=500)
    msg = gr.Textbox(placeholder="Ask me anything!", container=False, scale=7)
    clear = gr.ClearButton([msg, chatbot])

    msg.submit(predict, [msg, chatbot], [msg, chatbot])

demo.launch()