Spaces:
Runtime error
Runtime error
import torch | |
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration | |
import gradio as gr | |
from PIL import Image | |
import re | |
from typing import List, Tuple | |
# Configuration | |
MODEL_NAME = "Salesforce/instructblip-flan-t5-xl" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
class RiverPollutionAnalyzer: | |
def __init__(self): | |
# Initialize processor and model | |
self.processor = InstructBlipProcessor.from_pretrained(MODEL_NAME) | |
self.model = InstructBlipForConditionalGeneration.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=TORCH_DTYPE | |
).to(DEVICE) | |
self.pollutants = [ | |
"plastic waste", "chemical foam", "industrial discharge", | |
"sewage water", "oil spill", "organic debris", | |
"construction waste", "medical waste", "floating trash", | |
"algal bloom", "toxic sludge", "agricultural runoff" | |
] | |
self.severity_descriptions = { | |
1: "Minimal pollution - Slightly noticeable", | |
2: "Minor pollution - Small amounts visible", | |
3: "Moderate pollution - Clearly visible", | |
4: "Significant pollution - Affecting water quality", | |
5: "Heavy pollution - Obvious environmental impact", | |
6: "Severe pollution - Large accumulation", | |
7: "Very severe pollution - Major ecosystem impact", | |
8: "Extreme pollution - Dangerous levels", | |
9: "Critical pollution - Immediate action needed", | |
10: "Disaster level - Ecological catastrophe" | |
} | |
def analyze_image(self, image): | |
"""Analyze river pollution with robust parsing""" | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
prompt = """Analyze this river pollution scene and provide: | |
1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff] | |
2. Estimate pollution severity from 1-10 | |
Respond EXACTLY in this format: | |
Pollutants: [comma separated list] | |
Severity: [number]""" | |
inputs = self.processor( | |
images=image, | |
text=prompt, | |
return_tensors="pt" | |
).to(DEVICE, TORCH_DTYPE) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=200, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
pollutants, severity = self._parse_response(analysis) | |
return self._format_analysis(pollutants, severity) | |
def analyze_chat(self, message): | |
"""Handle chat questions about pollution""" | |
if "severity" in message.lower(): | |
return "Severity levels range from 1 (minimal) to 10 (disaster). The analyzer automatically detects the appropriate level." | |
elif "pollutant" in message.lower(): | |
return f"Detectable pollutants: {', '.join(self.pollutants)}" | |
else: | |
return "I can answer questions about pollution severity levels and detectable pollutants." | |
def _parse_response(self, analysis: str) -> Tuple[List[str], int]: | |
"""Robust parsing of model response""" | |
pollutants = [] | |
severity = 3 | |
# Extract pollutants | |
pollutant_match = re.search( | |
r'Pollutants:\s*\[?(.*?)\]?', | |
analysis, re.IGNORECASE | |
) | |
if pollutant_match: | |
pollutants_str = pollutant_match.group(1).strip() | |
pollutants = [ | |
p.strip().lower() | |
for p in re.split(r'[,;]', pollutants_str) | |
if p.strip().lower() in self.pollutants | |
] | |
# Extract severity | |
severity_match = re.search( | |
r'Severity:\s*(\d{1,2})', | |
analysis, re.IGNORECASE | |
) | |
if severity_match: | |
severity = min(max(int(severity_match.group(1)), 1), 10) | |
else: | |
severity = self._calculate_severity(pollutants) | |
return pollutants, severity | |
def _calculate_severity(self, pollutants: List[str]) -> int: | |
"""Weighted severity calculation""" | |
if not pollutants: | |
return 1 | |
weights = { | |
"medical waste": 3, "toxic sludge": 3, "oil spill": 2.5, | |
"chemical foam": 2, "industrial discharge": 2, "sewage water": 2, | |
"plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5, | |
"agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1 | |
} | |
avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants) | |
return min(10, max(1, round(avg_weight * 3))) | |
def _format_analysis(self, pollutants: List[str], severity: int) -> str: | |
"""Generate formatted report""" | |
severity_bar = f"""π Severity: {severity}/10 | |
{"β" * severity}{"β" * (10 - severity)} | |
{self.severity_descriptions.get(severity, '')}""" | |
pollutants_list = "\nπ No pollutants detected" if not pollutants else "\n".join( | |
f"β’ {p.capitalize()}" for p in pollutants[:8]) | |
return f"""π River Pollution Analysis π | |
{pollutants_list} | |
{severity_bar}""" | |
# Initialize analyzer | |
analyzer = RiverPollutionAnalyzer() | |
# Gradio Interface | |
css = """ | |
.header { text-align: center; margin-bottom: 20px; } | |
.header h1 { font-size: 2.2rem; margin-bottom: 0; } | |
.header h3 { font-size: 1.1rem; font-weight: normal; margin-top: 0.5rem; } | |
.side-by-side { display: flex; gap: 20px; } | |
.left-panel, .right-panel { flex: 1; } | |
.analysis-box { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; margin-top: 20px; } | |
.chat-container { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; height: 100%; } | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
with gr.Column(elem_classes="header"): | |
gr.Markdown("# π River Pollution Analyzer") | |
gr.Markdown("### AI-powered water quality assessment") | |
with gr.Row(elem_classes="side-by-side"): | |
# Image Analysis Panel | |
with gr.Column(elem_classes="left-panel"): | |
gr.Markdown("### πΈ Image Analysis") | |
with gr.Group(): | |
image_input = gr.Image(type="pil", label="Upload River Image", height=300) | |
analyze_btn = gr.Button("π Analyze", variant="primary") | |
with gr.Group(elem_classes="analysis-box"): | |
analysis_output = gr.Markdown() | |
# Chat Panel | |
with gr.Column(elem_classes="right-panel"): | |
gr.Markdown("### π¬ Pollution Q&A") | |
with gr.Group(elem_classes="chat-container"): | |
chatbot = gr.Chatbot(height=350) | |
with gr.Row(): | |
chat_input = gr.Textbox(placeholder="Ask about pollution...", show_label=False) | |
chat_btn = gr.Button("Send", variant="secondary") | |
clear_btn = gr.Button("Clear Chat") | |
# Event handlers | |
analyze_btn.click( | |
analyzer.analyze_image, | |
inputs=image_input, | |
outputs=analysis_output | |
) | |
def respond(message, chat_history): | |
response = analyzer.analyze_chat(message) | |
chat_history.append((message, response)) | |
return "", chat_history | |
chat_input.submit(respond, [chat_input, chatbot], [chat_input, chatbot]) | |
chat_btn.click(respond, [chat_input, chatbot], [chat_input, chatbot]) | |
clear_btn.click(lambda: None, None, chatbot, queue=False) | |
# Examples | |
gr.Examples( | |
examples=[["examples/pollution1.jpg"], ["examples/pollution2.jpg"]], | |
inputs=image_input, | |
outputs=analysis_output, | |
fn=analyzer.analyze_image, | |
cache_examples=True, | |
label="Example Images" | |
) | |
demo.launch() |