slcr-hack / app.py
atharwaah1work's picture
Update app.py
334ca28 verified
import torch
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import gradio as gr
from PIL import Image
import re
from typing import List, Tuple
# Configuration
MODEL_NAME = "Salesforce/instructblip-flan-t5-xl"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
class RiverPollutionAnalyzer:
def __init__(self):
# Initialize processor and model
self.processor = InstructBlipProcessor.from_pretrained(MODEL_NAME)
self.model = InstructBlipForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=TORCH_DTYPE
).to(DEVICE)
self.pollutants = [
"plastic waste", "chemical foam", "industrial discharge",
"sewage water", "oil spill", "organic debris",
"construction waste", "medical waste", "floating trash",
"algal bloom", "toxic sludge", "agricultural runoff"
]
self.severity_descriptions = {
1: "Minimal pollution - Slightly noticeable",
2: "Minor pollution - Small amounts visible",
3: "Moderate pollution - Clearly visible",
4: "Significant pollution - Affecting water quality",
5: "Heavy pollution - Obvious environmental impact",
6: "Severe pollution - Large accumulation",
7: "Very severe pollution - Major ecosystem impact",
8: "Extreme pollution - Dangerous levels",
9: "Critical pollution - Immediate action needed",
10: "Disaster level - Ecological catastrophe"
}
def analyze_image(self, image):
"""Analyze river pollution with robust parsing"""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
prompt = """Analyze this river pollution scene and provide:
1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
2. Estimate pollution severity from 1-10
Respond EXACTLY in this format:
Pollutants: [comma separated list]
Severity: [number]"""
inputs = self.processor(
images=image,
text=prompt,
return_tensors="pt"
).to(DEVICE, TORCH_DTYPE)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True
)
analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
pollutants, severity = self._parse_response(analysis)
return self._format_analysis(pollutants, severity)
def analyze_chat(self, message):
"""Handle chat questions about pollution"""
if "severity" in message.lower():
return "Severity levels range from 1 (minimal) to 10 (disaster). The analyzer automatically detects the appropriate level."
elif "pollutant" in message.lower():
return f"Detectable pollutants: {', '.join(self.pollutants)}"
else:
return "I can answer questions about pollution severity levels and detectable pollutants."
def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
"""Robust parsing of model response"""
pollutants = []
severity = 3
# Extract pollutants
pollutant_match = re.search(
r'Pollutants:\s*\[?(.*?)\]?',
analysis, re.IGNORECASE
)
if pollutant_match:
pollutants_str = pollutant_match.group(1).strip()
pollutants = [
p.strip().lower()
for p in re.split(r'[,;]', pollutants_str)
if p.strip().lower() in self.pollutants
]
# Extract severity
severity_match = re.search(
r'Severity:\s*(\d{1,2})',
analysis, re.IGNORECASE
)
if severity_match:
severity = min(max(int(severity_match.group(1)), 1), 10)
else:
severity = self._calculate_severity(pollutants)
return pollutants, severity
def _calculate_severity(self, pollutants: List[str]) -> int:
"""Weighted severity calculation"""
if not pollutants:
return 1
weights = {
"medical waste": 3, "toxic sludge": 3, "oil spill": 2.5,
"chemical foam": 2, "industrial discharge": 2, "sewage water": 2,
"plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5,
"agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1
}
avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants)
return min(10, max(1, round(avg_weight * 3)))
def _format_analysis(self, pollutants: List[str], severity: int) -> str:
"""Generate formatted report"""
severity_bar = f"""πŸ“Š Severity: {severity}/10
{"β–ˆ" * severity}{"β–‘" * (10 - severity)}
{self.severity_descriptions.get(severity, '')}"""
pollutants_list = "\nπŸ” No pollutants detected" if not pollutants else "\n".join(
f"β€’ {p.capitalize()}" for p in pollutants[:8])
return f"""🌊 River Pollution Analysis 🌊
{pollutants_list}
{severity_bar}"""
# Initialize analyzer
analyzer = RiverPollutionAnalyzer()
# Gradio Interface
css = """
.header { text-align: center; margin-bottom: 20px; }
.header h1 { font-size: 2.2rem; margin-bottom: 0; }
.header h3 { font-size: 1.1rem; font-weight: normal; margin-top: 0.5rem; }
.side-by-side { display: flex; gap: 20px; }
.left-panel, .right-panel { flex: 1; }
.analysis-box { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; margin-top: 20px; }
.chat-container { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; height: 100%; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
with gr.Column(elem_classes="header"):
gr.Markdown("# 🌍 River Pollution Analyzer")
gr.Markdown("### AI-powered water quality assessment")
with gr.Row(elem_classes="side-by-side"):
# Image Analysis Panel
with gr.Column(elem_classes="left-panel"):
gr.Markdown("### πŸ“Έ Image Analysis")
with gr.Group():
image_input = gr.Image(type="pil", label="Upload River Image", height=300)
analyze_btn = gr.Button("πŸ” Analyze", variant="primary")
with gr.Group(elem_classes="analysis-box"):
analysis_output = gr.Markdown()
# Chat Panel
with gr.Column(elem_classes="right-panel"):
gr.Markdown("### πŸ’¬ Pollution Q&A")
with gr.Group(elem_classes="chat-container"):
chatbot = gr.Chatbot(height=350)
with gr.Row():
chat_input = gr.Textbox(placeholder="Ask about pollution...", show_label=False)
chat_btn = gr.Button("Send", variant="secondary")
clear_btn = gr.Button("Clear Chat")
# Event handlers
analyze_btn.click(
analyzer.analyze_image,
inputs=image_input,
outputs=analysis_output
)
def respond(message, chat_history):
response = analyzer.analyze_chat(message)
chat_history.append((message, response))
return "", chat_history
chat_input.submit(respond, [chat_input, chatbot], [chat_input, chatbot])
chat_btn.click(respond, [chat_input, chatbot], [chat_input, chatbot])
clear_btn.click(lambda: None, None, chatbot, queue=False)
# Examples
gr.Examples(
examples=[["examples/pollution1.jpg"], ["examples/pollution2.jpg"]],
inputs=image_input,
outputs=analysis_output,
fn=analyzer.analyze_image,
cache_examples=True,
label="Example Images"
)
demo.launch()