HaryaniAnjali commited on
Commit
ad9c083
·
verified ·
1 Parent(s): 911c6e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -123
app.py CHANGED
@@ -1,10 +1,6 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import gradio as gr
4
- import os
5
-
6
- # Setup cache directory for models
7
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
8
 
9
  class CustomChatDoctor:
10
  def __init__(self):
@@ -17,31 +13,13 @@ class CustomChatDoctor:
17
  # Model name
18
  model_name = "zl111/ChatDoctor"
19
 
20
- # Setup quantization for memory efficiency
21
- if self.device == "cuda":
22
- # Use 8-bit quantization if GPU is available
23
- quantization_config = BitsAndBytesConfig(
24
- load_in_8bit=True,
25
- llm_int8_threshold=6.0
26
- )
27
-
28
- # Load tokenizer and model
29
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30
- self.model = AutoModelForCausalLM.from_pretrained(
31
- model_name,
32
- quantization_config=quantization_config,
33
- device_map="auto",
34
- trust_remote_code=True
35
- )
36
- else:
37
- # For CPU, use lighter settings
38
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
39
- self.model = AutoModelForCausalLM.from_pretrained(
40
- model_name,
41
- device_map="auto",
42
- low_cpu_mem_usage=True,
43
- trust_remote_code=True
44
- )
45
 
46
  print("Model loaded successfully!")
47
 
@@ -51,11 +29,10 @@ class CustomChatDoctor:
51
  Your answers should be based on verified medical information.
52
  If a question doesn't make any sense, or is not factually coherent, explain why instead of answering something incorrect.
53
  If you don't know the answer to a question, respond with "I don't have enough information to provide a reliable answer."
54
- Always include a disclaimer that you are an AI assistant and not a licensed medical professional.
55
  """
56
 
57
  # Initialize conversation history
58
- self.reset_conversation()
59
 
60
  def generate_response(self, user_input):
61
  try:
@@ -67,22 +44,18 @@ class CustomChatDoctor:
67
  prompt += "\n".join(self.conversation_history[-10:])
68
  prompt += "\nAI Assistant: "
69
 
70
- # Generate the response with appropriate parameters based on device
71
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
72
 
73
- with torch.no_grad(): # Disable gradient calculations for inference
74
- generation_config = {
75
- "max_new_tokens": 512,
76
- "temperature": 0.7,
77
- "top_p": 0.9,
78
- "do_sample": True,
79
- "pad_token_id": self.tokenizer.eos_token_id
80
- }
81
-
82
  outputs = self.model.generate(
83
  input_ids=inputs.input_ids,
84
  attention_mask=inputs.attention_mask,
85
- **generation_config
 
 
 
 
86
  )
87
 
88
  response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
@@ -93,115 +66,53 @@ class CustomChatDoctor:
93
  return response
94
 
95
  except Exception as e:
96
- # Handle any errors during generation
97
  error_message = f"An error occurred: {str(e)}"
98
  print(error_message)
99
- return "I'm sorry, I encountered an error processing your question. Please try again or ask a different question."
100
 
101
  def reset_conversation(self):
102
  self.conversation_history = []
103
  return None
104
 
105
- # Create a singleton instance to avoid reloading the model for each user
106
- chat_doctor = None
107
-
108
- def get_chat_doctor():
109
- global chat_doctor
110
- if chat_doctor is None:
111
- chat_doctor = CustomChatDoctor()
112
- return chat_doctor
113
 
114
- # Example inputs to help users get started
115
  examples = [
116
  "What are the symptoms of diabetes?",
117
  "How can I manage my hypertension?",
118
- "What should I do for a persistent headache?",
119
- "Can you explain what asthma is?",
120
- "What are the side effects of ibuprofen?"
121
  ]
122
 
123
- # Add CSS for styling
124
- css = """
125
- .gradio-container {
126
- font-family: 'Arial', sans-serif;
127
- }
128
- .disclaimer {
129
- margin-top: 20px;
130
- padding: 10px;
131
- background-color: #f8f9fa;
132
- border-left: 3px solid #f0ad4e;
133
- font-size: 14px;
134
- }
135
- """
136
-
137
- # Create welcome message function
138
- def welcome():
139
- return "Welcome to ChatDoctor! I'm an AI assistant trained to provide medical information. How can I help you today?"
140
-
141
- # Build the Gradio interface
142
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
143
  gr.Markdown("# Your Custom ChatDoctor")
144
  gr.Markdown("Ask medical questions and get AI-powered responses.")
145
 
146
- chatbot = gr.Chatbot(height=600, type="messages")
147
- msg = gr.Textbox(placeholder="Type your medical question here...", lines=2)
148
 
149
  with gr.Row():
150
- submit_btn = gr.Button("Send", variant="primary")
151
  clear_btn = gr.Button("Clear Conversation")
152
 
153
- # Display example queries that users can click on
154
- gr.Examples(
155
- examples=examples,
156
- inputs=msg
157
- )
158
 
159
- with gr.Accordion("About this AI", open=False):
160
- gr.Markdown("""
161
- **ChatDoctor** is a medical conversation model designed to provide general health information.
162
-
163
- This AI uses language models to generate responses based on patterns learned from medical texts and conversations.
164
-
165
- **Important Notes:**
166
- - This system is for informational purposes only
167
- - Not a substitute for professional medical advice
168
- - In emergencies, contact emergency services immediately
169
- """)
170
-
171
- # Add disclaimer at the bottom with custom styling
172
- gr.HTML("""
173
- <div class="disclaimer">
174
- <strong>Disclaimer:</strong> This AI assistant provides information for educational purposes only.
175
- Always consult with a qualified healthcare provider for medical advice, diagnosis, or treatment.
176
- This tool is not intended to replace professional medical consultation.
177
- </div>
178
  """)
179
 
180
  def respond(message, chat_history):
181
- # Lazy-load model on first request
182
- doctor = get_chat_doctor()
183
-
184
- if message.strip() == "":
185
- return "", chat_history
186
-
187
- bot_message = doctor.generate_response(message)
188
- chat_history.append({"role": "user", "content": message})
189
- chat_history.append({"role": "assistant", "content": bot_message})
190
  return "", chat_history
191
 
192
  def clear_history():
193
- doctor = get_chat_doctor()
194
- doctor.reset_conversation()
195
  return None
196
 
197
- # Show welcome message when the app starts
198
- demo.load(lambda: None, None, chatbot, js="""
199
- () => {
200
- const welcomeMsg = "Welcome to ChatDoctor! I'm an AI assistant trained to provide medical information. How can I help you today?";
201
- return [[null, welcomeMsg]];
202
- }
203
- """)
204
-
205
  submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
206
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
207
  clear_btn.click(clear_history, None, chatbot)
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
 
 
 
 
4
 
5
  class CustomChatDoctor:
6
  def __init__(self):
 
13
  # Model name
14
  model_name = "zl111/ChatDoctor"
15
 
16
+ # Load tokenizer and model
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ self.model = AutoModelForCausalLM.from_pretrained(
19
+ model_name,
20
+ device_map="auto",
21
+ trust_remote_code=True
22
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  print("Model loaded successfully!")
25
 
 
29
  Your answers should be based on verified medical information.
30
  If a question doesn't make any sense, or is not factually coherent, explain why instead of answering something incorrect.
31
  If you don't know the answer to a question, respond with "I don't have enough information to provide a reliable answer."
 
32
  """
33
 
34
  # Initialize conversation history
35
+ self.conversation_history = []
36
 
37
  def generate_response(self, user_input):
38
  try:
 
44
  prompt += "\n".join(self.conversation_history[-10:])
45
  prompt += "\nAI Assistant: "
46
 
47
+ # Generate the response
48
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
49
 
50
+ with torch.no_grad():
 
 
 
 
 
 
 
 
51
  outputs = self.model.generate(
52
  input_ids=inputs.input_ids,
53
  attention_mask=inputs.attention_mask,
54
+ max_new_tokens=512,
55
+ temperature=0.7,
56
+ top_p=0.9,
57
+ do_sample=True,
58
+ pad_token_id=self.tokenizer.eos_token_id
59
  )
60
 
61
  response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
66
  return response
67
 
68
  except Exception as e:
 
69
  error_message = f"An error occurred: {str(e)}"
70
  print(error_message)
71
+ return "I'm sorry, I encountered an error processing your question. Please try again."
72
 
73
  def reset_conversation(self):
74
  self.conversation_history = []
75
  return None
76
 
77
+ # Initialize the model
78
+ chat_doctor = CustomChatDoctor()
 
 
 
 
 
 
79
 
80
+ # Define example inputs
81
  examples = [
82
  "What are the symptoms of diabetes?",
83
  "How can I manage my hypertension?",
84
+ "What should I do for a persistent headache?"
 
 
85
  ]
86
 
87
+ # Create Gradio interface
88
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  gr.Markdown("# Your Custom ChatDoctor")
90
  gr.Markdown("Ask medical questions and get AI-powered responses.")
91
 
92
+ chatbot = gr.Chatbot()
93
+ msg = gr.Textbox(placeholder="Type your medical question here...")
94
 
95
  with gr.Row():
96
+ submit_btn = gr.Button("Send")
97
  clear_btn = gr.Button("Clear Conversation")
98
 
99
+ gr.Examples(examples=examples, inputs=msg)
 
 
 
 
100
 
101
+ gr.Markdown("""
102
+ ### Disclaimer
103
+ This AI assistant provides information for educational purposes only.
104
+ Always consult with a qualified healthcare provider for medical advice.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  """)
106
 
107
  def respond(message, chat_history):
108
+ bot_message = chat_doctor.generate_response(message)
109
+ chat_history.append((message, bot_message))
 
 
 
 
 
 
 
110
  return "", chat_history
111
 
112
  def clear_history():
113
+ chat_doctor.reset_conversation()
 
114
  return None
115
 
 
 
 
 
 
 
 
 
116
  submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
117
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
118
  clear_btn.click(clear_history, None, chatbot)