kimhyunwoo commited on
Commit
da1470a
·
verified ·
1 Parent(s): 9ec0a3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -98
app.py CHANGED
@@ -1,121 +1,143 @@
1
- import gradio as gr
 
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
-
5
- # Model and tokenizer loading (with error handling)
6
- try:
7
- model_name = "google/gemma-3-1b-it" # Correct model name
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- model_name,
11
- torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency, if supported
12
- device_map="auto", # Automatically use GPU if available, otherwise CPU
13
- )
14
- # Create the pipeline
15
- pipe = pipeline(
16
- "text-generation",
17
- model=model,
18
- tokenizer=tokenizer,
19
- torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype
20
- device_map="auto", # and device mapping
21
- model_kwargs={"attn_implementation": "flash_attention_2"} # Enable Flash Attention 2 if supported by your hardware and transformers version
22
- )
23
-
24
- except Exception as e:
25
- error_message = f"Error loading model or tokenizer: {e}"
26
- print(error_message) # Log the error to the console
27
- # Provide a fallback, even if it's just displaying the error.
28
- def error_response(message, history):
29
- return f"Model loading failed. Error: {error_message}"
30
-
31
- # Minimal Gradio interface to show the error
32
- with gr.Blocks() as demo:
33
- gr.ChatInterface(error_response)
34
- demo.launch()
35
- exit() # Important: exit to prevent running the rest of the (broken) code
36
-
37
-
38
- # Chat template handling (important for correct prompting)
39
- def apply_chat_template(messages, chat_template=None):
40
- """Applies the chat template to the message history.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  Args:
43
- messages: A list of dictionaries, where each dictionary has a "role"
44
- ("user" or "assistant") and "content" key.
45
- chat_template: The chat template string (optional). If None,
46
- try to get from tokenizer.
47
 
48
  Returns:
49
- A single string representing the formatted conversation.
50
  """
51
- if chat_template is None:
52
  if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
53
- chat_template = tokenizer.chat_template
 
 
 
54
  else:
55
- # Fallback to a simple template if no chat template is found. This is
56
- # *critical* to prevent the model from generating nonsensical output.
57
  chat_template = "{% for message in messages %}" \
58
  "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
59
  "{% endfor %}" \
60
  "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
 
61
 
62
- return tokenizer.apply_chat_template(
63
- messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template
64
- )
65
 
66
- # Prediction function (modified for chat)
67
- def predict(message, history):
68
- """Generates a response to the user's message.
69
 
70
- Args:
71
- message: The user's input message (string).
72
- history: A list of (user_message, bot_response) tuples representing
73
- the conversation history.
74
 
75
- Returns:
76
- The generated bot response (string).
77
- """
78
- # Build the conversation history in the required format.
79
- messages = []
80
- for user_msg, bot_response in history:
81
- messages.append({"role": "user", "content": user_msg})
82
- messages.append({"role": "model", "content": bot_response})
83
- messages.append({"role": "user", "content": message})
84
 
85
- # Apply the chat template.
86
- prompt = apply_chat_template(messages)
87
 
88
- # Generate the response using the pipeline (much cleaner).
89
  try:
90
- sequences = pipe(
 
 
 
 
 
 
 
 
 
91
  prompt,
92
- max_new_tokens=512, # Limit response length
93
- do_sample=True, # Use sampling for more diverse responses
94
- temperature=0.7, # Control randomness (higher = more random)
95
- top_k=50, # Top-k sampling
96
- top_p=0.95, # Nucleus sampling
97
- repetition_penalty=1.2, # Reduce repetition
98
- pad_token_id=tokenizer.eos_token_id, # Ensure padding is correct.
99
-
100
  )
101
- response = sequences[0]['generated_text'][len(prompt):].strip() # Extract *only* generated text
102
- return response
 
 
103
 
104
  except Exception as e:
105
- return f"An error occurred during generation: {e}"
106
-
107
-
108
- # Gradio interface (using gr.ChatInterface for a chatbot UI)
109
- with gr.Blocks() as demo:
110
- gr.ChatInterface(
111
- predict,
112
- chatbot=gr.Chatbot(height=500), # Set a reasonable height
113
- textbox=gr.Textbox(placeholder="Ask me anything!", container=False, scale=7),
114
- title="Gemma-3-1b-it Chatbot",
115
- description="Chat with the Gemma-3-1b-it model.",
116
- retry_btn=None, # Remove redundant buttons
117
- undo_btn=None,
118
- clear_btn=None,
119
- )
120
-
121
- demo.launch(share=False) # Set share=True to create a publicly shareable link
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging
2
+ from huggingface_hub import login
3
  import torch
4
+ import os
5
+
6
+ # --- 1. Authentication (Choose ONE method and follow the instructions) ---
7
+
8
+ # Method 1: Environment Variable (RECOMMENDED for security and Hugging Face Spaces)
9
+ # - Set the HUGGING_FACE_HUB_TOKEN environment variable *before* running.
10
+ # - Linux/macOS: `export HUGGING_FACE_HUB_TOKEN=your_token` (in terminal)
11
+ # - Windows (PowerShell): `$env:HUGGING_FACE_HUB_TOKEN = "your_token"`
12
+ # - Hugging Face Spaces: Add `HUGGING_FACE_HUB_TOKEN` as a secret in your Space's settings.
13
+ # - Then, uncomment the following line:
14
+ login()
15
+
16
+ # Method 2: Direct Token (ONLY for local testing, NOT for deployment)
17
+ # - Replace "YOUR_HUGGING_FACE_TOKEN" with your actual token.
18
+ # - WARNING: Do NOT commit your token to a public repository!
19
+ # login(token="YOUR_HUGGING_FACE_TOKEN")
20
+
21
+ # Method 3: huggingface-cli (Interactive, one-time setup, good for local development)
22
+ # - Run `huggingface-cli login` in your terminal.
23
+ # - Paste your token when prompted.
24
+ # - No code changes are needed after this; the token is stored.
25
+
26
+ # --- 2. Model and Tokenizer Setup (with comprehensive error handling) ---
27
+
28
+ def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
29
+ """Loads the model and tokenizer, handling potential errors."""
30
+ try:
31
+ # Suppress unnecessary warning messages from transformers
32
+ logging.set_verbosity_error()
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ device_map="auto", # Automatically use GPU if available, else CPU
38
+ torch_dtype=torch.bfloat16, # Use bfloat16 for speed/memory if supported
39
+ attn_implementation="flash_attention_2" # Use Flash Attention 2 if supported
40
+ )
41
+ return model, tokenizer
42
+
43
+ except Exception as e:
44
+ print(f"ERROR: Failed to load model or tokenizer: {e}")
45
+ print("\nTroubleshooting Steps:")
46
+ print("1. Ensure you have a Hugging Face account and have accepted the model's terms.")
47
+ print("2. Verify your internet connection.")
48
+ print("3. Double-check the model name: 'google/gemma-3-1b-it'")
49
+ print("4. Ensure you are properly authenticated (see authentication section above).")
50
+ print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.")
51
+ exit(1) # Exit with an error code
52
+
53
+ model, tokenizer = load_model_and_tokenizer()
54
+
55
+
56
+ # --- 3. Chat Template Function (CRITICAL for conversational models) ---
57
+
58
+ def apply_chat_template(messages, tokenizer):
59
+ """Applies the appropriate chat template to the message history.
60
 
61
  Args:
62
+ messages: A list of dictionaries, where each dictionary has 'role' (user/model)
63
+ and 'content' keys.
64
+ tokenizer: The tokenizer object.
 
65
 
66
  Returns:
67
+ A formatted prompt string ready for the model.
68
  """
69
+ try:
70
  if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
71
+ # Use the tokenizer's built-in chat template if available
72
+ return tokenizer.apply_chat_template(
73
+ messages, tokenize=False, add_generation_prompt=True
74
+ )
75
  else:
76
+ # Fallback to a standard chat template if no specific one is found
77
+ print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.")
78
  chat_template = "{% for message in messages %}" \
79
  "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
80
  "{% endfor %}" \
81
  "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
82
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)
83
 
84
+ except Exception as e:
85
+ print(f"ERROR: Failed to apply chat template: {e}")
86
+ exit(1)
87
 
 
 
 
88
 
89
+ # --- 4. Text Generation Function ---
 
 
 
90
 
91
+ def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
92
+ """Generates a response using the model and tokenizer."""
 
 
 
 
 
 
 
93
 
94
+ prompt = apply_chat_template(messages, tokenizer)
 
95
 
 
96
  try:
97
+ pipeline_instance = pipeline(
98
+ "text-generation",
99
+ model=model,
100
+ tokenizer=tokenizer,
101
+ torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype
102
+ device_map="auto", # and device mapping
103
+ model_kwargs={"attn_implementation": "flash_attention_2"}
104
+ )
105
+
106
+ outputs = pipeline_instance(
107
  prompt,
108
+ max_new_tokens=max_new_tokens,
109
+ do_sample=True,
110
+ temperature=temperature,
111
+ top_k=top_k,
112
+ top_p=top_p,
113
+ repetition_penalty=repetition_penalty,
114
+ pad_token_id=tokenizer.eos_token_id, # Important for proper padding
 
115
  )
116
+
117
+ # Extract *only* the generated text (remove the prompt)
118
+ generated_text = outputs[0]["generated_text"][len(prompt):].strip()
119
+ return generated_text
120
 
121
  except Exception as e:
122
+ print(f"ERROR: Failed to generate response: {e}")
123
+ return "Sorry, I encountered an error while generating a response."
124
+
125
+
126
+ # --- 5. Main Interaction Loop (for command-line interaction) ---
127
+ def main():
128
+ """Main function for interactive command-line chat."""
129
+
130
+ messages = [] # Initialize the conversation history
131
+
132
+ while True:
133
+ user_input = input("You: ")
134
+ if user_input.lower() in ("exit", "quit", "bye"):
135
+ break
136
+
137
+ messages.append({"role": "user", "content": user_input})
138
+ response = generate_response(messages, model, tokenizer)
139
+ print(f"Model: {response}")
140
+ messages.append({"role": "model", "content": response})
141
+
142
+ if __name__ == "__main__":
143
+ main()