File size: 4,625 Bytes
da1470a
 
662b714
da1470a
a41650d
da1470a
c5ec987
da1470a
c5ec987
 
 
 
 
 
 
 
da1470a
c5ec987
da1470a
 
c5ec987
da1470a
 
 
 
 
c5ec987
 
 
da1470a
 
 
c5ec987
 
da1470a
c5ec987
da1470a
 
c5ec987
da1470a
662b714
da1470a
 
 
662b714
c5ec987
662b714
 
 
 
da1470a
 
c5ec987
 
662b714
da1470a
662b714
da1470a
a41650d
da1470a
662b714
da1470a
c5ec987
 
da1470a
c5ec987
da1470a
c5ec987
 
 
662b714
c5ec987
 
 
 
da1470a
c5ec987
 
 
662b714
c5ec987
 
a41650d
c5ec987
 
a41650d
c5ec987
 
 
 
 
da1470a
a41650d
 
c5ec987
 
a41650d
da1470a
a41650d
c5ec987
 
 
a41650d
c5ec987
da1470a
a41650d
c5ec987
 
 
 
 
 
 
 
 
da1470a
 
a41650d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging
from huggingface_hub import login
import torch
import os
import gradio as gr

# --- 1. Authentication (Using User-Provided Token) ---

def authenticate(token):
    """Attempts to authenticate with the provided token."""
    try:
        login(token=token)
        return True
    except Exception as e:
        print(f"Authentication failed: {e}")
        return False

# --- 2. Model and Tokenizer Setup ---

def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
    """Loads the model and tokenizer."""
    try:
        logging.set_verbosity_error()
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2"
        )
        return model, tokenizer
    except Exception as e:
        print(f"ERROR: Failed to load model/tokenizer: {e}")
        raise  # Re-raise for Gradio

# --- 3. Chat Template Function ---

def apply_chat_template(messages, tokenizer):
    """Applies the chat template."""
    try:
        if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
            return tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            print("WARNING: Tokenizer lacks chat_template. Using fallback.")
            chat_template = "{% for message in messages %}" \
                            "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
                            "{% endfor %}" \
                            "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
            return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)
    except Exception as e:
        print(f"ERROR: Chat template application failed: {e}")
        raise

# --- 4. Text Generation Function ---

def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
    """Generates a response."""
    prompt = apply_chat_template(messages, tokenizer)
    try:
        pipeline_instance = pipeline(
            "text-generation", model=model, tokenizer=tokenizer,
            torch_dtype=torch.bfloat16, device_map="auto",
            model_kwargs={"attn_implementation": "flash_attention_2"}
        )
        outputs = pipeline_instance(
            prompt, max_new_tokens=max_new_tokens, do_sample=True,
            temperature=temperature, top_k=top_k, top_p=top_p,
            repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id
        )
        return outputs[0]["generated_text"][len(prompt):].strip()
    except Exception as e:
        print(f"ERROR: Response generation failed: {e}")
        raise

# --- 5. Gradio Interface ---
model = None  # Initialize model and tokenizer as global variables
tokenizer = None

def chat(token, message, history):
    global model, tokenizer  # Access the global model and tokenizer

    if not authenticate(token):
        return "Authentication failed. Please enter a valid Hugging Face token.", history

    if model is None or tokenizer is None:
        try:
            model, tokenizer = load_model_and_tokenizer()
        except Exception as e:
            return f"Model loading error: {e}", history

    if not history:
        history = []
    messages = [{"role": "user", "content": msg} for msg, _ in history]
    messages.extend([{"role": "model", "content": resp} for _, resp in history if resp])
    messages.append({"role": "user", "content": message})

    try:
        response = generate_response(messages, model, tokenizer)
        history.append((message, response))
        return "", history
    except Exception as e:
        return f"Error during generation: {e}", history

with gr.Blocks() as demo:
    gr.Markdown("# Gemma Chatbot")
    gr.Markdown("Enter your Hugging Face API token (read access required):")
    token_input = gr.Textbox(label="Hugging Face Token", type="password") # Use type="password"
    chatbot = gr.Chatbot(label="Chat", height=400)
    msg_input = gr.Textbox(label="Message", placeholder="Ask me anything!")
    clear_btn = gr.ClearButton([msg_input, chatbot])

    msg_input.submit(chat, [token_input, msg_input, chatbot], [msg_input, chatbot])
    clear_btn.click(lambda: (None, []), [], [msg_input, chatbot])


demo.launch()