atharwaah1work commited on
Commit
5219bc0
Β·
verified Β·
1 Parent(s): 65ea401

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -136
app.py CHANGED
@@ -1,55 +1,22 @@
1
 
2
 
3
  import torch
4
- from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, BitsAndBytesConfig
5
  import gradio as gr
6
  from PIL import Image
7
  import re
8
- import os
9
  from typing import List, Tuple
10
 
11
- # Configuration for 4-bit quantization (if GPU available)
12
- quant_config = BitsAndBytesConfig(
13
- load_in_4bit=True,
14
- bnb_4bit_compute_dtype=torch.float16,
15
- bnb_4bit_quant_type="nf4",
16
- bnb_4bit_use_double_quant=True
17
- )
18
-
19
  class RiverPollutionAnalyzer:
20
  def __init__(self):
21
- try:
22
- # Initialize model with fallback for CPU
23
- self.processor = InstructBlipProcessor.from_pretrained(
24
- "Salesforce/instructblip-flan-t5-xl",
25
- cache_dir="model_cache"
26
- )
27
-
28
- if torch.cuda.is_available():
29
- self.model = InstructBlipForConditionalGeneration.from_pretrained(
30
- "Salesforce/instructblip-flan-t5-xl",
31
- device_map="auto",
32
- quantization_config=quant_config,
33
- torch_dtype=torch.float16,
34
- cache_dir="model_cache"
35
- )
36
- self.device = "cuda"
37
- self.status = "βœ… Model loaded (4-bit GPU)"
38
- else:
39
- self.model = InstructBlipForConditionalGeneration.from_pretrained(
40
- "Salesforce/instructblip-flan-t5-xl",
41
- device_map="auto",
42
- torch_dtype=torch.float32,
43
- cache_dir="model_cache",
44
- low_cpu_mem_usage=True
45
- )
46
- self.device = "cpu"
47
- self.status = "⚠️ Model loaded (CPU mode - slower)"
48
-
49
- except Exception as e:
50
- self.model = None
51
- self.status = f"❌ Model loading failed: {str(e)}"
52
- print(self.status)
53
 
54
  self.pollutants = [
55
  "plastic waste", "chemical foam", "industrial discharge",
@@ -72,15 +39,9 @@ class RiverPollutionAnalyzer:
72
  }
73
 
74
  def analyze_image(self, image):
75
- """Analyze river pollution with device-aware processing"""
76
- if not self.model:
77
- return "Model not loaded. Please check logs."
78
-
79
  if not isinstance(image, Image.Image):
80
  image = Image.fromarray(image)
81
-
82
- # Resize for efficiency
83
- image = image.resize((512, 512))
84
 
85
  prompt = """Analyze this river pollution scene and provide:
86
  1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
@@ -90,31 +51,27 @@ Respond EXACTLY in this format:
90
  Pollutants: [comma separated list]
91
  Severity: [number]"""
92
 
93
- try:
94
- inputs = self.processor(
95
- images=image,
96
- text=prompt,
97
- return_tensors="pt"
98
- ).to(self.model.device)
99
-
100
- with torch.no_grad():
101
- outputs = self.model.generate(
102
- **inputs,
103
- max_new_tokens=150, # Reduced for stability
104
- temperature=0.5,
105
- top_p=0.85,
106
- do_sample=True
107
- )
108
-
109
- analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
110
- pollutants, severity = self._parse_response(analysis)
111
- return self._format_analysis(pollutants, severity)
112
- except Exception as e:
113
- return f"⚠️ Analysis error: {str(e)}"
114
-
115
- # [Keep all existing helper methods unchanged]
116
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
117
- """Same parsing logic as before"""
118
  pollutants = []
119
  severity = 3
120
 
@@ -123,6 +80,7 @@ Severity: [number]"""
123
  r'(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)',
124
  analysis
125
  )
 
126
  if pollutant_match:
127
  pollutants_str = pollutant_match.group(2).strip()
128
  pollutants = [
@@ -132,7 +90,11 @@ Severity: [number]"""
132
  ]
133
 
134
  # Extract severity
135
- severity_match = re.search(r'(?i)(severity|level)[:\s]*(\d{1,2})', analysis)
 
 
 
 
136
  if severity_match:
137
  try:
138
  severity = min(max(int(severity_match.group(2)), 1), 10)
@@ -144,7 +106,7 @@ Severity: [number]"""
144
  return pollutants, severity
145
 
146
  def _calculate_severity(self, pollutants: List[str]) -> int:
147
- """Same severity calculation"""
148
  if not pollutants:
149
  return 1
150
 
@@ -159,7 +121,7 @@ Severity: [number]"""
159
  return min(10, max(1, round(avg_weight * 3)))
160
 
161
  def _format_analysis(self, pollutants: List[str], severity: int) -> str:
162
- """Same formatting"""
163
  severity_bar = f"""πŸ“Š Severity: {severity}/10
164
  {"β–ˆ" * severity}{"β–‘" * (10 - severity)}
165
  {self.severity_descriptions.get(severity, '')}"""
@@ -171,57 +133,21 @@ Severity: [number]"""
171
  {pollutants_list}
172
  {severity_bar}"""
173
 
174
- def analyze_chat(self, message: str) -> str:
175
- """Handle chat questions"""
176
- if any(word in message.lower() for word in ["hello", "hi", "hey"]):
177
- return "Hello! I'm a river pollution analyzer. Ask me about pollution types."
178
- elif "pollution" in message.lower():
179
- return "Common river pollutants: plastic waste, chemical foam, industrial discharge, sewage water, oil spills."
180
- else:
181
- return "I can answer questions about river pollution. Try asking about pollution types."
182
-
183
  # Initialize analyzer
184
  analyzer = RiverPollutionAnalyzer()
185
 
186
- # Gradio Interface
 
 
 
187
  css = """
188
- .header {
189
- text-align: center;
190
- padding: 20px;
191
- background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
192
- border-radius: 10px;
193
- margin-bottom: 20px;
194
- }
195
- .side-by-side {
196
- display: flex;
197
- gap: 20px;
198
- }
199
- .left-panel, .right-panel {
200
- flex: 1;
201
- }
202
- .analysis-box {
203
- padding: 20px;
204
- background: #f8f9fa;
205
- border-radius: 10px;
206
- margin-top: 20px;
207
- border: 1px solid #dee2e6;
208
- }
209
- .chat-container {
210
- background: #f8f9fa;
211
- padding: 20px;
212
- border-radius: 10px;
213
- height: 100%;
214
- }
215
- .dark .analysis-box, .dark .chat-container {
216
- background: #2a2a2a;
217
- border-color: #444;
218
- }
219
  """
220
 
221
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
222
  with gr.Column(elem_classes="header"):
223
  gr.Markdown("# 🌍 River Pollution Analyzer")
224
- gr.Markdown(f"### {analyzer.status}")
225
 
226
  with gr.Row(elem_classes="side-by-side"):
227
  # Left Panel
@@ -231,53 +157,54 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
231
  analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
232
 
233
  with gr.Group(elem_classes="analysis-box"):
234
- gr.Markdown("### πŸ“Š Analysis Report")
235
  analysis_output = gr.Markdown()
236
 
237
  # Right Panel
238
  with gr.Column(elem_classes="right-panel"):
239
  with gr.Group(elem_classes="chat-container"):
240
- chatbot = gr.Chatbot(label="Pollution Q&A", height=400)
241
  with gr.Row():
242
- chat_input = gr.Textbox(
243
- placeholder="Ask about pollution sources...",
244
- label="Your Question",
245
- container=False,
246
- scale=5
247
- )
248
  chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
249
- clear_btn = gr.Button("🧹 Clear Chat", size="sm")
250
 
 
251
  analyze_btn.click(
252
  analyzer.analyze_image,
253
  inputs=image_input,
254
  outputs=analysis_output
255
  )
256
-
257
  chat_input.submit(
258
  lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
259
  inputs=[chat_input, chatbot],
260
  outputs=[chat_input, chatbot]
261
  )
262
-
263
  chat_btn.click(
264
  lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
265
  inputs=[chat_input, chatbot],
266
  outputs=[chat_input, chatbot]
267
  )
 
 
 
 
 
268
 
269
- clear_btn.click(lambda: None, outputs=[chatbot])
270
-
271
  gr.Examples(
272
  examples=[
273
- ["examples/polluted_river1.jpg"],
274
- ["examples/polluted_river2.jpg"]
275
  ],
276
  inputs=image_input,
277
  outputs=analysis_output,
278
  fn=analyzer.analyze_image,
279
- cache_examples=torch.cuda.is_available(), # Cache only if GPU available
280
- label="Example Images"
281
  )
282
 
283
- demo.queue(max_size=2).launch()
 
1
 
2
 
3
  import torch
4
+ from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
5
  import gradio as gr
6
  from PIL import Image
7
  import re
 
8
  from typing import List, Tuple
9
 
 
 
 
 
 
 
 
 
10
  class RiverPollutionAnalyzer:
11
  def __init__(self):
12
+ # Initialize model with 4-bit quantization
13
+ self.processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
14
+ self.model = InstructBlipForConditionalGeneration.from_pretrained(
15
+ "Salesforce/instructblip-vicuna-7b",
16
+ device_map="auto",
17
+ torch_dtype=torch.float16,
18
+ load_in_4bit=True
19
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  self.pollutants = [
22
  "plastic waste", "chemical foam", "industrial discharge",
 
39
  }
40
 
41
  def analyze_image(self, image):
42
+ """Analyze river pollution with robust parsing"""
 
 
 
43
  if not isinstance(image, Image.Image):
44
  image = Image.fromarray(image)
 
 
 
45
 
46
  prompt = """Analyze this river pollution scene and provide:
47
  1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
 
51
  Pollutants: [comma separated list]
52
  Severity: [number]"""
53
 
54
+ inputs = self.processor(
55
+ images=image,
56
+ text=prompt,
57
+ return_tensors="pt"
58
+ ).to("cuda", torch.float16)
59
+
60
+ with torch.no_grad():
61
+ outputs = self.model.generate(
62
+ **inputs,
63
+ max_new_tokens=200,
64
+ temperature=0.5,
65
+ top_p=0.85,
66
+ do_sample=True
67
+ )
68
+
69
+ analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
70
+ pollutants, severity = self._parse_response(analysis)
71
+ return self._format_analysis(pollutants, severity)
72
+
 
 
 
 
73
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
74
+ """Robust parsing of model response"""
75
  pollutants = []
76
  severity = 3
77
 
 
80
  r'(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)',
81
  analysis
82
  )
83
+
84
  if pollutant_match:
85
  pollutants_str = pollutant_match.group(2).strip()
86
  pollutants = [
 
90
  ]
91
 
92
  # Extract severity
93
+ severity_match = re.search(
94
+ r'(?i)(severity|level)[:\s]*(\d{1,2})',
95
+ analysis
96
+ )
97
+
98
  if severity_match:
99
  try:
100
  severity = min(max(int(severity_match.group(2)), 1), 10)
 
106
  return pollutants, severity
107
 
108
  def _calculate_severity(self, pollutants: List[str]) -> int:
109
+ """Weighted severity calculation"""
110
  if not pollutants:
111
  return 1
112
 
 
121
  return min(10, max(1, round(avg_weight * 3)))
122
 
123
  def _format_analysis(self, pollutants: List[str], severity: int) -> str:
124
+ """Generate formatted report"""
125
  severity_bar = f"""πŸ“Š Severity: {severity}/10
126
  {"β–ˆ" * severity}{"β–‘" * (10 - severity)}
127
  {self.severity_descriptions.get(severity, '')}"""
 
133
  {pollutants_list}
134
  {severity_bar}"""
135
 
 
 
 
 
 
 
 
 
 
136
  # Initialize analyzer
137
  analyzer = RiverPollutionAnalyzer()
138
 
139
+ import gradio as gr
140
+
141
+ # Import your actual analyzer
142
+
143
  css = """
144
+ /* (Keep all your CSS styles) */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  """
146
 
147
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
148
  with gr.Column(elem_classes="header"):
149
  gr.Markdown("# 🌍 River Pollution Analyzer")
150
+ gr.Markdown("### AI-powered water pollution detection")
151
 
152
  with gr.Row(elem_classes="side-by-side"):
153
  # Left Panel
 
157
  analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
158
 
159
  with gr.Group(elem_classes="analysis-box"):
160
+ gr.Markdown("### πŸ“Š Analysis report")
161
  analysis_output = gr.Markdown()
162
 
163
  # Right Panel
164
  with gr.Column(elem_classes="right-panel"):
165
  with gr.Group(elem_classes="chat-container"):
166
+ chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400)
167
  with gr.Row():
168
+ chat_input = gr.Textbox(placeholder="Ask about pollution sources...",
169
+ label="Your Question", container=False, scale=5)
 
 
 
 
170
  chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
171
+ clear_btn = gr.Button("🧹 Clear Chat History", size="sm")
172
 
173
+ # Connect to your actual analyzer functions
174
  analyze_btn.click(
175
  analyzer.analyze_image,
176
  inputs=image_input,
177
  outputs=analysis_output
178
  )
179
+
180
  chat_input.submit(
181
  lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
182
  inputs=[chat_input, chatbot],
183
  outputs=[chat_input, chatbot]
184
  )
185
+
186
  chat_btn.click(
187
  lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
188
  inputs=[chat_input, chatbot],
189
  outputs=[chat_input, chatbot]
190
  )
191
+
192
+ clear_btn.click(
193
+ lambda: None,
194
+ outputs=[chatbot]
195
+ )
196
 
197
+ # Examples using your real analyzer
 
198
  gr.Examples(
199
  examples=[
200
+ ["https://drive.google.com/uc?export=view&id=1sCxcpacS5WkV5qVrhj8mcdq1JHyVyaEb"],
201
+ ["https://drive.google.com/uc?export=view&id=1WGcXwFhpbD1LrtbQ8E5IZZN3nEGfcwuN"]
202
  ],
203
  inputs=image_input,
204
  outputs=analysis_output,
205
  fn=analyzer.analyze_image,
206
+ cache_examples=True,
207
+ label="Try example images:"
208
  )
209
 
210
+ demo.launch()