HaryaniAnjali commited on
Commit
5edb16c
·
verified ·
1 Parent(s): ad9c083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -34
app.py CHANGED
@@ -1,6 +1,10 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
 
 
 
 
4
 
5
  class CustomChatDoctor:
6
  def __init__(self):
@@ -13,13 +17,31 @@ class CustomChatDoctor:
13
  # Model name
14
  model_name = "zl111/ChatDoctor"
15
 
16
- # Load tokenizer and model
17
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- self.model = AutoModelForCausalLM.from_pretrained(
19
- model_name,
20
- device_map="auto",
21
- trust_remote_code=True
22
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  print("Model loaded successfully!")
25
 
@@ -29,10 +51,11 @@ class CustomChatDoctor:
29
  Your answers should be based on verified medical information.
30
  If a question doesn't make any sense, or is not factually coherent, explain why instead of answering something incorrect.
31
  If you don't know the answer to a question, respond with "I don't have enough information to provide a reliable answer."
 
32
  """
33
 
34
  # Initialize conversation history
35
- self.conversation_history = []
36
 
37
  def generate_response(self, user_input):
38
  try:
@@ -44,18 +67,22 @@ class CustomChatDoctor:
44
  prompt += "\n".join(self.conversation_history[-10:])
45
  prompt += "\nAI Assistant: "
46
 
47
- # Generate the response
48
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
49
 
50
- with torch.no_grad():
 
 
 
 
 
 
 
 
51
  outputs = self.model.generate(
52
  input_ids=inputs.input_ids,
53
  attention_mask=inputs.attention_mask,
54
- max_new_tokens=512,
55
- temperature=0.7,
56
- top_p=0.9,
57
- do_sample=True,
58
- pad_token_id=self.tokenizer.eos_token_id
59
  )
60
 
61
  response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
@@ -66,53 +93,115 @@ class CustomChatDoctor:
66
  return response
67
 
68
  except Exception as e:
 
69
  error_message = f"An error occurred: {str(e)}"
70
  print(error_message)
71
- return "I'm sorry, I encountered an error processing your question. Please try again."
72
 
73
  def reset_conversation(self):
74
  self.conversation_history = []
75
  return None
76
 
77
- # Initialize the model
78
- chat_doctor = CustomChatDoctor()
 
 
 
 
 
 
79
 
80
- # Define example inputs
81
  examples = [
82
  "What are the symptoms of diabetes?",
83
  "How can I manage my hypertension?",
84
- "What should I do for a persistent headache?"
 
 
85
  ]
86
 
87
- # Create Gradio interface
88
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  gr.Markdown("# Your Custom ChatDoctor")
90
  gr.Markdown("Ask medical questions and get AI-powered responses.")
91
 
92
- chatbot = gr.Chatbot()
93
- msg = gr.Textbox(placeholder="Type your medical question here...")
94
 
95
  with gr.Row():
96
- submit_btn = gr.Button("Send")
97
  clear_btn = gr.Button("Clear Conversation")
98
 
99
- gr.Examples(examples=examples, inputs=msg)
 
 
 
 
100
 
101
- gr.Markdown("""
102
- ### Disclaimer
103
- This AI assistant provides information for educational purposes only.
104
- Always consult with a qualified healthcare provider for medical advice.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  """)
106
 
107
  def respond(message, chat_history):
108
- bot_message = chat_doctor.generate_response(message)
109
- chat_history.append((message, bot_message))
 
 
 
 
 
 
 
110
  return "", chat_history
111
 
112
  def clear_history():
113
- chat_doctor.reset_conversation()
 
114
  return None
115
 
 
 
 
 
 
 
 
 
116
  submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
117
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
118
  clear_btn.click(clear_history, None, chatbot)
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import gradio as gr
4
+ import os
5
+
6
+ # Setup cache directory for models
7
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
8
 
9
  class CustomChatDoctor:
10
  def __init__(self):
 
17
  # Model name
18
  model_name = "zl111/ChatDoctor"
19
 
20
+ # Setup quantization for memory efficiency
21
+ if self.device == "cuda":
22
+ # Use 8-bit quantization if GPU is available
23
+ quantization_config = BitsAndBytesConfig(
24
+ load_in_8bit=True,
25
+ llm_int8_threshold=6.0
26
+ )
27
+
28
+ # Load tokenizer and model
29
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ self.model = AutoModelForCausalLM.from_pretrained(
31
+ model_name,
32
+ quantization_config=quantization_config,
33
+ device_map="auto",
34
+ trust_remote_code=True
35
+ )
36
+ else:
37
+ # For CPU, use lighter settings
38
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ self.model = AutoModelForCausalLM.from_pretrained(
40
+ model_name,
41
+ device_map="auto",
42
+ low_cpu_mem_usage=True,
43
+ trust_remote_code=True
44
+ )
45
 
46
  print("Model loaded successfully!")
47
 
 
51
  Your answers should be based on verified medical information.
52
  If a question doesn't make any sense, or is not factually coherent, explain why instead of answering something incorrect.
53
  If you don't know the answer to a question, respond with "I don't have enough information to provide a reliable answer."
54
+ Always include a disclaimer that you are an AI assistant and not a licensed medical professional.
55
  """
56
 
57
  # Initialize conversation history
58
+ self.reset_conversation()
59
 
60
  def generate_response(self, user_input):
61
  try:
 
67
  prompt += "\n".join(self.conversation_history[-10:])
68
  prompt += "\nAI Assistant: "
69
 
70
+ # Generate the response with appropriate parameters based on device
71
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
72
 
73
+ with torch.no_grad(): # Disable gradient calculations for inference
74
+ generation_config = {
75
+ "max_new_tokens": 512,
76
+ "temperature": 0.7,
77
+ "top_p": 0.9,
78
+ "do_sample": True,
79
+ "pad_token_id": self.tokenizer.eos_token_id
80
+ }
81
+
82
  outputs = self.model.generate(
83
  input_ids=inputs.input_ids,
84
  attention_mask=inputs.attention_mask,
85
+ **generation_config
 
 
 
 
86
  )
87
 
88
  response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
93
  return response
94
 
95
  except Exception as e:
96
+ # Handle any errors during generation
97
  error_message = f"An error occurred: {str(e)}"
98
  print(error_message)
99
+ return "I'm sorry, I encountered an error processing your question. Please try again or ask a different question."
100
 
101
  def reset_conversation(self):
102
  self.conversation_history = []
103
  return None
104
 
105
+ # Create a singleton instance to avoid reloading the model for each user
106
+ chat_doctor = None
107
+
108
+ def get_chat_doctor():
109
+ global chat_doctor
110
+ if chat_doctor is None:
111
+ chat_doctor = CustomChatDoctor()
112
+ return chat_doctor
113
 
114
+ # Example inputs to help users get started
115
  examples = [
116
  "What are the symptoms of diabetes?",
117
  "How can I manage my hypertension?",
118
+ "What should I do for a persistent headache?",
119
+ "Can you explain what asthma is?",
120
+ "What are the side effects of ibuprofen?"
121
  ]
122
 
123
+ # Add CSS for styling
124
+ css = """
125
+ .gradio-container {
126
+ font-family: 'Arial', sans-serif;
127
+ }
128
+ .disclaimer {
129
+ margin-top: 20px;
130
+ padding: 10px;
131
+ background-color: #f8f9fa;
132
+ border-left: 3px solid #f0ad4e;
133
+ font-size: 14px;
134
+ }
135
+ """
136
+
137
+ # Create welcome message function
138
+ def welcome():
139
+ return "Welcome to ChatDoctor! I'm an AI assistant trained to provide medical information. How can I help you today?"
140
+
141
+ # Build the Gradio interface
142
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
143
  gr.Markdown("# Your Custom ChatDoctor")
144
  gr.Markdown("Ask medical questions and get AI-powered responses.")
145
 
146
+ chatbot = gr.Chatbot(height=600, type="messages")
147
+ msg = gr.Textbox(placeholder="Type your medical question here...", lines=2)
148
 
149
  with gr.Row():
150
+ submit_btn = gr.Button("Send", variant="primary")
151
  clear_btn = gr.Button("Clear Conversation")
152
 
153
+ # Display example queries that users can click on
154
+ gr.Examples(
155
+ examples=examples,
156
+ inputs=msg
157
+ )
158
 
159
+ with gr.Accordion("About this AI", open=False):
160
+ gr.Markdown("""
161
+ **ChatDoctor** is a medical conversation model designed to provide general health information.
162
+
163
+ This AI uses language models to generate responses based on patterns learned from medical texts and conversations.
164
+
165
+ **Important Notes:**
166
+ - This system is for informational purposes only
167
+ - Not a substitute for professional medical advice
168
+ - In emergencies, contact emergency services immediately
169
+ """)
170
+
171
+ # Add disclaimer at the bottom with custom styling
172
+ gr.HTML("""
173
+ <div class="disclaimer">
174
+ <strong>Disclaimer:</strong> This AI assistant provides information for educational purposes only.
175
+ Always consult with a qualified healthcare provider for medical advice, diagnosis, or treatment.
176
+ This tool is not intended to replace professional medical consultation.
177
+ </div>
178
  """)
179
 
180
  def respond(message, chat_history):
181
+ # Lazy-load model on first request
182
+ doctor = get_chat_doctor()
183
+
184
+ if message.strip() == "":
185
+ return "", chat_history
186
+
187
+ bot_message = doctor.generate_response(message)
188
+ chat_history.append({"role": "user", "content": message})
189
+ chat_history.append({"role": "assistant", "content": bot_message})
190
  return "", chat_history
191
 
192
  def clear_history():
193
+ doctor = get_chat_doctor()
194
+ doctor.reset_conversation()
195
  return None
196
 
197
+ # Show welcome message when the app starts
198
+ demo.load(lambda: None, None, chatbot, js="""
199
+ () => {
200
+ const welcomeMsg = "Welcome to ChatDoctor! I'm an AI assistant trained to provide medical information. How can I help you today?";
201
+ return [[null, welcomeMsg]];
202
+ }
203
+ """)
204
+
205
  submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
206
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
207
  clear_btn.click(clear_history, None, chatbot)