|
|
|
|
|
""" |
|
Enhanced Single-View Gradio App for Semantic Segmentation, Depth Estimation, and 3D Point Cloud |
|
Processes one image and shows all outputs: segmentation, depth, and colored point cloud |
|
Now with precomputed examples for demonstration |
|
""" |
|
|
|
import sys |
|
import locale |
|
import os |
|
import datetime |
|
from pathlib import Path |
|
|
|
|
|
if sys.version_info >= (3, 7): |
|
sys.stdout.reconfigure(encoding='utf-8') |
|
sys.stderr.reconfigure(encoding='utf-8') |
|
|
|
|
|
try: |
|
locale.setlocale(locale.LC_ALL, 'en_US.UTF-8') |
|
except locale.Error: |
|
try: |
|
locale.setlocale(locale.LC_ALL, 'C.UTF-8') |
|
except locale.Error: |
|
pass |
|
|
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import io |
|
import base64 |
|
from dataclasses import dataclass |
|
from typing import Optional, List, Tuple, Dict, Any |
|
import requests |
|
import cv2 |
|
from abc import ABC, abstractmethod |
|
from collections import namedtuple |
|
import plotly.graph_objects as go |
|
import plotly.io as pio |
|
import open3d as o3d |
|
import json |
|
import subprocess |
|
|
|
|
|
try: |
|
from metric_depth.depth_anything_v2.dpt import DepthAnythingV2 |
|
DEPTH_AVAILABLE = True |
|
except ImportError: |
|
print("DepthAnythingV2 not available. Using precomputed examples only.") |
|
DEPTH_AVAILABLE = False |
|
|
|
|
|
CUDA_AVAILABLE = torch.cuda.is_available() |
|
|
|
|
|
os.environ['XFORMERS_DISABLED'] = '1' |
|
os.environ['XFORMERS_MORE_DETAILS'] = '1' |
|
|
|
|
|
OUTPUT_DIR = Path("outputs") |
|
|
|
def fix_lfs_on_startup(): |
|
"""Quick fix for LFS issues on HuggingFace startup.""" |
|
print("Checking for LFS issues...") |
|
|
|
try: |
|
|
|
result = subprocess.run(['git', 'lfs', 'pull'], |
|
capture_output=True, text=True, timeout=30) |
|
if result.returncode == 0: |
|
print("LFS files pulled successfully") |
|
else: |
|
print(f"LFS pull failed: {result.stderr}") |
|
|
|
subprocess.run(['git', 'lfs', 'checkout'], |
|
capture_output=True, timeout=20) |
|
except Exception as e: |
|
print(f"LFS operations failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class ModelConfig: |
|
"""Configuration for segmentation models.""" |
|
model_name: str |
|
processor_name: str |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
trust_remote_code: bool = True |
|
task_type: str = "semantic" |
|
|
|
@dataclass |
|
class DepthConfig: |
|
"""Configuration for depth estimation models.""" |
|
encoder: str = "vitl" |
|
dataset: str = "vkitti" |
|
max_depth: int = 80 |
|
weights_path: str = "depth_anything_v2_metric_vkitti_vitl.pth" |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
class BaseSegmentationModel(ABC): |
|
"""Abstract base class for segmentation models.""" |
|
|
|
def __init__(self, model_config): |
|
self.config = model_config |
|
self.model = None |
|
self.processor = None |
|
self.device = torch.device(model_config.device if torch.cuda.is_available() else "cpu") |
|
|
|
@abstractmethod |
|
def load_model(self): |
|
"""Load the model and processor.""" |
|
pass |
|
|
|
@abstractmethod |
|
def preprocess(self, image: Image.Image, **kwargs) -> Dict[str, torch.Tensor]: |
|
"""Preprocess the input image.""" |
|
pass |
|
|
|
@abstractmethod |
|
def predict(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
"""Run inference on preprocessed inputs.""" |
|
pass |
|
|
|
@abstractmethod |
|
def postprocess(self, outputs: Dict[str, torch.Tensor], target_size: Tuple[int, int]) -> np.ndarray: |
|
"""Postprocess model outputs to segmentation map.""" |
|
pass |
|
|
|
def segment_image(self, image: Image.Image, **kwargs) -> np.ndarray: |
|
"""End-to-end segmentation pipeline.""" |
|
if self.model is None: |
|
self.load_model() |
|
|
|
inputs = self.preprocess(image, **kwargs) |
|
outputs = self.predict(inputs) |
|
segmentation_map = self.postprocess(outputs, image.size[::-1]) |
|
|
|
return segmentation_map |
|
|
|
|
|
|
|
|
|
|
|
class OneFormerModel(BaseSegmentationModel): |
|
"""OneFormer model for universal segmentation.""" |
|
|
|
def __init__(self, model_config): |
|
super().__init__(model_config) |
|
|
|
def load_model(self): |
|
"""Load OneFormer model and processor.""" |
|
print(f"Loading OneFormer model: {self.config.model_name}") |
|
|
|
try: |
|
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation |
|
|
|
self.processor = OneFormerProcessor.from_pretrained( |
|
self.config.processor_name, |
|
trust_remote_code=self.config.trust_remote_code |
|
) |
|
|
|
self.model = OneFormerForUniversalSegmentation.from_pretrained( |
|
self.config.model_name, |
|
trust_remote_code=self.config.trust_remote_code |
|
) |
|
|
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
print(f"OneFormer model loaded successfully on {self.device}") |
|
|
|
except Exception as e: |
|
print(f"Error loading OneFormer model: {e}") |
|
raise |
|
|
|
def preprocess(self, image: Image.Image, task_inputs: List[str] = None) -> Dict[str, torch.Tensor]: |
|
"""Preprocess image for OneFormer.""" |
|
if task_inputs is None: |
|
task_inputs = [self.config.task_type] |
|
|
|
inputs = self.processor( |
|
images=image, |
|
task_inputs=task_inputs, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
|
for k, v in inputs.items()} |
|
|
|
return inputs |
|
|
|
def predict(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
"""Run inference with OneFormer.""" |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
return outputs |
|
|
|
def postprocess(self, outputs: Dict[str, torch.Tensor], target_size: Tuple[int, int]) -> np.ndarray: |
|
"""Postprocess OneFormer outputs.""" |
|
predicted_semantic_map = self.processor.post_process_semantic_segmentation( |
|
outputs, |
|
target_sizes=[target_size] |
|
)[0] |
|
|
|
return predicted_semantic_map.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
class DepthAnythingV2Model: |
|
"""DepthAnythingV2 model for depth estimation.""" |
|
|
|
def __init__(self, depth_config: DepthConfig): |
|
self.config = depth_config |
|
self.model = None |
|
self.device = torch.device(depth_config.device if torch.cuda.is_available() else "cpu") |
|
|
|
def load_model(self): |
|
"""Load DepthAnythingV2 model.""" |
|
if not DEPTH_AVAILABLE: |
|
raise ImportError("DepthAnythingV2 is not available") |
|
|
|
print(f"Loading DepthAnythingV2 model: {self.config.encoder}") |
|
|
|
try: |
|
model_configs = { |
|
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, |
|
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, |
|
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]} |
|
} |
|
|
|
self.model = DepthAnythingV2(**{**model_configs[self.config.encoder], 'max_depth': self.config.max_depth}) |
|
|
|
|
|
if os.path.exists(self.config.weights_path): |
|
self.model.load_state_dict(torch.load(self.config.weights_path, map_location='cpu')) |
|
print(f"Loaded weights from {self.config.weights_path}") |
|
else: |
|
print(f"Warning: Weights file {self.config.weights_path} not found") |
|
|
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
print(f"DepthAnythingV2 model loaded successfully on {self.device}") |
|
|
|
except Exception as e: |
|
print(f"Error loading DepthAnythingV2 model: {e}") |
|
raise |
|
|
|
def estimate_depth(self, image: Image.Image) -> np.ndarray: |
|
"""Estimate depth from image.""" |
|
if self.model is None: |
|
self.load_model() |
|
|
|
|
|
img_array = np.array(image) |
|
if len(img_array.shape) == 3: |
|
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
depth_map = self.model.infer_image(img_array) |
|
|
|
return depth_map |
|
|
|
|
|
|
|
|
|
|
|
Label = namedtuple('Label', [ |
|
'name', 'id', 'trainId', 'category', 'categoryId', |
|
'hasInstances', 'ignoreInEval', 'color' |
|
]) |
|
|
|
labels = [ |
|
Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), |
|
Label('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), |
|
Label('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), |
|
Label('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), |
|
Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), |
|
Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), |
|
Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), |
|
Label('road', 7, 0, 'flat', 1, False, False, (128, 64,128)), |
|
Label('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35,232)), |
|
Label('parking', 9, 255, 'flat', 1, False, True, (250,170,160)), |
|
Label('rail track', 10, 255, 'flat', 1, False, True, (230,150,140)), |
|
Label('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), |
|
Label('wall', 12, 3, 'construction', 2, False, False, (102,102,156)), |
|
Label('fence', 13, 4, 'construction', 2, False, False, (190,153,153)), |
|
Label('guard rail', 14, 255, 'construction', 2, False, True, (180,165,180)), |
|
Label('bridge', 15, 255, 'construction', 2, False, True, (150,100,100)), |
|
Label('tunnel', 16, 255, 'construction', 2, False, True, (150,120, 90)), |
|
Label('pole', 17, 5, 'object', 3, False, False, (153,153,153)), |
|
Label('polegroup', 18, 255, 'object', 3, False, True, (153,153,153)), |
|
Label('traffic light', 19, 6, 'object', 3, False, False, (250,170, 30)), |
|
Label('traffic sign', 20, 7, 'object', 3, False, False, (220,220, 0)), |
|
Label('vegetation', 21, 8, 'nature', 4, False, False, (107,142, 35)), |
|
Label('terrain', 22, 9, 'nature', 4, False, False, (152,251,152)), |
|
Label('sky', 23, 10, 'sky', 5, False, False, (70,130,180)), |
|
Label('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), |
|
Label('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), |
|
Label('car', 26, 13, 'vehicle', 7, True, False, (0, 0,142)), |
|
Label('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), |
|
Label('bus', 28, 15, 'vehicle', 7, True, False, (0, 60,100)), |
|
Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), |
|
Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0,110)), |
|
Label('train', 31, 16, 'vehicle', 7, True, False, (0, 80,100)), |
|
Label('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0,230)), |
|
Label('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), |
|
Label('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0,142)), |
|
] |
|
|
|
|
|
SKY_TRAIN_ID = 10 |
|
|
|
|
|
|
|
|
|
|
|
def get_color_map(labels): |
|
"""Returns a color map dictionary for the given labels.""" |
|
color_map = {label.trainId: label.color for label in labels if label.trainId != 255} |
|
return color_map |
|
|
|
def apply_color_map(semantic_map, color_map): |
|
"""Applies a color map to a semantic map.""" |
|
height, width = semantic_map.shape |
|
color_mapped_image = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
|
for trainId, color in color_map.items(): |
|
mask = semantic_map == trainId |
|
color_mapped_image[mask] = color |
|
|
|
return color_mapped_image |
|
|
|
def create_depth_visualization(depth_map: np.ndarray, colormap: str = 'magma') -> Image.Image: |
|
"""Create a colored depth map visualization with exact dimensions.""" |
|
|
|
normalized_depth = depth_map / np.max(depth_map) |
|
|
|
|
|
cmap = plt.get_cmap(colormap) |
|
colored_depth = cmap(normalized_depth) |
|
|
|
|
|
colored_depth_8bit = (colored_depth[:, :, :3] * 255).astype(np.uint8) |
|
|
|
return Image.fromarray(colored_depth_8bit) |
|
|
|
def depth_to_point_cloud_with_segmentation(depth_map: np.ndarray, rgb_image: Image.Image, |
|
semantic_map: np.ndarray, |
|
fx: float = 525.0, fy: float = 525.0, |
|
cx: float = None, cy: float = None) -> o3d.geometry.PointCloud: |
|
"""Convert depth map and RGB image to 3D point cloud with segmentation colors, excluding sky.""" |
|
height, width = depth_map.shape |
|
|
|
if cx is None: |
|
cx = width / 2.0 |
|
if cy is None: |
|
cy = height / 2.0 |
|
|
|
|
|
u, v = np.meshgrid(np.arange(width), np.arange(height)) |
|
|
|
|
|
z = depth_map |
|
x = (u - cx) * z / fx |
|
y = (v - cy) * z / fy |
|
|
|
|
|
points = np.stack([x, y, z], axis=-1).reshape(-1, 3) |
|
|
|
|
|
flat_semantic = semantic_map.flatten() |
|
flat_depth = z.flatten() |
|
|
|
|
|
valid_mask = (flat_depth > 0) & (flat_depth < 1000) & (flat_semantic != SKY_TRAIN_ID) |
|
points = points[valid_mask] |
|
|
|
|
|
color_map = get_color_map(labels) |
|
seg_colors = np.zeros((len(flat_semantic), 3)) |
|
|
|
for trainId, color in color_map.items(): |
|
mask = flat_semantic == trainId |
|
seg_colors[mask] = color |
|
|
|
|
|
colors = seg_colors[valid_mask] / 255.0 |
|
|
|
|
|
pcd = o3d.geometry.PointCloud() |
|
pcd.points = o3d.utility.Vector3dVector(points) |
|
pcd.colors = o3d.utility.Vector3dVector(colors) |
|
|
|
return pcd |
|
|
|
def create_plotly_pointcloud(pcd: o3d.geometry.PointCloud, downsample_factor: float = 0.1) -> go.Figure: |
|
"""Create interactive Plotly 3D point cloud visualization.""" |
|
|
|
if downsample_factor < 1.0: |
|
num_points = len(pcd.points) |
|
indices = np.random.choice(num_points, int(num_points * downsample_factor), replace=False) |
|
points = np.asarray(pcd.points)[indices] |
|
colors = np.asarray(pcd.colors)[indices] |
|
else: |
|
points = np.asarray(pcd.points) |
|
colors = np.asarray(pcd.colors) |
|
|
|
|
|
fig = go.Figure(data=[go.Scatter3d( |
|
x=points[:, 0], |
|
y=points[:, 1], |
|
z=points[:, 2], |
|
mode='markers', |
|
marker=dict( |
|
size=1, |
|
color=colors, |
|
opacity=0.8 |
|
), |
|
text=[f'Point {i}' for i in range(len(points))], |
|
hovertemplate='X: %{x:.2f}<br>Y: %{y:.2f}<br>Z: %{z:.2f}<extra></extra>' |
|
)]) |
|
|
|
|
|
fig.update_layout( |
|
scene=dict( |
|
xaxis_title='X (Horizontal)', |
|
yaxis_title='Y (Vertical)', |
|
zaxis_title='Z (Depth)', |
|
aspectmode='data' |
|
), |
|
title={ |
|
'text': 'Interactive 3D Point Cloud (Colored by Segmentation, Sky Excluded)', |
|
'x': 0.5, |
|
'xanchor': 'center' |
|
}, |
|
width=None, |
|
height=600, |
|
margin=dict(l=0, r=0, t=40, b=0), |
|
autosize=True |
|
) |
|
|
|
|
|
fig.update_layout(scene_camera=dict( |
|
up=dict(x=0, y=0, z=1), |
|
center=dict(x=0, y=0, z=0), |
|
eye=dict(x=0.5, y=-2.5, z=1.5) |
|
)) |
|
|
|
return fig |
|
|
|
def create_overlay_plot(rgb_image: Image.Image, semantic_map: np.ndarray, alpha: float = 0.5): |
|
"""Create segmentation overlay plot without title and borders.""" |
|
rgb_array = np.array(rgb_image) |
|
color_map = get_color_map(labels) |
|
colored_semantic_map = apply_color_map(semantic_map, color_map) |
|
|
|
|
|
height, width = rgb_array.shape[:2] |
|
dpi = 100 |
|
fig, ax = plt.subplots(1, 1, figsize=(width/dpi, height/dpi), dpi=dpi) |
|
|
|
|
|
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
|
|
ax.imshow(rgb_array) |
|
ax.imshow(colored_semantic_map, alpha=alpha) |
|
ax.axis('off') |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=dpi) |
|
buf.seek(0) |
|
plt.close(fig) |
|
|
|
return Image.open(buf) |
|
|
|
class PrecomputedExamplesManager: |
|
"""Manages precomputed examples from output folder structure.""" |
|
|
|
def __init__(self, output_dir: Path): |
|
self.output_dir = output_dir |
|
self.rgb_dir = output_dir / "rgb" |
|
self.segmentation_dir = output_dir / "segmentation" |
|
self.depth_dir = output_dir / "depth" |
|
self.pointclouds_dir = output_dir / "pointclouds" |
|
self.examples = self._load_examples() |
|
|
|
def _load_examples(self) -> Dict[str, Dict]: |
|
"""Load all available precomputed examples from output structure.""" |
|
examples = {} |
|
|
|
if not self.output_dir.exists(): |
|
print(f"Output directory {self.output_dir} not found.") |
|
return {} |
|
|
|
|
|
if not self.rgb_dir.exists(): |
|
print(f"RGB directory {self.rgb_dir} not found.") |
|
return {} |
|
|
|
|
|
timestamps = set() |
|
for rgb_file in self.rgb_dir.glob("rgb_*.png"): |
|
|
|
filename = rgb_file.stem |
|
if filename.startswith("rgb_"): |
|
timestamp = filename.replace("rgb_", "") |
|
timestamps.add(timestamp) |
|
|
|
print(f"Found {len(timestamps)} RGB input images") |
|
|
|
|
|
for timestamp in sorted(timestamps, reverse=True): |
|
example_data = self._load_single_example(timestamp) |
|
if example_data: |
|
examples[timestamp] = example_data |
|
|
|
print(f"Loaded {len(examples)} precomputed examples from output directory") |
|
return examples |
|
|
|
def _load_single_example(self, timestamp: str) -> Optional[Dict]: |
|
"""Load a single precomputed example by timestamp.""" |
|
try: |
|
|
|
rgb_path = self.rgb_dir / f"rgb_{timestamp}.png" |
|
|
|
|
|
seg_path = self.segmentation_dir / f"segmentation_{timestamp}.png" |
|
depth_path = self.depth_dir / f"depth_{timestamp}.png" |
|
ply_path = self.pointclouds_dir / f"pointcloud_{timestamp}.ply" |
|
html_path = self.pointclouds_dir / f"pointcloud_{timestamp}.html" |
|
|
|
|
|
if not rgb_path.exists(): |
|
print(f"RGB input file missing for timestamp {timestamp}: {rgb_path}") |
|
return None |
|
|
|
|
|
if not seg_path.exists(): |
|
print(f"Segmentation output missing for timestamp {timestamp}: {seg_path}") |
|
return None |
|
|
|
|
|
try: |
|
|
|
if len(timestamp) >= 13 and "_" in timestamp: |
|
date_part = timestamp[:8] |
|
time_part = timestamp[9:15] |
|
|
|
year = date_part[:4] |
|
month = date_part[4:6] |
|
day = date_part[6:8] |
|
hour = time_part[:2] |
|
minute = time_part[2:4] |
|
|
|
month_names = ["", "Jan", "Feb", "Mar", "Apr", "May", "Jun", |
|
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] |
|
month_name = month_names[int(month)] if 1 <= int(month) <= 12 else month |
|
|
|
display_name = f"{month_name} {int(day)}, {year} {hour}:{minute}" |
|
else: |
|
display_name = timestamp |
|
except: |
|
display_name = timestamp |
|
|
|
return { |
|
'name': display_name, |
|
'timestamp': timestamp, |
|
'rgb_path': rgb_path, |
|
'segmentation_path': seg_path, |
|
'depth_path': depth_path if depth_path.exists() else None, |
|
'pointcloud_ply_path': ply_path if ply_path.exists() else None, |
|
'pointcloud_html_path': html_path if html_path.exists() else None, |
|
'preview_image': self._create_preview_image(rgb_path, timestamp) |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error loading example {timestamp}: {e}") |
|
return None |
|
|
|
def _create_preview_image(self, rgb_path: Path, timestamp: str) -> Image.Image: |
|
"""Create a preview thumbnail from RGB input image.""" |
|
try: |
|
image = Image.open(rgb_path) |
|
image.thumbnail((600, 450), Image.Resampling.LANCZOS) |
|
return image |
|
|
|
except Exception as e: |
|
print(f"Error creating preview for {timestamp}: {e}") |
|
return Image.new('RGB', (200, 150), color=(128, 128, 128)) |
|
|
|
def get_example_names(self) -> List[str]: |
|
"""Get list of available example names.""" |
|
return [data['name'] for data in self.examples.values()] |
|
|
|
def get_example_previews(self) -> List[Tuple[Image.Image, str]]: |
|
"""Get preview images for all examples.""" |
|
previews = [] |
|
for timestamp, data in self.examples.items(): |
|
previews.append((data['preview_image'], data['name'])) |
|
return previews |
|
|
|
def get_timestamp_by_name(self, name: str) -> Optional[str]: |
|
"""Get timestamp by display name.""" |
|
for timestamp, data in self.examples.items(): |
|
if data['name'] == name: |
|
return timestamp |
|
return None |
|
|
|
def load_example_results(self, example_name: str) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[go.Figure], str]: |
|
"""Load precomputed results for an example.""" |
|
if not example_name: |
|
return None, None, None, "Please select an example." |
|
|
|
|
|
timestamp = self.get_timestamp_by_name(example_name) |
|
if not timestamp or timestamp not in self.examples: |
|
return None, None, None, f"Example '{example_name}' not found." |
|
|
|
example_data = self.examples[timestamp] |
|
|
|
try: |
|
|
|
segmentation_image = Image.open(example_data['segmentation_path']) |
|
|
|
depth_image = None |
|
if example_data['depth_path'] and example_data['depth_path'].exists(): |
|
depth_image = Image.open(example_data['depth_path']) |
|
|
|
|
|
point_cloud_fig = None |
|
if example_data['pointcloud_ply_path'] and example_data['pointcloud_ply_path'].exists(): |
|
try: |
|
pcd = o3d.io.read_point_cloud(str(example_data['pointcloud_ply_path'])) |
|
if len(pcd.points) > 0: |
|
point_cloud_fig = create_plotly_pointcloud(pcd, downsample_factor=1) |
|
else: |
|
print(f"Point cloud file {example_data['pointcloud_ply_path']} is empty") |
|
except Exception as e: |
|
print(f"Error loading point cloud: {e}") |
|
|
|
return segmentation_image, depth_image, point_cloud_fig, "" |
|
|
|
except Exception as e: |
|
return None, None, None, f"Error loading example results: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
class EnhancedSingleViewApp: |
|
def __init__(self): |
|
|
|
self.oneformer_config = ModelConfig( |
|
model_name="shi-labs/oneformer_cityscapes_swin_large", |
|
processor_name="shi-labs/oneformer_cityscapes_swin_large", |
|
task_type="semantic" |
|
) |
|
|
|
self.depth_config = DepthConfig( |
|
encoder="vitl", |
|
dataset="vkitti", |
|
max_depth=80, |
|
weights_path="depth_anything_v2_metric_vkitti_vitl.pth" |
|
) |
|
|
|
|
|
self.oneformer_model = None |
|
self.depth_model = None |
|
self.segmentation_loaded = False |
|
self.depth_loaded = False |
|
|
|
|
|
self.examples_manager = PrecomputedExamplesManager(OUTPUT_DIR) |
|
|
|
|
|
self.sample_images = { |
|
"Street Scene 1": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=800", |
|
"Street Scene 2": "https://images.unsplash.com/photo-1502920917128-1aa500764cbd?w=800", |
|
"Urban Road": "https://images.unsplash.com/photo-1516738901171-8eb4fc13bd20?w=800", |
|
"City View": "https://images.unsplash.com/photo-1477959858617-67f85cf4f1df?w=800", |
|
"Highway": "https://images.unsplash.com/photo-1544620347-c4fd4a3d5957?w=800", |
|
} |
|
|
|
def download_sample_image(self, image_url: str) -> Image.Image: |
|
"""Download a sample image from URL.""" |
|
try: |
|
response = requests.get(image_url, timeout=10) |
|
response.raise_for_status() |
|
return Image.open(io.BytesIO(response.content)).convert('RGB') |
|
except Exception as e: |
|
print(f"Error downloading image: {e}") |
|
return Image.new('RGB', (800, 600), color=(128, 128, 128)) |
|
|
|
def create_overlay_plot(self, rgb_image: Image.Image, semantic_map: np.ndarray, alpha: float = 0.5): |
|
"""Create segmentation overlay plot without title and borders.""" |
|
rgb_array = np.array(rgb_image) |
|
color_map = get_color_map(labels) |
|
colored_semantic_map = apply_color_map(semantic_map, color_map) |
|
|
|
|
|
height, width = rgb_array.shape[:2] |
|
dpi = 100 |
|
fig, ax = plt.subplots(1, 1, figsize=(width/dpi, height/dpi), dpi=dpi) |
|
|
|
|
|
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
|
|
ax.imshow(rgb_array) |
|
ax.imshow(colored_semantic_map, alpha=alpha) |
|
ax.axis('off') |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=dpi) |
|
buf.seek(0) |
|
plt.close(fig) |
|
|
|
return Image.open(buf) |
|
|
|
def process_complete_pipeline(self, image: Image.Image): |
|
"""Process image through complete pipeline: segmentation + depth + point cloud.""" |
|
if image is None: |
|
return None, None, None, "Please upload an image." |
|
|
|
|
|
overlay_alpha = 0.5 |
|
depth_colormap = "magma" |
|
downsample_factor = 0.1 |
|
|
|
try: |
|
|
|
if not self.segmentation_loaded: |
|
if self.oneformer_model is None: |
|
self.oneformer_model = OneFormerModel(self.oneformer_config) |
|
self.oneformer_model.load_model() |
|
self.segmentation_loaded = True |
|
|
|
if not self.depth_loaded and DEPTH_AVAILABLE: |
|
if self.depth_model is None: |
|
self.depth_model = DepthAnythingV2Model(self.depth_config) |
|
self.depth_model.load_model() |
|
self.depth_loaded = True |
|
|
|
|
|
original_size = image.size |
|
if max(image.size) > 1024: |
|
image.thumbnail((1024, 1024), Image.Resampling.LANCZOS) |
|
|
|
|
|
task_inputs = ["semantic"] |
|
semantic_map = self.oneformer_model.segment_image(image, task_inputs=task_inputs) |
|
segmentation_overlay = self.create_overlay_plot(image, semantic_map, overlay_alpha) |
|
|
|
|
|
depth_vis = None |
|
point_cloud_fig = None |
|
pcd = None |
|
|
|
if DEPTH_AVAILABLE and self.depth_loaded: |
|
depth_map = self.depth_model.estimate_depth(image) |
|
depth_vis = create_depth_visualization(depth_map, depth_colormap) |
|
|
|
|
|
pcd = depth_to_point_cloud_with_segmentation(depth_map, image, semantic_map) |
|
point_cloud_fig = create_plotly_pointcloud(pcd, downsample_factor) |
|
|
|
|
|
unique_classes = np.unique(semantic_map) |
|
class_info = [] |
|
total_pixels = semantic_map.size |
|
|
|
for class_id in unique_classes: |
|
if class_id < len(labels) and class_id != 255: |
|
label = labels[class_id] |
|
pixel_count = np.sum(semantic_map == class_id) |
|
percentage = (pixel_count / total_pixels) * 100 |
|
if percentage > 0.1: |
|
class_info.append(f"- {label.name}: {percentage:.1f}%") |
|
|
|
|
|
if point_cloud_fig is not None: |
|
num_points = len(pcd.points) |
|
downsampled_points = int(num_points * downsample_factor) |
|
point_cloud_info = f""" |
|
3D Point Cloud: |
|
- Total points: {num_points:,} |
|
- Displayed points: {downsampled_points:,} ({downsample_factor*100:.0f}%) |
|
- Sky points excluded |
|
- Colors match segmentation classes""" |
|
else: |
|
point_cloud_info = "Point cloud not available (DepthAnythingV2 required)" |
|
|
|
|
|
if depth_vis is not None and DEPTH_AVAILABLE: |
|
depth_stats = { |
|
'min': np.min(depth_map), |
|
'max': np.max(depth_map), |
|
'mean': np.mean(depth_map), |
|
'std': np.std(depth_map) |
|
} |
|
depth_info = f""" |
|
Depth Estimation: |
|
- Min depth: {depth_stats['min']:.2f}m |
|
- Max depth: {depth_stats['max']:.2f}m |
|
- Mean depth: {depth_stats['mean']:.2f}m |
|
- Std deviation: {depth_stats['std']:.2f}m |
|
- Colormap: {depth_colormap}""" |
|
else: |
|
depth_info = "Depth estimation not available" |
|
|
|
info_text = f"""Complete vision pipeline processed successfully! |
|
|
|
Models Used: |
|
- OneFormer (Semantic Segmentation) |
|
{f"- DepthAnythingV2 ({self.depth_config.encoder.upper()})" if DEPTH_AVAILABLE else "- DepthAnythingV2 (Not Available)"} |
|
|
|
Image Processing: |
|
- Original size: {original_size[0]}x{original_size[1]} |
|
- Processed size: {image.size[0]}x{image.size[1]} |
|
- Overlay transparency: {overlay_alpha:.1f} |
|
|
|
Detected Classes: |
|
{chr(10).join(class_info)} |
|
{depth_info} |
|
{point_cloud_info} |
|
|
|
The point cloud shows 3D structure with each point colored according to its segmentation class. Sky points are excluded for better visualization.""" |
|
|
|
return segmentation_overlay, depth_vis, point_cloud_fig, info_text |
|
|
|
except Exception as e: |
|
return None, None, None, f"Error processing pipeline: {str(e)}" |
|
|
|
|
|
app = EnhancedSingleViewApp() |
|
|
|
def process_uploaded_image(image): |
|
try: |
|
return app.process_complete_pipeline(image) |
|
except: |
|
return None, None, None |
|
|
|
def process_sample_image(sample_choice): |
|
"""Process sample image through complete pipeline.""" |
|
if sample_choice and sample_choice in app.sample_images: |
|
image_url = app.sample_images[sample_choice] |
|
image = app.download_sample_image(image_url) |
|
return app.process_complete_pipeline(image) |
|
return None, None, None, "Please select a sample image." |
|
|
|
def load_precomputed_example(evt: gr.SelectData): |
|
"""Load precomputed example results from gallery selection.""" |
|
if evt.index is not None: |
|
example_names = app.examples_manager.get_example_names() |
|
if evt.index < len(example_names): |
|
example_name = example_names[evt.index] |
|
seg_image, depth_image, pc_fig, info_text = app.examples_manager.load_example_results(example_name) |
|
return seg_image, depth_image, pc_fig |
|
return None, None, None |
|
|
|
def get_example_previews(): |
|
"""Get preview images for the gallery.""" |
|
previews = app.examples_manager.get_example_previews() |
|
if not previews: |
|
return [] |
|
return previews |
|
|
|
|
|
|
|
|
|
|
|
def create_gradio_interface(): |
|
"""Create and return the enhanced single-view Gradio interface.""" |
|
|
|
with gr.Blocks( |
|
title="Enhanced Computer Vision Pipeline", |
|
theme=gr.themes.Default() |
|
) as demo: |
|
|
|
gr.Markdown(""" |
|
# Street Scene 3D Reconstruction |
|
|
|
Upload an image or select an example to see: |
|
- **Semantic Segmentation** - Identify roads, buildings, vehicles, people, and other scene elements |
|
- **Depth Estimation** - Generate metric depth maps showing distance to objects |
|
- **3D Point Cloud** - Interactive 3D reconstruction with semantic colors) |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
if CUDA_AVAILABLE: |
|
gr.Markdown("### Upload Image") |
|
|
|
uploaded_image = gr.Image( |
|
type="pil", |
|
label="Upload Image" |
|
) |
|
upload_btn = gr.Button("Process Image", variant="primary", size="lg") |
|
else: |
|
uploaded_image = gr.Image(visible=False) |
|
upload_btn = gr.Button(visible=False) |
|
|
|
gr.Markdown("### CPU Mode") |
|
gr.Markdown("⚠️ **Upload disabled**: DepthAnythingV2 requires CUDA. Using precomputed examples only.") |
|
|
|
gr.Markdown("### Examples") |
|
gr.Markdown("Click on an image to load the example:") |
|
|
|
|
|
example_gallery = gr.Gallery( |
|
value=get_example_previews(), |
|
label="Example Images", |
|
show_label=False, |
|
elem_id="example_gallery", |
|
columns=2, |
|
rows=3, |
|
height="auto", |
|
object_fit="cover" |
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("### Results") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("#### Semantic Segmentation") |
|
segmentation_output = gr.Image(label="Segmentation Overlay") |
|
|
|
with gr.Column(): |
|
gr.Markdown("#### Depth Estimation") |
|
depth_output = gr.Image(label="Depth Map") |
|
|
|
|
|
gr.Markdown("#### 3D Point Cloud") |
|
pointcloud_output = gr.Plot(label="Interactive 3D Point Cloud (Colored by Segmentation)") |
|
|
|
if CUDA_AVAILABLE: |
|
upload_btn.click( |
|
fn=process_uploaded_image, |
|
inputs=[uploaded_image], |
|
outputs=[segmentation_output, depth_output, pointcloud_output] |
|
) |
|
|
|
|
|
example_gallery.select( |
|
fn=load_precomputed_example, |
|
outputs=[segmentation_output, depth_output, pointcloud_output] |
|
) |
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
fix_lfs_on_startup() |
|
|
|
demo = create_gradio_interface() |
|
|
|
print("Starting Enhanced Single-View Computer Vision App...") |
|
print("Complete Pipeline: Segmentation + Depth + 3D Point Cloud") |
|
print("Device:", "CUDA" if torch.cuda.is_available() else "CPU") |
|
print("Depth Available:", "YES" if DEPTH_AVAILABLE else "NO") |
|
print("Point Cloud Colors: Segmentation-based (Sky Excluded)") |
|
print(f"Output Directory: {OUTPUT_DIR.absolute()}") |
|
print(f"Available Examples: {len(app.examples_manager.examples)}") |
|
|
|
|
|
demo.launch( |
|
share=True, |
|
debug=True, |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
quiet=False |
|
) |