Spaces:
Sleeping
Sleeping
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging | |
from huggingface_hub import login | |
import torch | |
import os | |
import gradio as gr | |
# --- 1. Authentication (Using User-Provided Token) --- | |
def authenticate(token): | |
"""Attempts to authenticate with the provided token.""" | |
try: | |
login(token=token) | |
return True | |
except Exception as e: | |
print(f"Authentication failed: {e}") | |
return False | |
# --- 2. Model and Tokenizer Setup --- | |
def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"): | |
"""Loads the model and tokenizer.""" | |
try: | |
logging.set_verbosity_error() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2" | |
) | |
return model, tokenizer | |
except Exception as e: | |
print(f"ERROR: Failed to load model/tokenizer: {e}") | |
raise # Re-raise for Gradio | |
# --- 3. Chat Template Function --- | |
def apply_chat_template(messages, tokenizer): | |
"""Applies the chat template.""" | |
try: | |
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: | |
return tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
else: | |
print("WARNING: Tokenizer lacks chat_template. Using fallback.") | |
chat_template = "{% for message in messages %}" \ | |
"{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \ | |
"{% endfor %}" \ | |
"{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}" | |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template) | |
except Exception as e: | |
print(f"ERROR: Chat template application failed: {e}") | |
raise | |
# --- 4. Text Generation Function --- | |
def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2): | |
"""Generates a response.""" | |
prompt = apply_chat_template(messages, tokenizer) | |
try: | |
pipeline_instance = pipeline( | |
"text-generation", model=model, tokenizer=tokenizer, | |
torch_dtype=torch.bfloat16, device_map="auto", | |
model_kwargs={"attn_implementation": "flash_attention_2"} | |
) | |
outputs = pipeline_instance( | |
prompt, max_new_tokens=max_new_tokens, do_sample=True, | |
temperature=temperature, top_k=top_k, top_p=top_p, | |
repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id | |
) | |
return outputs[0]["generated_text"][len(prompt):].strip() | |
except Exception as e: | |
print(f"ERROR: Response generation failed: {e}") | |
raise | |
# --- 5. Gradio Interface --- | |
model = None # Initialize model and tokenizer as global variables | |
tokenizer = None | |
def chat(token, message, history): | |
global model, tokenizer # Access the global model and tokenizer | |
if not authenticate(token): | |
return "Authentication failed. Please enter a valid Hugging Face token.", history | |
if model is None or tokenizer is None: | |
try: | |
model, tokenizer = load_model_and_tokenizer() | |
except Exception as e: | |
return f"Model loading error: {e}", history | |
if not history: | |
history = [] | |
messages = [{"role": "user", "content": msg} for msg, _ in history] | |
messages.extend([{"role": "model", "content": resp} for _, resp in history if resp]) | |
messages.append({"role": "user", "content": message}) | |
try: | |
response = generate_response(messages, model, tokenizer) | |
history.append((message, response)) | |
return "", history | |
except Exception as e: | |
return f"Error during generation: {e}", history | |
with gr.Blocks() as demo: | |
gr.Markdown("# Gemma Chatbot") | |
gr.Markdown("Enter your Hugging Face API token (read access required):") | |
token_input = gr.Textbox(label="Hugging Face Token", type="password") # Use type="password" | |
chatbot = gr.Chatbot(label="Chat", height=400) | |
msg_input = gr.Textbox(label="Message", placeholder="Ask me anything!") | |
clear_btn = gr.ClearButton([msg_input, chatbot]) | |
msg_input.submit(chat, [token_input, msg_input, chatbot], [msg_input, chatbot]) | |
clear_btn.click(lambda: (None, []), [], [msg_input, chatbot]) | |
demo.launch() |