atharwaah1work commited on
Commit
334ca28
Β·
verified Β·
1 Parent(s): 5219bc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -67
app.py CHANGED
@@ -1,5 +1,3 @@
1
-
2
-
3
  import torch
4
  from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
5
  import gradio as gr
@@ -7,16 +5,19 @@ from PIL import Image
7
  import re
8
  from typing import List, Tuple
9
 
 
 
 
 
 
10
  class RiverPollutionAnalyzer:
11
  def __init__(self):
12
- # Initialize model with 4-bit quantization
13
- self.processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
14
  self.model = InstructBlipForConditionalGeneration.from_pretrained(
15
- "Salesforce/instructblip-vicuna-7b",
16
- device_map="auto",
17
- torch_dtype=torch.float16,
18
- load_in_4bit=True
19
- )
20
 
21
  self.pollutants = [
22
  "plastic waste", "chemical foam", "industrial discharge",
@@ -55,14 +56,14 @@ Severity: [number]"""
55
  images=image,
56
  text=prompt,
57
  return_tensors="pt"
58
- ).to("cuda", torch.float16)
59
 
60
  with torch.no_grad():
61
  outputs = self.model.generate(
62
  **inputs,
63
  max_new_tokens=200,
64
- temperature=0.5,
65
- top_p=0.85,
66
  do_sample=True
67
  )
68
 
@@ -70,6 +71,15 @@ Severity: [number]"""
70
  pollutants, severity = self._parse_response(analysis)
71
  return self._format_analysis(pollutants, severity)
72
 
 
 
 
 
 
 
 
 
 
73
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
74
  """Robust parsing of model response"""
75
  pollutants = []
@@ -77,29 +87,24 @@ Severity: [number]"""
77
 
78
  # Extract pollutants
79
  pollutant_match = re.search(
80
- r'(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)',
81
- analysis
82
  )
83
-
84
  if pollutant_match:
85
- pollutants_str = pollutant_match.group(2).strip()
86
  pollutants = [
87
- p.strip().lower()
88
- for p in re.split(r'[,;]|\band\b', pollutants_str)
89
  if p.strip().lower() in self.pollutants
90
  ]
91
 
92
  # Extract severity
93
  severity_match = re.search(
94
- r'(?i)(severity|level)[:\s]*(\d{1,2})',
95
- analysis
96
  )
97
-
98
  if severity_match:
99
- try:
100
- severity = min(max(int(severity_match.group(2)), 1), 10)
101
- except:
102
- severity = self._calculate_severity(pollutants)
103
  else:
104
  severity = self._calculate_severity(pollutants)
105
 
@@ -116,7 +121,6 @@ Severity: [number]"""
116
  "plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5,
117
  "agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1
118
  }
119
-
120
  avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants)
121
  return min(10, max(1, round(avg_weight * 3)))
122
 
@@ -127,7 +131,7 @@ Severity: [number]"""
127
  {self.severity_descriptions.get(severity, '')}"""
128
 
129
  pollutants_list = "\nπŸ” No pollutants detected" if not pollutants else "\n".join(
130
- f"{i}. {p.capitalize()}" for i, p in enumerate(pollutants[:5], 1))
131
 
132
  return f"""🌊 River Pollution Analysis 🌊
133
  {pollutants_list}
@@ -136,75 +140,66 @@ Severity: [number]"""
136
  # Initialize analyzer
137
  analyzer = RiverPollutionAnalyzer()
138
 
139
- import gradio as gr
140
-
141
- # Import your actual analyzer
142
-
143
  css = """
144
- /* (Keep all your CSS styles) */
 
 
 
 
 
 
145
  """
146
 
147
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
148
  with gr.Column(elem_classes="header"):
149
  gr.Markdown("# 🌍 River Pollution Analyzer")
150
- gr.Markdown("### AI-powered water pollution detection")
151
 
152
  with gr.Row(elem_classes="side-by-side"):
153
- # Left Panel
154
  with gr.Column(elem_classes="left-panel"):
 
155
  with gr.Group():
156
  image_input = gr.Image(type="pil", label="Upload River Image", height=300)
157
- analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
158
-
159
  with gr.Group(elem_classes="analysis-box"):
160
- gr.Markdown("### πŸ“Š Analysis report")
161
  analysis_output = gr.Markdown()
162
 
163
- # Right Panel
164
  with gr.Column(elem_classes="right-panel"):
 
165
  with gr.Group(elem_classes="chat-container"):
166
- chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400)
167
  with gr.Row():
168
- chat_input = gr.Textbox(placeholder="Ask about pollution sources...",
169
- label="Your Question", container=False, scale=5)
170
- chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
171
- clear_btn = gr.Button("🧹 Clear Chat History", size="sm")
172
 
173
- # Connect to your actual analyzer functions
174
  analyze_btn.click(
175
  analyzer.analyze_image,
176
  inputs=image_input,
177
  outputs=analysis_output
178
  )
179
-
180
- chat_input.submit(
181
- lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
182
- inputs=[chat_input, chatbot],
183
- outputs=[chat_input, chatbot]
184
- )
185
-
186
- chat_btn.click(
187
- lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
188
- inputs=[chat_input, chatbot],
189
- outputs=[chat_input, chatbot]
190
- )
191
-
192
- clear_btn.click(
193
- lambda: None,
194
- outputs=[chatbot]
195
- )
196
 
197
- # Examples using your real analyzer
 
 
 
 
 
 
 
 
 
198
  gr.Examples(
199
- examples=[
200
- ["https://drive.google.com/uc?export=view&id=1sCxcpacS5WkV5qVrhj8mcdq1JHyVyaEb"],
201
- ["https://drive.google.com/uc?export=view&id=1WGcXwFhpbD1LrtbQ8E5IZZN3nEGfcwuN"]
202
- ],
203
  inputs=image_input,
204
  outputs=analysis_output,
205
  fn=analyzer.analyze_image,
206
  cache_examples=True,
207
- label="Try example images:"
208
  )
209
 
210
  demo.launch()
 
 
 
1
  import torch
2
  from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
3
  import gradio as gr
 
5
  import re
6
  from typing import List, Tuple
7
 
8
+ # Configuration
9
+ MODEL_NAME = "Salesforce/instructblip-flan-t5-xl"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
12
+
13
  class RiverPollutionAnalyzer:
14
  def __init__(self):
15
+ # Initialize processor and model
16
+ self.processor = InstructBlipProcessor.from_pretrained(MODEL_NAME)
17
  self.model = InstructBlipForConditionalGeneration.from_pretrained(
18
+ MODEL_NAME,
19
+ torch_dtype=TORCH_DTYPE
20
+ ).to(DEVICE)
 
 
21
 
22
  self.pollutants = [
23
  "plastic waste", "chemical foam", "industrial discharge",
 
56
  images=image,
57
  text=prompt,
58
  return_tensors="pt"
59
+ ).to(DEVICE, TORCH_DTYPE)
60
 
61
  with torch.no_grad():
62
  outputs = self.model.generate(
63
  **inputs,
64
  max_new_tokens=200,
65
+ temperature=0.7,
66
+ top_p=0.9,
67
  do_sample=True
68
  )
69
 
 
71
  pollutants, severity = self._parse_response(analysis)
72
  return self._format_analysis(pollutants, severity)
73
 
74
+ def analyze_chat(self, message):
75
+ """Handle chat questions about pollution"""
76
+ if "severity" in message.lower():
77
+ return "Severity levels range from 1 (minimal) to 10 (disaster). The analyzer automatically detects the appropriate level."
78
+ elif "pollutant" in message.lower():
79
+ return f"Detectable pollutants: {', '.join(self.pollutants)}"
80
+ else:
81
+ return "I can answer questions about pollution severity levels and detectable pollutants."
82
+
83
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
84
  """Robust parsing of model response"""
85
  pollutants = []
 
87
 
88
  # Extract pollutants
89
  pollutant_match = re.search(
90
+ r'Pollutants:\s*\[?(.*?)\]?',
91
+ analysis, re.IGNORECASE
92
  )
 
93
  if pollutant_match:
94
+ pollutants_str = pollutant_match.group(1).strip()
95
  pollutants = [
96
+ p.strip().lower()
97
+ for p in re.split(r'[,;]', pollutants_str)
98
  if p.strip().lower() in self.pollutants
99
  ]
100
 
101
  # Extract severity
102
  severity_match = re.search(
103
+ r'Severity:\s*(\d{1,2})',
104
+ analysis, re.IGNORECASE
105
  )
 
106
  if severity_match:
107
+ severity = min(max(int(severity_match.group(1)), 1), 10)
 
 
 
108
  else:
109
  severity = self._calculate_severity(pollutants)
110
 
 
121
  "plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5,
122
  "agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1
123
  }
 
124
  avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants)
125
  return min(10, max(1, round(avg_weight * 3)))
126
 
 
131
  {self.severity_descriptions.get(severity, '')}"""
132
 
133
  pollutants_list = "\nπŸ” No pollutants detected" if not pollutants else "\n".join(
134
+ f"β€’ {p.capitalize()}" for p in pollutants[:8])
135
 
136
  return f"""🌊 River Pollution Analysis 🌊
137
  {pollutants_list}
 
140
  # Initialize analyzer
141
  analyzer = RiverPollutionAnalyzer()
142
 
143
+ # Gradio Interface
 
 
 
144
  css = """
145
+ .header { text-align: center; margin-bottom: 20px; }
146
+ .header h1 { font-size: 2.2rem; margin-bottom: 0; }
147
+ .header h3 { font-size: 1.1rem; font-weight: normal; margin-top: 0.5rem; }
148
+ .side-by-side { display: flex; gap: 20px; }
149
+ .left-panel, .right-panel { flex: 1; }
150
+ .analysis-box { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; margin-top: 20px; }
151
+ .chat-container { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; height: 100%; }
152
  """
153
 
154
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
155
  with gr.Column(elem_classes="header"):
156
  gr.Markdown("# 🌍 River Pollution Analyzer")
157
+ gr.Markdown("### AI-powered water quality assessment")
158
 
159
  with gr.Row(elem_classes="side-by-side"):
160
+ # Image Analysis Panel
161
  with gr.Column(elem_classes="left-panel"):
162
+ gr.Markdown("### πŸ“Έ Image Analysis")
163
  with gr.Group():
164
  image_input = gr.Image(type="pil", label="Upload River Image", height=300)
165
+ analyze_btn = gr.Button("πŸ” Analyze", variant="primary")
 
166
  with gr.Group(elem_classes="analysis-box"):
 
167
  analysis_output = gr.Markdown()
168
 
169
+ # Chat Panel
170
  with gr.Column(elem_classes="right-panel"):
171
+ gr.Markdown("### πŸ’¬ Pollution Q&A")
172
  with gr.Group(elem_classes="chat-container"):
173
+ chatbot = gr.Chatbot(height=350)
174
  with gr.Row():
175
+ chat_input = gr.Textbox(placeholder="Ask about pollution...", show_label=False)
176
+ chat_btn = gr.Button("Send", variant="secondary")
177
+ clear_btn = gr.Button("Clear Chat")
 
178
 
179
+ # Event handlers
180
  analyze_btn.click(
181
  analyzer.analyze_image,
182
  inputs=image_input,
183
  outputs=analysis_output
184
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ def respond(message, chat_history):
187
+ response = analyzer.analyze_chat(message)
188
+ chat_history.append((message, response))
189
+ return "", chat_history
190
+
191
+ chat_input.submit(respond, [chat_input, chatbot], [chat_input, chatbot])
192
+ chat_btn.click(respond, [chat_input, chatbot], [chat_input, chatbot])
193
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
194
+
195
+ # Examples
196
  gr.Examples(
197
+ examples=[["examples/pollution1.jpg"], ["examples/pollution2.jpg"]],
 
 
 
198
  inputs=image_input,
199
  outputs=analysis_output,
200
  fn=analyzer.analyze_image,
201
  cache_examples=True,
202
+ label="Example Images"
203
  )
204
 
205
  demo.launch()