diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2e95aae603a79908ef5e0cca65da45d41c2bc975
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+*.pth
+.venv/
+outputs/pointclouds/*.html
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..574d4fcc7348b46e667d92fc1ecdbe6505a01947
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,64 @@
+# Use NVIDIA CUDA base image for GPU support
+FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
+
+# Set environment variables
+ENV DEBIAN_FRONTEND=noninteractive
+ENV PYTHONUNBUFFERED=1
+ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX"
+
+# Install system dependencies
+RUN apt-get update && apt-get install -y \
+ python3 \
+ python3-pip \
+ python3-dev \
+ git \
+ wget \
+ curl \
+ build-essential \
+ cmake \
+ libgl1-mesa-glx \
+ libglib2.0-0 \
+ libsm6 \
+ libxext6 \
+ libxrender-dev \
+ libgomp1 \
+ libgcc-s1 \
+ && rm -rf /var/lib/apt/lists/*
+
+# Create working directory
+WORKDIR /app
+
+# Copy requirements first (for better Docker layer caching)
+COPY requirements.txt .
+
+# Install Python dependencies
+RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel
+RUN pip3 install --no-cache-dir -r requirements.txt
+
+
+
+
+
+# Go back to app directory
+WORKDIR /app
+
+# Copy the application code
+COPY . .
+
+# Set up DepthAnythingV2
+#WORKDIR /app/Depth-Anything-V2
+#RUN pip3 install -e .
+#WORKDIR /app
+
+# Create directories for models and cache
+RUN mkdir -p /app/models /root/.cache
+
+# Download DepthAnythingV2 weights (you can add this step or mount as volume)
+# Uncomment the line below if you want to download weights during build
+# RUN wget -O depth_anything_v2_metric_vkitti_vitl.pth https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Small/resolve/main/depth_anything_v2_metric_vkitti_vitl.pth
+
+# Expose the port
+EXPOSE 7860
+
+# Set the entry point
+CMD ["python3", "enhanced_app.py"]
\ No newline at end of file
diff --git a/README.md b/README.md
index fcc9d862d8f1c2079d1b5e5ad361953cf343e1c7..749e40d7ff802d807e5640c35c586b8e952d18a2 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@ colorFrom: blue
colorTo: green
sdk: gradio
sdk_version: 5.36.2
-app_file: app.py
+app_file: enhanced_app.py
pinned: false
license: apache-2.0
short_description: 'Road Scene Sensing: Semantic Segmentation & Depth Estimation'
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..79c29b328c2f0ec0714096357f7a3ae09ef38f48
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,23 @@
+version: '3.8'
+
+services:
+ segmentation-app:
+ build: .
+ ports:
+ - "7861:7860"
+ volumes:
+ - models_cache:/root/.cache
+ - ./outputs:/app/outputs # Map host outputs directory to container outputs
+ environment:
+ - NVIDIA_VISIBLE_DEVICES=all
+ - NVIDIA_DRIVER_CAPABILITIES=compute,utility
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: all
+ capabilities: [gpu]
+
+volumes:
+ models_cache:
\ No newline at end of file
diff --git a/enhanced_app.py b/enhanced_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d556d4ba7e0dc2f96ddb456463aad44db0af2d9
--- /dev/null
+++ b/enhanced_app.py
@@ -0,0 +1,960 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Enhanced Single-View Gradio App for Semantic Segmentation, Depth Estimation, and 3D Point Cloud
+Processes one image and shows all outputs: segmentation, depth, and colored point cloud
+Now with precomputed examples for demonstration
+"""
+
+import sys
+import locale
+import os
+import datetime
+from pathlib import Path
+
+# Set UTF-8 encoding
+if sys.version_info >= (3, 7):
+ sys.stdout.reconfigure(encoding='utf-8')
+ sys.stderr.reconfigure(encoding='utf-8')
+
+# Set locale for proper Unicode support
+try:
+ locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
+except locale.Error:
+ try:
+ locale.setlocale(locale.LC_ALL, 'C.UTF-8')
+ except locale.Error:
+ pass # Use system default
+
+import gradio as gr
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+from PIL import Image
+import io
+import base64
+from dataclasses import dataclass
+from typing import Optional, List, Tuple, Dict, Any
+import requests
+import cv2
+from abc import ABC, abstractmethod
+from collections import namedtuple
+import plotly.graph_objects as go
+import plotly.io as pio
+import open3d as o3d
+import json
+
+# Import DepthAnythingV2 (assuming it's in the same directory or installed)
+try:
+ from metric_depth.depth_anything_v2.dpt import DepthAnythingV2
+ DEPTH_AVAILABLE = True
+except ImportError:
+ print("DepthAnythingV2 not available. Using precomputed examples only.")
+ DEPTH_AVAILABLE = False
+
+# Set environment variable to disable xFormers
+os.environ['XFORMERS_DISABLED'] = '1'
+os.environ['XFORMERS_MORE_DETAILS'] = '1'
+
+# Output directory structure (mounted volume)
+OUTPUT_DIR = Path("outputs")
+
+# =============================================================================
+# Model Base Classes and Configurations
+# =============================================================================
+
+@dataclass
+class ModelConfig:
+ """Configuration for segmentation models."""
+ model_name: str
+ processor_name: str
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
+ trust_remote_code: bool = True
+ task_type: str = "semantic"
+
+@dataclass
+class DepthConfig:
+ """Configuration for depth estimation models."""
+ encoder: str = "vitl" # 'vits', 'vitb', 'vitl'
+ dataset: str = "vkitti" # 'hypersim' for indoor, 'vkitti' for outdoor
+ max_depth: int = 80 # 20 for indoor, 80 for outdoor
+ weights_path: str = "depth_anything_v2_metric_vkitti_vitl.pth"
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
+
+class BaseSegmentationModel(ABC):
+ """Abstract base class for segmentation models."""
+
+ def __init__(self, model_config):
+ self.config = model_config
+ self.model = None
+ self.processor = None
+ self.device = torch.device(model_config.device if torch.cuda.is_available() else "cpu")
+
+ @abstractmethod
+ def load_model(self):
+ """Load the model and processor."""
+ pass
+
+ @abstractmethod
+ def preprocess(self, image: Image.Image, **kwargs) -> Dict[str, torch.Tensor]:
+ """Preprocess the input image."""
+ pass
+
+ @abstractmethod
+ def predict(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Run inference on preprocessed inputs."""
+ pass
+
+ @abstractmethod
+ def postprocess(self, outputs: Dict[str, torch.Tensor], target_size: Tuple[int, int]) -> np.ndarray:
+ """Postprocess model outputs to segmentation map."""
+ pass
+
+ def segment_image(self, image: Image.Image, **kwargs) -> np.ndarray:
+ """End-to-end segmentation pipeline."""
+ if self.model is None:
+ self.load_model()
+
+ inputs = self.preprocess(image, **kwargs)
+ outputs = self.predict(inputs)
+ segmentation_map = self.postprocess(outputs, image.size[::-1])
+
+ return segmentation_map
+
+# =============================================================================
+# OneFormer Model Implementation
+# =============================================================================
+
+class OneFormerModel(BaseSegmentationModel):
+ """OneFormer model for universal segmentation."""
+
+ def __init__(self, model_config):
+ super().__init__(model_config)
+
+ def load_model(self):
+ """Load OneFormer model and processor."""
+ print(f"Loading OneFormer model: {self.config.model_name}")
+
+ try:
+ from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
+
+ self.processor = OneFormerProcessor.from_pretrained(
+ self.config.processor_name,
+ trust_remote_code=self.config.trust_remote_code
+ )
+
+ self.model = OneFormerForUniversalSegmentation.from_pretrained(
+ self.config.model_name,
+ trust_remote_code=self.config.trust_remote_code
+ )
+
+ self.model.to(self.device)
+ self.model.eval()
+
+ print(f"OneFormer model loaded successfully on {self.device}")
+
+ except Exception as e:
+ print(f"Error loading OneFormer model: {e}")
+ raise
+
+ def preprocess(self, image: Image.Image, task_inputs: List[str] = None) -> Dict[str, torch.Tensor]:
+ """Preprocess image for OneFormer."""
+ if task_inputs is None:
+ task_inputs = [self.config.task_type]
+
+ inputs = self.processor(
+ images=image,
+ task_inputs=task_inputs,
+ return_tensors="pt"
+ )
+
+ # Move inputs to device
+ inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
+ for k, v in inputs.items()}
+
+ return inputs
+
+ def predict(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Run inference with OneFormer."""
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ return outputs
+
+ def postprocess(self, outputs: Dict[str, torch.Tensor], target_size: Tuple[int, int]) -> np.ndarray:
+ """Postprocess OneFormer outputs."""
+ predicted_semantic_map = self.processor.post_process_semantic_segmentation(
+ outputs,
+ target_sizes=[target_size]
+ )[0]
+
+ return predicted_semantic_map.cpu().numpy()
+
+# =============================================================================
+# DepthAnythingV2 Model Implementation
+# =============================================================================
+
+class DepthAnythingV2Model:
+ """DepthAnythingV2 model for depth estimation."""
+
+ def __init__(self, depth_config: DepthConfig):
+ self.config = depth_config
+ self.model = None
+ self.device = torch.device(depth_config.device if torch.cuda.is_available() else "cpu")
+
+ def load_model(self):
+ """Load DepthAnythingV2 model."""
+ if not DEPTH_AVAILABLE:
+ raise ImportError("DepthAnythingV2 is not available")
+
+ print(f"Loading DepthAnythingV2 model: {self.config.encoder}")
+
+ try:
+ model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
+ }
+
+ self.model = DepthAnythingV2(**{**model_configs[self.config.encoder], 'max_depth': self.config.max_depth})
+
+ # Load weights
+ if os.path.exists(self.config.weights_path):
+ self.model.load_state_dict(torch.load(self.config.weights_path, map_location='cpu'))
+ print(f"Loaded weights from {self.config.weights_path}")
+ else:
+ print(f"Warning: Weights file {self.config.weights_path} not found")
+
+ self.model.to(self.device)
+ self.model.eval()
+
+ print(f"DepthAnythingV2 model loaded successfully on {self.device}")
+
+ except Exception as e:
+ print(f"Error loading DepthAnythingV2 model: {e}")
+ raise
+
+ def estimate_depth(self, image: Image.Image) -> np.ndarray:
+ """Estimate depth from image."""
+ if self.model is None:
+ self.load_model()
+
+ # Convert PIL to OpenCV format
+ img_array = np.array(image)
+ if len(img_array.shape) == 3:
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
+
+ # Infer depth
+ depth_map = self.model.infer_image(img_array)
+
+ return depth_map
+
+# =============================================================================
+# Cityscapes Label Definitions
+# =============================================================================
+
+Label = namedtuple('Label', [
+ 'name', 'id', 'trainId', 'category', 'categoryId',
+ 'hasInstances', 'ignoreInEval', 'color'
+])
+
+labels = [
+ Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
+ Label('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
+ Label('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
+ Label('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
+ Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
+ Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
+ Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
+ Label('road', 7, 0, 'flat', 1, False, False, (128, 64,128)),
+ Label('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35,232)),
+ Label('parking', 9, 255, 'flat', 1, False, True, (250,170,160)),
+ Label('rail track', 10, 255, 'flat', 1, False, True, (230,150,140)),
+ Label('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
+ Label('wall', 12, 3, 'construction', 2, False, False, (102,102,156)),
+ Label('fence', 13, 4, 'construction', 2, False, False, (190,153,153)),
+ Label('guard rail', 14, 255, 'construction', 2, False, True, (180,165,180)),
+ Label('bridge', 15, 255, 'construction', 2, False, True, (150,100,100)),
+ Label('tunnel', 16, 255, 'construction', 2, False, True, (150,120, 90)),
+ Label('pole', 17, 5, 'object', 3, False, False, (153,153,153)),
+ Label('polegroup', 18, 255, 'object', 3, False, True, (153,153,153)),
+ Label('traffic light', 19, 6, 'object', 3, False, False, (250,170, 30)),
+ Label('traffic sign', 20, 7, 'object', 3, False, False, (220,220, 0)),
+ Label('vegetation', 21, 8, 'nature', 4, False, False, (107,142, 35)),
+ Label('terrain', 22, 9, 'nature', 4, False, False, (152,251,152)),
+ Label('sky', 23, 10, 'sky', 5, False, False, (70,130,180)),
+ Label('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
+ Label('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
+ Label('car', 26, 13, 'vehicle', 7, True, False, (0, 0,142)),
+ Label('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
+ Label('bus', 28, 15, 'vehicle', 7, True, False, (0, 60,100)),
+ Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
+ Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0,110)),
+ Label('train', 31, 16, 'vehicle', 7, True, False, (0, 80,100)),
+ Label('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0,230)),
+ Label('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
+ Label('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0,142)),
+]
+
+# Sky trainId is 10
+SKY_TRAIN_ID = 10
+
+# =============================================================================
+# Utility Functions
+# =============================================================================
+
+def get_color_map(labels):
+ """Returns a color map dictionary for the given labels."""
+ color_map = {label.trainId: label.color for label in labels if label.trainId != 255}
+ return color_map
+
+def apply_color_map(semantic_map, color_map):
+ """Applies a color map to a semantic map."""
+ height, width = semantic_map.shape
+ color_mapped_image = np.zeros((height, width, 3), dtype=np.uint8)
+
+ for trainId, color in color_map.items():
+ mask = semantic_map == trainId
+ color_mapped_image[mask] = color
+
+ return color_mapped_image
+
+def create_depth_visualization(depth_map: np.ndarray, colormap: str = 'magma') -> Image.Image:
+ """Create a colored depth map visualization with exact dimensions."""
+ # Normalize depth map to [0, 1]
+ normalized_depth = depth_map / np.max(depth_map)
+
+ # Apply colormap
+ cmap = plt.get_cmap(colormap)
+ colored_depth = cmap(normalized_depth)
+
+ # Convert to 8-bit RGB (remove alpha channel)
+ colored_depth_8bit = (colored_depth[:, :, :3] * 255).astype(np.uint8)
+
+ return Image.fromarray(colored_depth_8bit)
+
+def depth_to_point_cloud_with_segmentation(depth_map: np.ndarray, rgb_image: Image.Image,
+ semantic_map: np.ndarray,
+ fx: float = 525.0, fy: float = 525.0,
+ cx: float = None, cy: float = None) -> o3d.geometry.PointCloud:
+ """Convert depth map and RGB image to 3D point cloud with segmentation colors, excluding sky."""
+ height, width = depth_map.shape
+
+ if cx is None:
+ cx = width / 2.0
+ if cy is None:
+ cy = height / 2.0
+
+ # Create coordinate matrices
+ u, v = np.meshgrid(np.arange(width), np.arange(height))
+
+ # Convert to 3D coordinates
+ z = depth_map
+ x = (u - cx) * z / fx
+ y = (v - cy) * z / fy
+
+ # Stack coordinates
+ points = np.stack([x, y, z], axis=-1).reshape(-1, 3)
+
+ # Create mask to exclude sky points and invalid depths
+ flat_semantic = semantic_map.flatten()
+ flat_depth = z.flatten()
+
+ # Filter out invalid points and sky points
+ valid_mask = (flat_depth > 0) & (flat_depth < 1000) & (flat_semantic != SKY_TRAIN_ID)
+ points = points[valid_mask]
+
+ # Get segmentation colors for each point
+ color_map = get_color_map(labels)
+ seg_colors = np.zeros((len(flat_semantic), 3))
+
+ for trainId, color in color_map.items():
+ mask = flat_semantic == trainId
+ seg_colors[mask] = color
+
+ # Filter colors to match valid points
+ colors = seg_colors[valid_mask] / 255.0 # Normalize to [0, 1]
+
+ # Create Open3D point cloud
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+
+ return pcd
+
+def create_plotly_pointcloud(pcd: o3d.geometry.PointCloud, downsample_factor: float = 0.1) -> go.Figure:
+ """Create interactive Plotly 3D point cloud visualization."""
+ # Downsample for performance
+ if downsample_factor < 1.0:
+ num_points = len(pcd.points)
+ indices = np.random.choice(num_points, int(num_points * downsample_factor), replace=False)
+ points = np.asarray(pcd.points)[indices]
+ colors = np.asarray(pcd.colors)[indices]
+ else:
+ points = np.asarray(pcd.points)
+ colors = np.asarray(pcd.colors)
+
+ # Create 3D scatter plot
+ fig = go.Figure(data=[go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 1],
+ z=points[:, 2],
+ mode='markers',
+ marker=dict(
+ size=1,
+ color=colors,
+ opacity=0.8
+ ),
+ text=[f'Point {i}' for i in range(len(points))],
+ hovertemplate='X: %{x:.2f}
Y: %{y:.2f}
Z: %{z:.2f}'
+ )])
+
+ # Update layout for centered display
+ fig.update_layout(
+ scene=dict(
+ xaxis_title='X (Horizontal)',
+ yaxis_title='Y (Vertical)',
+ zaxis_title='Z (Depth)',
+ aspectmode='data'
+ ),
+ title={
+ 'text': 'Interactive 3D Point Cloud (Colored by Segmentation, Sky Excluded)',
+ 'x': 0.5,
+ 'xanchor': 'center'
+ },
+ width=None, # Let it auto-size to container
+ height=600,
+ margin=dict(l=0, r=0, t=40, b=0), # Minimal margins
+ autosize=True # Enable auto-sizing to container
+ )
+
+ # Set camera for bird's eye view that clearly shows 3D structure
+ fig.update_layout(scene_camera=dict(
+ up=dict(x=0, y=0, z=1), # Z-axis points up
+ center=dict(x=0, y=0, z=0), # Center at origin
+ eye=dict(x=0.5, y=-2.5, z=1.5) # View from above-back position
+ ))
+
+ return fig
+
+def create_overlay_plot(rgb_image: Image.Image, semantic_map: np.ndarray, alpha: float = 0.5):
+ """Create segmentation overlay plot without title and borders."""
+ rgb_array = np.array(rgb_image)
+ color_map = get_color_map(labels)
+ colored_semantic_map = apply_color_map(semantic_map, color_map)
+
+ # Create figure with exact image dimensions
+ height, width = rgb_array.shape[:2]
+ dpi = 100
+ fig, ax = plt.subplots(1, 1, figsize=(width/dpi, height/dpi), dpi=dpi)
+
+ # Remove all margins and padding
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
+
+ ax.imshow(rgb_array)
+ ax.imshow(colored_semantic_map, alpha=alpha)
+ ax.axis('off')
+
+ buf = io.BytesIO()
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=dpi)
+ buf.seek(0)
+ plt.close(fig)
+
+ return Image.open(buf)
+
+class PrecomputedExamplesManager:
+ """Manages precomputed examples from output folder structure."""
+
+ def __init__(self, output_dir: Path):
+ self.output_dir = output_dir
+ self.rgb_dir = output_dir / "rgb"
+ self.segmentation_dir = output_dir / "segmentation"
+ self.depth_dir = output_dir / "depth"
+ self.pointclouds_dir = output_dir / "pointclouds"
+ self.examples = self._load_examples()
+
+ def _load_examples(self) -> Dict[str, Dict]:
+ """Load all available precomputed examples from output structure."""
+ examples = {}
+
+ if not self.output_dir.exists():
+ print(f"Output directory {self.output_dir} not found.")
+ return {}
+
+ # Find all timestamps by looking at RGB files (the inputs)
+ if not self.rgb_dir.exists():
+ print(f"RGB directory {self.rgb_dir} not found.")
+ return {}
+
+ # Get all RGB files and extract timestamps
+ timestamps = set()
+ for rgb_file in self.rgb_dir.glob("rgb_*.png"):
+ # Extract timestamp from filename like "rgb_20241215_143022.png"
+ filename = rgb_file.stem
+ if filename.startswith("rgb_"):
+ timestamp = filename.replace("rgb_", "")
+ timestamps.add(timestamp)
+
+ print(f"Found {len(timestamps)} RGB input images")
+
+ # For each timestamp, try to load the complete example
+ for timestamp in sorted(timestamps, reverse=True): # Most recent first
+ example_data = self._load_single_example(timestamp)
+ if example_data:
+ examples[timestamp] = example_data
+
+ print(f"Loaded {len(examples)} precomputed examples from output directory")
+ return examples
+
+ def _load_single_example(self, timestamp: str) -> Optional[Dict]:
+ """Load a single precomputed example by timestamp."""
+ try:
+ # Input file (required)
+ rgb_path = self.rgb_dir / f"rgb_{timestamp}.png"
+
+ # Output files (some may be optional)
+ seg_path = self.segmentation_dir / f"segmentation_{timestamp}.png"
+ depth_path = self.depth_dir / f"depth_{timestamp}.png"
+ ply_path = self.pointclouds_dir / f"pointcloud_{timestamp}.ply"
+ html_path = self.pointclouds_dir / f"pointcloud_{timestamp}.html"
+
+ # Check if RGB input exists (required)
+ if not rgb_path.exists():
+ print(f"RGB input file missing for timestamp {timestamp}: {rgb_path}")
+ return None
+
+ # Check if at least segmentation output exists
+ if not seg_path.exists():
+ print(f"Segmentation output missing for timestamp {timestamp}: {seg_path}")
+ return None
+
+ # Create a display name from timestamp
+ try:
+ # Parse timestamp like "20241215_143022"
+ if len(timestamp) >= 13 and "_" in timestamp:
+ date_part = timestamp[:8]
+ time_part = timestamp[9:15]
+ # Format as "Dec 15, 2024 14:30"
+ year = date_part[:4]
+ month = date_part[4:6]
+ day = date_part[6:8]
+ hour = time_part[:2]
+ minute = time_part[2:4]
+
+ month_names = ["", "Jan", "Feb", "Mar", "Apr", "May", "Jun",
+ "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
+ month_name = month_names[int(month)] if 1 <= int(month) <= 12 else month
+
+ display_name = f"{month_name} {int(day)}, {year} {hour}:{minute}"
+ else:
+ display_name = timestamp
+ except:
+ display_name = timestamp
+
+ return {
+ 'name': display_name,
+ 'timestamp': timestamp,
+ 'rgb_path': rgb_path, # Input image
+ 'segmentation_path': seg_path, # Output
+ 'depth_path': depth_path if depth_path.exists() else None, # Output (optional)
+ 'pointcloud_ply_path': ply_path if ply_path.exists() else None, # Output (optional)
+ 'pointcloud_html_path': html_path if html_path.exists() else None, # Output (optional)
+ 'preview_image': self._create_preview_image(rgb_path, timestamp)
+ }
+
+ except Exception as e:
+ print(f"Error loading example {timestamp}: {e}")
+ return None
+
+ def _create_preview_image(self, rgb_path: Path, timestamp: str) -> Image.Image:
+ """Create a preview thumbnail from RGB input image."""
+ try:
+ image = Image.open(rgb_path)
+ image.thumbnail((200, 150), Image.Resampling.LANCZOS)
+ return image
+
+ except Exception as e:
+ print(f"Error creating preview for {timestamp}: {e}")
+ return Image.new('RGB', (200, 150), color=(128, 128, 128))
+
+ def get_example_names(self) -> List[str]:
+ """Get list of available example names."""
+ return [data['name'] for data in self.examples.values()]
+
+ def get_example_previews(self) -> List[Tuple[Image.Image, str]]:
+ """Get preview images for all examples."""
+ previews = []
+ for timestamp, data in self.examples.items():
+ previews.append((data['preview_image'], data['name']))
+ return previews
+
+ def get_timestamp_by_name(self, name: str) -> Optional[str]:
+ """Get timestamp by display name."""
+ for timestamp, data in self.examples.items():
+ if data['name'] == name:
+ return timestamp
+ return None
+
+ def load_example_results(self, example_name: str) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[go.Figure], str]:
+ """Load precomputed results for an example."""
+ if not example_name:
+ return None, None, None, "Please select an example."
+
+ # Find the timestamp for this example name
+ timestamp = self.get_timestamp_by_name(example_name)
+ if not timestamp or timestamp not in self.examples:
+ return None, None, None, f"Example '{example_name}' not found."
+
+ example_data = self.examples[timestamp]
+
+ try:
+ # Load output images
+ segmentation_image = Image.open(example_data['segmentation_path'])
+
+ depth_image = None
+ if example_data['depth_path'] and example_data['depth_path'].exists():
+ depth_image = Image.open(example_data['depth_path'])
+
+ # Load point cloud if available
+ point_cloud_fig = None
+ if example_data['pointcloud_ply_path'] and example_data['pointcloud_ply_path'].exists():
+ try:
+ pcd = o3d.io.read_point_cloud(str(example_data['pointcloud_ply_path']))
+ if len(pcd.points) > 0:
+ point_cloud_fig = create_plotly_pointcloud(pcd, downsample_factor=1)
+ else:
+ print(f"Point cloud file {example_data['pointcloud_ply_path']} is empty")
+ except Exception as e:
+ print(f"Error loading point cloud: {e}")
+
+ return segmentation_image, depth_image, point_cloud_fig, ""
+
+ except Exception as e:
+ return None, None, None, f"Error loading example results: {str(e)}"
+
+# =============================================================================
+# Main Application Class
+# =============================================================================
+
+class EnhancedSingleViewApp:
+ def __init__(self):
+ # Model configurations
+ self.oneformer_config = ModelConfig(
+ model_name="shi-labs/oneformer_cityscapes_swin_large",
+ processor_name="shi-labs/oneformer_cityscapes_swin_large",
+ task_type="semantic"
+ )
+
+ self.depth_config = DepthConfig(
+ encoder="vitl",
+ dataset="vkitti",
+ max_depth=80,
+ weights_path="depth_anything_v2_metric_vkitti_vitl.pth"
+ )
+
+ # Models
+ self.oneformer_model = None
+ self.depth_model = None
+ self.segmentation_loaded = False
+ self.depth_loaded = False
+
+ # Precomputed examples manager
+ self.examples_manager = PrecomputedExamplesManager(OUTPUT_DIR)
+
+ # Online sample images (fallback)
+ self.sample_images = {
+ "Street Scene 1": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=800",
+ "Street Scene 2": "https://images.unsplash.com/photo-1502920917128-1aa500764cbd?w=800",
+ "Urban Road": "https://images.unsplash.com/photo-1516738901171-8eb4fc13bd20?w=800",
+ "City View": "https://images.unsplash.com/photo-1477959858617-67f85cf4f1df?w=800",
+ "Highway": "https://images.unsplash.com/photo-1544620347-c4fd4a3d5957?w=800",
+ }
+
+ def download_sample_image(self, image_url: str) -> Image.Image:
+ """Download a sample image from URL."""
+ try:
+ response = requests.get(image_url, timeout=10)
+ response.raise_for_status()
+ return Image.open(io.BytesIO(response.content)).convert('RGB')
+ except Exception as e:
+ print(f"Error downloading image: {e}")
+ return Image.new('RGB', (800, 600), color=(128, 128, 128))
+
+ def create_overlay_plot(self, rgb_image: Image.Image, semantic_map: np.ndarray, alpha: float = 0.5):
+ """Create segmentation overlay plot without title and borders."""
+ rgb_array = np.array(rgb_image)
+ color_map = get_color_map(labels)
+ colored_semantic_map = apply_color_map(semantic_map, color_map)
+
+ # Create figure with exact image dimensions
+ height, width = rgb_array.shape[:2]
+ dpi = 100
+ fig, ax = plt.subplots(1, 1, figsize=(width/dpi, height/dpi), dpi=dpi)
+
+ # Remove all margins and padding
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
+
+ ax.imshow(rgb_array)
+ ax.imshow(colored_semantic_map, alpha=alpha)
+ ax.axis('off')
+
+ buf = io.BytesIO()
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=dpi)
+ buf.seek(0)
+ plt.close(fig)
+
+ return Image.open(buf)
+
+ def process_complete_pipeline(self, image: Image.Image):
+ """Process image through complete pipeline: segmentation + depth + point cloud."""
+ if image is None:
+ return None, None, None, "Please upload an image."
+
+ # Default values
+ overlay_alpha = 0.5
+ depth_colormap = "magma"
+ downsample_factor = 0.1
+
+ try:
+ # Auto-load models if not loaded
+ if not self.segmentation_loaded:
+ if self.oneformer_model is None:
+ self.oneformer_model = OneFormerModel(self.oneformer_config)
+ self.oneformer_model.load_model()
+ self.segmentation_loaded = True
+
+ if not self.depth_loaded and DEPTH_AVAILABLE:
+ if self.depth_model is None:
+ self.depth_model = DepthAnythingV2Model(self.depth_config)
+ self.depth_model.load_model()
+ self.depth_loaded = True
+
+ # Resize if too large
+ original_size = image.size
+ if max(image.size) > 1024:
+ image.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
+
+ # Step 1: Semantic Segmentation
+ task_inputs = ["semantic"]
+ semantic_map = self.oneformer_model.segment_image(image, task_inputs=task_inputs)
+ segmentation_overlay = self.create_overlay_plot(image, semantic_map, overlay_alpha)
+
+ # Step 2: Depth Estimation (if available)
+ depth_vis = None
+ point_cloud_fig = None
+ pcd = None
+
+ if DEPTH_AVAILABLE and self.depth_loaded:
+ depth_map = self.depth_model.estimate_depth(image)
+ depth_vis = create_depth_visualization(depth_map, depth_colormap)
+
+ # Step 3: Point Cloud with Segmentation Colors
+ pcd = depth_to_point_cloud_with_segmentation(depth_map, image, semantic_map)
+ point_cloud_fig = create_plotly_pointcloud(pcd, downsample_factor)
+
+ # Generate comprehensive info
+ unique_classes = np.unique(semantic_map)
+ class_info = []
+ total_pixels = semantic_map.size
+
+ for class_id in unique_classes:
+ if class_id < len(labels) and class_id != 255:
+ label = labels[class_id]
+ pixel_count = np.sum(semantic_map == class_id)
+ percentage = (pixel_count / total_pixels) * 100
+ if percentage > 0.1:
+ class_info.append(f"- {label.name}: {percentage:.1f}%")
+
+ # Point cloud statistics
+ if point_cloud_fig is not None:
+ num_points = len(pcd.points)
+ downsampled_points = int(num_points * downsample_factor)
+ point_cloud_info = f"""
+3D Point Cloud:
+- Total points: {num_points:,}
+- Displayed points: {downsampled_points:,} ({downsample_factor*100:.0f}%)
+- Sky points excluded
+- Colors match segmentation classes"""
+ else:
+ point_cloud_info = "Point cloud not available (DepthAnythingV2 required)"
+
+ # Depth statistics
+ if depth_vis is not None and DEPTH_AVAILABLE:
+ depth_stats = {
+ 'min': np.min(depth_map),
+ 'max': np.max(depth_map),
+ 'mean': np.mean(depth_map),
+ 'std': np.std(depth_map)
+ }
+ depth_info = f"""
+Depth Estimation:
+- Min depth: {depth_stats['min']:.2f}m
+- Max depth: {depth_stats['max']:.2f}m
+- Mean depth: {depth_stats['mean']:.2f}m
+- Std deviation: {depth_stats['std']:.2f}m
+- Colormap: {depth_colormap}"""
+ else:
+ depth_info = "Depth estimation not available"
+
+ info_text = f"""Complete vision pipeline processed successfully!
+
+Models Used:
+- OneFormer (Semantic Segmentation)
+{f"- DepthAnythingV2 ({self.depth_config.encoder.upper()})" if DEPTH_AVAILABLE else "- DepthAnythingV2 (Not Available)"}
+
+Image Processing:
+- Original size: {original_size[0]}x{original_size[1]}
+- Processed size: {image.size[0]}x{image.size[1]}
+- Overlay transparency: {overlay_alpha:.1f}
+
+Detected Classes:
+{chr(10).join(class_info)}
+{depth_info}
+{point_cloud_info}
+
+The point cloud shows 3D structure with each point colored according to its segmentation class. Sky points are excluded for better visualization."""
+
+ return segmentation_overlay, depth_vis, point_cloud_fig, info_text
+
+ except Exception as e:
+ return None, None, None, f"Error processing pipeline: {str(e)}"
+
+# Initialize the app
+app = EnhancedSingleViewApp()
+
+def process_uploaded_image(image):
+ try:
+ return app.process_complete_pipeline(image)
+ except:
+ return None, None, None
+
+def process_sample_image(sample_choice):
+ """Process sample image through complete pipeline."""
+ if sample_choice and sample_choice in app.sample_images:
+ image_url = app.sample_images[sample_choice]
+ image = app.download_sample_image(image_url)
+ return app.process_complete_pipeline(image)
+ return None, None, None, "Please select a sample image."
+
+def load_precomputed_example(evt: gr.SelectData):
+ """Load precomputed example results from gallery selection."""
+ if evt.index is not None:
+ example_names = app.examples_manager.get_example_names()
+ if evt.index < len(example_names):
+ example_name = example_names[evt.index]
+ seg_image, depth_image, pc_fig, info_text = app.examples_manager.load_example_results(example_name)
+ return seg_image, depth_image, pc_fig
+ return None, None, None
+
+def get_example_previews():
+ """Get preview images for the gallery."""
+ previews = app.examples_manager.get_example_previews()
+ if not previews:
+ return []
+ return previews
+
+# =============================================================================
+# Create Gradio Interface
+# =============================================================================
+
+def create_gradio_interface():
+ """Create and return the enhanced single-view Gradio interface."""
+
+ with gr.Blocks(
+ title="Enhanced Computer Vision Pipeline",
+ theme=gr.themes.Default()
+ ) as demo:
+
+ gr.Markdown("""
+ # Street Scene 3D Reconstruction
+
+ Upload an image or select an example to see:
+ - **Semantic Segmentation** - Identify roads, buildings, vehicles, people, and other scene elements
+ - **Depth Estimation** - Generate metric depth maps showing distance to objects
+ - **3D Point Cloud** - Interactive 3D reconstruction with semantic colors)
+ """)
+
+ with gr.Row():
+ # Left Column: Controls and Input
+ with gr.Column(scale=1):
+ gr.Markdown("### Upload Image")
+
+ uploaded_image = gr.Image(
+ type="pil",
+ label="Upload Image"
+ )
+ upload_btn = gr.Button("Process Image", variant="primary", size="lg")
+
+ gr.Markdown("### Examples")
+ gr.Markdown("Click on an image to load the example:")
+
+ # Example gallery
+ example_gallery = gr.Gallery(
+ value=get_example_previews(),
+ label="Example Images",
+ show_label=False,
+ elem_id="example_gallery",
+ columns=2,
+ rows=3,
+ height="auto",
+ object_fit="cover"
+ )
+
+ # Right Column: Results
+ with gr.Column(scale=2):
+ gr.Markdown("### Results")
+
+ # Segmentation and Depth side by side
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("#### Semantic Segmentation")
+ segmentation_output = gr.Image(label="Segmentation Overlay")
+
+ with gr.Column():
+ gr.Markdown("#### Depth Estimation")
+ depth_output = gr.Image(label="Depth Map")
+
+ # Point Cloud below
+ gr.Markdown("#### 3D Point Cloud")
+ pointcloud_output = gr.Plot(label="Interactive 3D Point Cloud (Colored by Segmentation)")
+
+ # Event handlers
+ upload_btn.click(
+ fn=process_uploaded_image,
+ inputs=[uploaded_image],
+ outputs=[segmentation_output, depth_output, pointcloud_output]
+ )
+
+ # Gallery selection loads example directly
+ example_gallery.select(
+ fn=load_precomputed_example,
+ outputs=[segmentation_output, depth_output, pointcloud_output]
+ )
+
+ return demo
+
+# =============================================================================
+# Main Execution
+# =============================================================================
+
+if __name__ == "__main__":
+ # Create and launch the interface
+ demo = create_gradio_interface()
+
+ print("Starting Enhanced Single-View Computer Vision App...")
+ print("Complete Pipeline: Segmentation + Depth + 3D Point Cloud")
+ print("Device:", "CUDA" if torch.cuda.is_available() else "CPU")
+ print("Depth Available:", "YES" if DEPTH_AVAILABLE else "NO")
+ print("Point Cloud Colors: Segmentation-based (Sky Excluded)")
+ print(f"Output Directory: {OUTPUT_DIR.absolute()}")
+ print(f"Available Examples: {len(app.examples_manager.examples)}")
+
+ # Launch the app
+ demo.launch(
+ share=True, # Creates a public link
+ debug=True, # Enable debugging
+ server_name="0.0.0.0", # Allow external connections
+ server_port=7860, # Default port
+ show_error=True, # Show errors in the interface
+ quiet=False # Show startup logs
+ )
\ No newline at end of file
diff --git a/metric_depth/README.md b/metric_depth/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fc84a2f050a63219dd1193b12b0651bd0ab0ced5
--- /dev/null
+++ b/metric_depth/README.md
@@ -0,0 +1,114 @@
+# Depth Anything V2 for Metric Depth Estimation
+
+
+
+We here provide a simple codebase to fine-tune our Depth Anything V2 pre-trained encoder for metric depth estimation. Built on our powerful encoder, we use a simple DPT head to regress the depth. We fine-tune our pre-trained encoder on synthetic Hypersim / Virtual KITTI datasets for indoor / outdoor metric depth estimation, respectively.
+
+
+# Pre-trained Models
+
+We provide **six metric depth models** of three scales for indoor and outdoor scenes, respectively.
+
+| Base Model | Params | Indoor (Hypersim) | Outdoor (Virtual KITTI 2) |
+|:-|-:|:-:|:-:|
+| Depth-Anything-V2-Small | 24.8M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Small/resolve/main/depth_anything_v2_metric_hypersim_vits.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Small/resolve/main/depth_anything_v2_metric_vkitti_vits.pth?download=true) |
+| Depth-Anything-V2-Base | 97.5M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Base/resolve/main/depth_anything_v2_metric_hypersim_vitb.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Base/resolve/main/depth_anything_v2_metric_vkitti_vitb.pth?download=true) |
+| Depth-Anything-V2-Large | 335.3M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Large/resolve/main/depth_anything_v2_metric_hypersim_vitl.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Large/resolve/main/depth_anything_v2_metric_vkitti_vitl.pth?download=true) |
+
+*We recommend to first try our larger models (if computational cost is affordable) and the indoor version.*
+
+## Usage
+
+### Prepraration
+
+```bash
+git clone https://github.com/DepthAnything/Depth-Anything-V2
+cd Depth-Anything-V2/metric_depth
+pip install -r requirements.txt
+```
+
+Download the checkpoints listed [here](#pre-trained-models) and put them under the `checkpoints` directory.
+
+### Use our models
+```python
+import cv2
+import torch
+
+from depth_anything_v2.dpt import DepthAnythingV2
+
+model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
+}
+
+encoder = 'vitl' # or 'vits', 'vitb'
+dataset = 'hypersim' # 'hypersim' for indoor model, 'vkitti' for outdoor model
+max_depth = 20 # 20 for indoor model, 80 for outdoor model
+
+model = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth})
+model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_metric_{dataset}_{encoder}.pth', map_location='cpu'))
+model.eval()
+
+raw_img = cv2.imread('your/image/path')
+depth = model.infer_image(raw_img) # HxW depth map in meters in numpy
+```
+
+### Running script on images
+
+Here, we take the `vitl` encoder as an example. You can also use `vitb` or `vits` encoders.
+
+```bash
+# indoor scenes
+python run.py \
+ --encoder vitl \
+ --load-from checkpoints/depth_anything_v2_metric_hypersim_vitl.pth \
+ --max-depth 20 \
+ --img-path --outdir [--input-size ] [--save-numpy]
+
+# outdoor scenes
+python run.py \
+ --encoder vitl \
+ --load-from checkpoints/depth_anything_v2_metric_vkitti_vitl.pth \
+ --max-depth 80 \
+ --img-path --outdir [--input-size ] [--save-numpy]
+```
+
+### Project 2D images to point clouds:
+
+```bash
+python depth_to_pointcloud.py \
+ --encoder vitl \
+ --load-from checkpoints/depth_anything_v2_metric_hypersim_vitl.pth \
+ --max-depth 20 \
+ --img-path --outdir
+```
+
+### Reproduce training
+
+Please first prepare the [Hypersim](https://github.com/apple/ml-hypersim) and [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) datasets. Then:
+
+```bash
+bash dist_train.sh
+```
+
+
+## Citation
+
+If you find this project useful, please consider citing:
+
+```bibtex
+@article{depth_anything_v2,
+ title={Depth Anything V2},
+ author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Zhao, Zhen and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
+ journal={arXiv:2406.09414},
+ year={2024}
+}
+
+@inproceedings{depth_anything_v1,
+ title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data},
+ author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
+ booktitle={CVPR},
+ year={2024}
+}
+```
diff --git a/metric_depth/dataset/hypersim.py b/metric_depth/dataset/hypersim.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d8dd3e0873ed736c1888fe1f639cc4439b5ab09
--- /dev/null
+++ b/metric_depth/dataset/hypersim.py
@@ -0,0 +1,74 @@
+import cv2
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from torchvision.transforms import Compose
+
+from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
+
+
+def hypersim_distance_to_depth(npyDistance):
+ intWidth, intHeight, fltFocal = 1024, 768, 886.81
+
+ npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape(
+ 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None]
+ npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5,
+ intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None]
+ npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32)
+ npyImageplane = np.concatenate(
+ [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2)
+
+ npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal
+ return npyDepth
+
+
+class Hypersim(Dataset):
+ def __init__(self, filelist_path, mode, size=(518, 518)):
+
+ self.mode = mode
+ self.size = size
+
+ with open(filelist_path, 'r') as f:
+ self.filelist = f.read().splitlines()
+
+ net_w, net_h = size
+ self.transform = Compose([
+ Resize(
+ width=net_w,
+ height=net_h,
+ resize_target=True if mode == 'train' else False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ] + ([Crop(size[0])] if self.mode == 'train' else []))
+
+ def __getitem__(self, item):
+ img_path = self.filelist[item].split(' ')[0]
+ depth_path = self.filelist[item].split(' ')[1]
+
+ image = cv2.imread(img_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
+
+ depth_fd = h5py.File(depth_path, "r")
+ distance_meters = np.array(depth_fd['dataset'])
+ depth = hypersim_distance_to_depth(distance_meters)
+
+ sample = self.transform({'image': image, 'depth': depth})
+
+ sample['image'] = torch.from_numpy(sample['image'])
+ sample['depth'] = torch.from_numpy(sample['depth'])
+
+ sample['valid_mask'] = (torch.isnan(sample['depth']) == 0)
+ sample['depth'][sample['valid_mask'] == 0] = 0
+
+ sample['image_path'] = self.filelist[item].split(' ')[0]
+
+ return sample
+
+ def __len__(self):
+ return len(self.filelist)
\ No newline at end of file
diff --git a/metric_depth/dataset/kitti.py b/metric_depth/dataset/kitti.py
new file mode 100644
index 0000000000000000000000000000000000000000..4be6828ad52720c64e3296fa81bb262e95ec1bbe
--- /dev/null
+++ b/metric_depth/dataset/kitti.py
@@ -0,0 +1,57 @@
+import cv2
+import torch
+from torch.utils.data import Dataset
+from torchvision.transforms import Compose
+
+from dataset.transform import Resize, NormalizeImage, PrepareForNet
+
+
+class KITTI(Dataset):
+ def __init__(self, filelist_path, mode, size=(518, 518)):
+ if mode != 'val':
+ raise NotImplementedError
+
+ self.mode = mode
+ self.size = size
+
+ with open(filelist_path, 'r') as f:
+ self.filelist = f.read().splitlines()
+
+ net_w, net_h = size
+ self.transform = Compose([
+ Resize(
+ width=net_w,
+ height=net_h,
+ resize_target=True if mode == 'train' else False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ])
+
+ def __getitem__(self, item):
+ img_path = self.filelist[item].split(' ')[0]
+ depth_path = self.filelist[item].split(' ')[1]
+
+ image = cv2.imread(img_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
+
+ depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype('float32')
+
+ sample = self.transform({'image': image, 'depth': depth})
+
+ sample['image'] = torch.from_numpy(sample['image'])
+ sample['depth'] = torch.from_numpy(sample['depth'])
+ sample['depth'] = sample['depth'] / 256.0 # convert in meters
+
+ sample['valid_mask'] = sample['depth'] > 0
+
+ sample['image_path'] = self.filelist[item].split(' ')[0]
+
+ return sample
+
+ def __len__(self):
+ return len(self.filelist)
\ No newline at end of file
diff --git a/metric_depth/dataset/transform.py b/metric_depth/dataset/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..41fd0dc9b270081b1c7bdfe3434777e4170ca4a3
--- /dev/null
+++ b/metric_depth/dataset/transform.py
@@ -0,0 +1,277 @@
+import cv2
+import math
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ if "semseg_mask" in sample:
+ # sample["semseg_mask"] = cv2.resize(
+ # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
+ # )
+ sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ # sample["mask"] = sample["mask"].astype(bool)
+
+ # print(sample['image'].shape, sample['depth'].shape)
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "semseg_mask" in sample:
+ sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
+ sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
+
+ return sample
+
+
+class Crop(object):
+ """Crop sample for batch-wise training. Image is of shape CxHxW
+ """
+
+ def __init__(self, size):
+ if isinstance(size, int):
+ self.size = (size, size)
+ else:
+ self.size = size
+
+ def __call__(self, sample):
+ h, w = sample['image'].shape[-2:]
+ assert h >= self.size[0] and w >= self.size[1], 'Wrong size'
+
+ h_start = np.random.randint(0, h - self.size[0] + 1)
+ w_start = np.random.randint(0, w - self.size[1] + 1)
+ h_end = h_start + self.size[0]
+ w_end = w_start + self.size[1]
+
+ sample['image'] = sample['image'][:, h_start: h_end, w_start: w_end]
+
+ if "depth" in sample:
+ sample["depth"] = sample["depth"][h_start: h_end, w_start: w_end]
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"][h_start: h_end, w_start: w_end]
+
+ if "semseg_mask" in sample:
+ sample["semseg_mask"] = sample["semseg_mask"][h_start: h_end, w_start: w_end]
+
+ return sample
\ No newline at end of file
diff --git a/metric_depth/dataset/vkitti2.py b/metric_depth/dataset/vkitti2.py
new file mode 100644
index 0000000000000000000000000000000000000000..48cb03112b8861fe7862cce3b25158a0b0a5ff25
--- /dev/null
+++ b/metric_depth/dataset/vkitti2.py
@@ -0,0 +1,54 @@
+import cv2
+import torch
+from torch.utils.data import Dataset
+from torchvision.transforms import Compose
+
+from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
+
+
+class VKITTI2(Dataset):
+ def __init__(self, filelist_path, mode, size=(518, 518)):
+
+ self.mode = mode
+ self.size = size
+
+ with open(filelist_path, 'r') as f:
+ self.filelist = f.read().splitlines()
+
+ net_w, net_h = size
+ self.transform = Compose([
+ Resize(
+ width=net_w,
+ height=net_h,
+ resize_target=True if mode == 'train' else False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ] + ([Crop(size[0])] if self.mode == 'train' else []))
+
+ def __getitem__(self, item):
+ img_path = self.filelist[item].split(' ')[0]
+ depth_path = self.filelist[item].split(' ')[1]
+
+ image = cv2.imread(img_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
+
+ depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m
+
+ sample = self.transform({'image': image, 'depth': depth})
+
+ sample['image'] = torch.from_numpy(sample['image'])
+ sample['depth'] = torch.from_numpy(sample['depth'])
+
+ sample['valid_mask'] = (sample['depth'] <= 80)
+
+ sample['image_path'] = self.filelist[item].split(' ')[0]
+
+ return sample
+
+ def __len__(self):
+ return len(self.filelist)
\ No newline at end of file
diff --git a/metric_depth/depth_anything_v2/dinov2.py b/metric_depth/depth_anything_v2/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4499a18330523aa3564b16be70e813de000c94
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2.py
@@ -0,0 +1,415 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+ # w0, h0 = w0 + 0.1, h0 + 0.1
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
+ mode="bicubic",
+ antialias=self.interpolate_antialias
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def DINOv2(model_name):
+ model_zoo = {
+ "vits": vit_small,
+ "vitb": vit_base,
+ "vitl": vit_large,
+ "vitg": vit_giant2
+ }
+
+ return model_zoo[model_name](
+ img_size=518,
+ patch_size=14,
+ init_values=1.0,
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
+ block_chunks=0,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1
+ )
\ No newline at end of file
diff --git a/metric_depth/depth_anything_v2/dinov2_layers/__init__.py b/metric_depth/depth_anything_v2/dinov2_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2_layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/metric_depth/depth_anything_v2/dinov2_layers/attention.py b/metric_depth/depth_anything_v2/dinov2_layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2_layers/attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
\ No newline at end of file
diff --git a/metric_depth/depth_anything_v2/dinov2_layers/block.py b/metric_depth/depth_anything_v2/dinov2_layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2_layers/block.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/metric_depth/depth_anything_v2/dinov2_layers/drop_path.py b/metric_depth/depth_anything_v2/dinov2_layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2_layers/drop_path.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/metric_depth/depth_anything_v2/dinov2_layers/layer_scale.py b/metric_depth/depth_anything_v2/dinov2_layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2_layers/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/metric_depth/depth_anything_v2/dinov2_layers/mlp.py b/metric_depth/depth_anything_v2/dinov2_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2_layers/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/metric_depth/depth_anything_v2/dinov2_layers/patch_embed.py b/metric_depth/depth_anything_v2/dinov2_layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2_layers/patch_embed.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/metric_depth/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/metric_depth/depth_anything_v2/dinov2_layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dinov2_layers/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/metric_depth/depth_anything_v2/dpt.py b/metric_depth/depth_anything_v2/dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..65413046b5fb4f65b90978bcdc9cc94bf8cdcae9
--- /dev/null
+++ b/metric_depth/depth_anything_v2/dpt.py
@@ -0,0 +1,222 @@
+import cv2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Compose
+
+from .dinov2 import DINOv2
+from .util.blocks import FeatureFusionBlock, _make_scratch
+from .util.transform import Resize, NormalizeImage, PrepareForNet
+
+
+def _make_fusion_block(features, use_bn, size=None):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_feature, out_feature):
+ super().__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_feature),
+ nn.ReLU(True)
+ )
+
+ def forward(self, x):
+ return self.conv_block(x)
+
+
+class DPTHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ features=256,
+ use_bn=False,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False
+ ):
+ super(DPTHead, self).__init__()
+
+ self.use_clstoken = use_clstoken
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.Sigmoid()
+ )
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv1(path_1)
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ out = self.scratch.output_conv2(out)
+
+ return out
+
+
+class DepthAnythingV2(nn.Module):
+ def __init__(
+ self,
+ encoder='vitl',
+ features=256,
+ out_channels=[256, 512, 1024, 1024],
+ use_bn=False,
+ use_clstoken=False,
+ max_depth=20.0
+ ):
+ super(DepthAnythingV2, self).__init__()
+
+ self.intermediate_layer_idx = {
+ 'vits': [2, 5, 8, 11],
+ 'vitb': [2, 5, 8, 11],
+ 'vitl': [4, 11, 17, 23],
+ 'vitg': [9, 19, 29, 39]
+ }
+
+ self.max_depth = max_depth
+
+ self.encoder = encoder
+ self.pretrained = DINOv2(model_name=encoder)
+
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
+
+ def forward(self, x):
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
+
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
+
+ depth = self.depth_head(features, patch_h, patch_w) * self.max_depth
+
+ return depth.squeeze(1)
+
+ @torch.no_grad()
+ def infer_image(self, raw_image, input_size=518):
+ image, (h, w) = self.image2tensor(raw_image, input_size)
+
+ depth = self.forward(image)
+
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
+
+ return depth.cpu().numpy()
+
+ def image2tensor(self, raw_image, input_size=518):
+ transform = Compose([
+ Resize(
+ width=input_size,
+ height=input_size,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ])
+
+ h, w = raw_image.shape[:2]
+
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
+
+ image = transform({'image': image})['image']
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
+ image = image.to(DEVICE)
+
+ return image, (h, w)
diff --git a/metric_depth/depth_anything_v2/util/blocks.py b/metric_depth/depth_anything_v2/util/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37
--- /dev/null
+++ b/metric_depth/depth_anything_v2/util/blocks.py
@@ -0,0 +1,148 @@
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/metric_depth/depth_anything_v2/util/transform.py b/metric_depth/depth_anything_v2/util/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73
--- /dev/null
+++ b/metric_depth/depth_anything_v2/util/transform.py
@@ -0,0 +1,158 @@
+import numpy as np
+import cv2
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
+
+ # resize sample
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
+
+ if self.__resize_target:
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ return sample
\ No newline at end of file
diff --git a/metric_depth/depth_to_pointcloud.py b/metric_depth/depth_to_pointcloud.py
new file mode 100644
index 0000000000000000000000000000000000000000..770fe60698724327f1071c66d685b4a3d8ce7ca8
--- /dev/null
+++ b/metric_depth/depth_to_pointcloud.py
@@ -0,0 +1,114 @@
+"""
+Born out of Depth Anything V1 Issue 36
+Make sure you have the necessary libraries installed.
+Code by @1ssb
+
+This script processes a set of images to generate depth maps and corresponding point clouds.
+The resulting point clouds are saved in the specified output directory.
+
+Usage:
+ python script.py --encoder vitl --load-from path_to_model --max-depth 20 --img-path path_to_images --outdir output_directory --focal-length-x 470.4 --focal-length-y 470.4
+
+Arguments:
+ --encoder: Model encoder to use. Choices are ['vits', 'vitb', 'vitl', 'vitg'].
+ --load-from: Path to the pre-trained model weights.
+ --max-depth: Maximum depth value for the depth map.
+ --img-path: Path to the input image or directory containing images.
+ --outdir: Directory to save the output point clouds.
+ --focal-length-x: Focal length along the x-axis.
+ --focal-length-y: Focal length along the y-axis.
+"""
+
+import argparse
+import cv2
+import glob
+import numpy as np
+import open3d as o3d
+import os
+from PIL import Image
+import torch
+
+from depth_anything_v2.dpt import DepthAnythingV2
+
+
+def main():
+ # Parse command-line arguments
+ parser = argparse.ArgumentParser(description='Generate depth maps and point clouds from images.')
+ parser.add_argument('--encoder', default='vitl', type=str, choices=['vits', 'vitb', 'vitl', 'vitg'],
+ help='Model encoder to use.')
+ parser.add_argument('--load-from', default='', type=str, required=True,
+ help='Path to the pre-trained model weights.')
+ parser.add_argument('--max-depth', default=20, type=float,
+ help='Maximum depth value for the depth map.')
+ parser.add_argument('--img-path', type=str, required=True,
+ help='Path to the input image or directory containing images.')
+ parser.add_argument('--outdir', type=str, default='./vis_pointcloud',
+ help='Directory to save the output point clouds.')
+ parser.add_argument('--focal-length-x', default=470.4, type=float,
+ help='Focal length along the x-axis.')
+ parser.add_argument('--focal-length-y', default=470.4, type=float,
+ help='Focal length along the y-axis.')
+
+ args = parser.parse_args()
+
+ # Determine the device to use (CUDA, MPS, or CPU)
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
+
+ # Model configuration based on the chosen encoder
+ model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+ }
+
+ # Initialize the DepthAnythingV2 model with the specified configuration
+ depth_anything = DepthAnythingV2(**{**model_configs[args.encoder], 'max_depth': args.max_depth})
+ depth_anything.load_state_dict(torch.load(args.load_from, map_location='cpu'))
+ depth_anything = depth_anything.to(DEVICE).eval()
+
+ # Get the list of image files to process
+ if os.path.isfile(args.img_path):
+ if args.img_path.endswith('txt'):
+ with open(args.img_path, 'r') as f:
+ filenames = f.read().splitlines()
+ else:
+ filenames = [args.img_path]
+ else:
+ filenames = glob.glob(os.path.join(args.img_path, '**/*'), recursive=True)
+
+ # Create the output directory if it doesn't exist
+ os.makedirs(args.outdir, exist_ok=True)
+
+ # Process each image file
+ for k, filename in enumerate(filenames):
+ print(f'Processing {k+1}/{len(filenames)}: {filename}')
+
+ # Load the image
+ color_image = Image.open(filename).convert('RGB')
+ width, height = color_image.size
+
+ # Read the image using OpenCV
+ image = cv2.imread(filename)
+ pred = depth_anything.infer_image(image, height)
+
+ # Resize depth prediction to match the original image size
+ resized_pred = Image.fromarray(pred).resize((width, height), Image.NEAREST)
+
+ # Generate mesh grid and calculate point cloud coordinates
+ x, y = np.meshgrid(np.arange(width), np.arange(height))
+ x = (x - width / 2) / args.focal_length_x
+ y = (y - height / 2) / args.focal_length_y
+ z = np.array(resized_pred)
+ points = np.stack((np.multiply(x, z), np.multiply(y, z), z), axis=-1).reshape(-1, 3)
+ colors = np.array(color_image).reshape(-1, 3) / 255.0
+
+ # Create the point cloud and save it to the output directory
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+ o3d.io.write_point_cloud(os.path.join(args.outdir, os.path.splitext(os.path.basename(filename))[0] + ".ply"), pcd)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/metric_depth/dist_train.sh b/metric_depth/dist_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..afa750ad3c3e05fca741bcd66a5f6fea1dab46ac
--- /dev/null
+++ b/metric_depth/dist_train.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+now=$(date +"%Y%m%d_%H%M%S")
+
+epoch=120
+bs=4
+gpus=8
+lr=0.000005
+encoder=vitl
+dataset=hypersim # vkitti
+img_size=518
+min_depth=0.001
+max_depth=20 # 80 for virtual kitti
+pretrained_from=../checkpoints/depth_anything_v2_${encoder}.pth
+save_path=exp/hypersim # exp/vkitti
+
+mkdir -p $save_path
+
+python3 -m torch.distributed.launch \
+ --nproc_per_node=$gpus \
+ --nnodes 1 \
+ --node_rank=0 \
+ --master_addr=localhost \
+ --master_port=20596 \
+ train.py --epoch $epoch --encoder $encoder --bs $bs --lr $lr --save-path $save_path --dataset $dataset \
+ --img-size $img_size --min-depth $min_depth --max-depth $max_depth --pretrained-from $pretrained_from \
+ --port 20596 2>&1 | tee -a $save_path/$now.log
diff --git a/metric_depth/requirements.txt b/metric_depth/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..082330f721c6e704a360cd7706b968d5a93b312f
--- /dev/null
+++ b/metric_depth/requirements.txt
@@ -0,0 +1,5 @@
+matplotlib
+opencv-python
+open3d
+torch
+torchvision
diff --git a/metric_depth/run.py b/metric_depth/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d46f54099d2039f80379fec694a717e2965bed
--- /dev/null
+++ b/metric_depth/run.py
@@ -0,0 +1,81 @@
+import argparse
+import cv2
+import glob
+import matplotlib
+import numpy as np
+import os
+import torch
+
+from depth_anything_v2.dpt import DepthAnythingV2
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Depth Anything V2 Metric Depth Estimation')
+
+ parser.add_argument('--img-path', type=str)
+ parser.add_argument('--input-size', type=int, default=518)
+ parser.add_argument('--outdir', type=str, default='./vis_depth')
+
+ parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
+ parser.add_argument('--load-from', type=str, default='checkpoints/depth_anything_v2_metric_hypersim_vitl.pth')
+ parser.add_argument('--max-depth', type=float, default=20)
+
+ parser.add_argument('--save-numpy', dest='save_numpy', action='store_true', help='save the model raw output')
+ parser.add_argument('--pred-only', dest='pred_only', action='store_true', help='only display the prediction')
+ parser.add_argument('--grayscale', dest='grayscale', action='store_true', help='do not apply colorful palette')
+
+ args = parser.parse_args()
+
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
+
+ model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+ }
+
+ depth_anything = DepthAnythingV2(**{**model_configs[args.encoder], 'max_depth': args.max_depth})
+ depth_anything.load_state_dict(torch.load(args.load_from, map_location='cpu'))
+ depth_anything = depth_anything.to(DEVICE).eval()
+
+ if os.path.isfile(args.img_path):
+ if args.img_path.endswith('txt'):
+ with open(args.img_path, 'r') as f:
+ filenames = f.read().splitlines()
+ else:
+ filenames = [args.img_path]
+ else:
+ filenames = glob.glob(os.path.join(args.img_path, '**/*'), recursive=True)
+
+ os.makedirs(args.outdir, exist_ok=True)
+
+ cmap = matplotlib.colormaps.get_cmap('Spectral')
+
+ for k, filename in enumerate(filenames):
+ print(f'Progress {k+1}/{len(filenames)}: {filename}')
+
+ raw_image = cv2.imread(filename)
+
+ depth = depth_anything.infer_image(raw_image, args.input_size)
+
+ if args.save_numpy:
+ output_path = os.path.join(args.outdir, os.path.splitext(os.path.basename(filename))[0] + '_raw_depth_meter.npy')
+ np.save(output_path, depth)
+
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ depth = depth.astype(np.uint8)
+
+ if args.grayscale:
+ depth = np.repeat(depth[..., np.newaxis], 3, axis=-1)
+ else:
+ depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
+
+ output_path = os.path.join(args.outdir, os.path.splitext(os.path.basename(filename))[0] + '.png')
+ if args.pred_only:
+ cv2.imwrite(output_path, depth)
+ else:
+ split_region = np.ones((raw_image.shape[0], 50, 3), dtype=np.uint8) * 255
+ combined_result = cv2.hconcat([raw_image, split_region, depth])
+
+ cv2.imwrite(output_path, combined_result)
\ No newline at end of file
diff --git a/metric_depth/train.py b/metric_depth/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b304275c5bd22f63fa19e54c427b3da6a730ea9
--- /dev/null
+++ b/metric_depth/train.py
@@ -0,0 +1,212 @@
+import argparse
+import logging
+import os
+import pprint
+import random
+
+import warnings
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+from torch.optim import AdamW
+import torch.nn.functional as F
+from torch.utils.tensorboard import SummaryWriter
+
+from dataset.hypersim import Hypersim
+from dataset.kitti import KITTI
+from dataset.vkitti2 import VKITTI2
+from depth_anything_v2.dpt import DepthAnythingV2
+from util.dist_helper import setup_distributed
+from util.loss import SiLogLoss
+from util.metric import eval_depth
+from util.utils import init_log
+
+
+parser = argparse.ArgumentParser(description='Depth Anything V2 for Metric Depth Estimation')
+
+parser.add_argument('--encoder', default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
+parser.add_argument('--dataset', default='hypersim', choices=['hypersim', 'vkitti'])
+parser.add_argument('--img-size', default=518, type=int)
+parser.add_argument('--min-depth', default=0.001, type=float)
+parser.add_argument('--max-depth', default=20, type=float)
+parser.add_argument('--epochs', default=40, type=int)
+parser.add_argument('--bs', default=2, type=int)
+parser.add_argument('--lr', default=0.000005, type=float)
+parser.add_argument('--pretrained-from', type=str)
+parser.add_argument('--save-path', type=str, required=True)
+parser.add_argument('--local-rank', default=0, type=int)
+parser.add_argument('--port', default=None, type=int)
+
+
+def main():
+ args = parser.parse_args()
+
+ warnings.simplefilter('ignore', np.RankWarning)
+
+ logger = init_log('global', logging.INFO)
+ logger.propagate = 0
+
+ rank, world_size = setup_distributed(port=args.port)
+
+ if rank == 0:
+ all_args = {**vars(args), 'ngpus': world_size}
+ logger.info('{}\n'.format(pprint.pformat(all_args)))
+ writer = SummaryWriter(args.save_path)
+
+ cudnn.enabled = True
+ cudnn.benchmark = True
+
+ size = (args.img_size, args.img_size)
+ if args.dataset == 'hypersim':
+ trainset = Hypersim('dataset/splits/hypersim/train.txt', 'train', size=size)
+ elif args.dataset == 'vkitti':
+ trainset = VKITTI2('dataset/splits/vkitti2/train.txt', 'train', size=size)
+ else:
+ raise NotImplementedError
+ trainsampler = torch.utils.data.distributed.DistributedSampler(trainset)
+ trainloader = DataLoader(trainset, batch_size=args.bs, pin_memory=True, num_workers=4, drop_last=True, sampler=trainsampler)
+
+ if args.dataset == 'hypersim':
+ valset = Hypersim('dataset/splits/hypersim/val.txt', 'val', size=size)
+ elif args.dataset == 'vkitti':
+ valset = KITTI('dataset/splits/kitti/val.txt', 'val', size=size)
+ else:
+ raise NotImplementedError
+ valsampler = torch.utils.data.distributed.DistributedSampler(valset)
+ valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=4, drop_last=True, sampler=valsampler)
+
+ local_rank = int(os.environ["LOCAL_RANK"])
+
+ model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+ }
+ model = DepthAnythingV2(**{**model_configs[args.encoder], 'max_depth': args.max_depth})
+
+ if args.pretrained_from:
+ model.load_state_dict({k: v for k, v in torch.load(args.pretrained_from, map_location='cpu').items() if 'pretrained' in k}, strict=False)
+
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model.cuda(local_rank)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False,
+ output_device=local_rank, find_unused_parameters=True)
+
+ criterion = SiLogLoss().cuda(local_rank)
+
+ optimizer = AdamW([{'params': [param for name, param in model.named_parameters() if 'pretrained' in name], 'lr': args.lr},
+ {'params': [param for name, param in model.named_parameters() if 'pretrained' not in name], 'lr': args.lr * 10.0}],
+ lr=args.lr, betas=(0.9, 0.999), weight_decay=0.01)
+
+ total_iters = args.epochs * len(trainloader)
+
+ previous_best = {'d1': 0, 'd2': 0, 'd3': 0, 'abs_rel': 100, 'sq_rel': 100, 'rmse': 100, 'rmse_log': 100, 'log10': 100, 'silog': 100}
+
+ for epoch in range(args.epochs):
+ if rank == 0:
+ logger.info('===========> Epoch: {:}/{:}, d1: {:.3f}, d2: {:.3f}, d3: {:.3f}'.format(epoch, args.epochs, previous_best['d1'], previous_best['d2'], previous_best['d3']))
+ logger.info('===========> Epoch: {:}/{:}, abs_rel: {:.3f}, sq_rel: {:.3f}, rmse: {:.3f}, rmse_log: {:.3f}, '
+ 'log10: {:.3f}, silog: {:.3f}'.format(
+ epoch, args.epochs, previous_best['abs_rel'], previous_best['sq_rel'], previous_best['rmse'],
+ previous_best['rmse_log'], previous_best['log10'], previous_best['silog']))
+
+ trainloader.sampler.set_epoch(epoch + 1)
+
+ model.train()
+ total_loss = 0
+
+ for i, sample in enumerate(trainloader):
+ optimizer.zero_grad()
+
+ img, depth, valid_mask = sample['image'].cuda(), sample['depth'].cuda(), sample['valid_mask'].cuda()
+
+ if random.random() < 0.5:
+ img = img.flip(-1)
+ depth = depth.flip(-1)
+ valid_mask = valid_mask.flip(-1)
+
+ pred = model(img)
+
+ loss = criterion(pred, depth, (valid_mask == 1) & (depth >= args.min_depth) & (depth <= args.max_depth))
+
+ loss.backward()
+ optimizer.step()
+
+ total_loss += loss.item()
+
+ iters = epoch * len(trainloader) + i
+
+ lr = args.lr * (1 - iters / total_iters) ** 0.9
+
+ optimizer.param_groups[0]["lr"] = lr
+ optimizer.param_groups[1]["lr"] = lr * 10.0
+
+ if rank == 0:
+ writer.add_scalar('train/loss', loss.item(), iters)
+
+ if rank == 0 and i % 100 == 0:
+ logger.info('Iter: {}/{}, LR: {:.7f}, Loss: {:.3f}'.format(i, len(trainloader), optimizer.param_groups[0]['lr'], loss.item()))
+
+ model.eval()
+
+ results = {'d1': torch.tensor([0.0]).cuda(), 'd2': torch.tensor([0.0]).cuda(), 'd3': torch.tensor([0.0]).cuda(),
+ 'abs_rel': torch.tensor([0.0]).cuda(), 'sq_rel': torch.tensor([0.0]).cuda(), 'rmse': torch.tensor([0.0]).cuda(),
+ 'rmse_log': torch.tensor([0.0]).cuda(), 'log10': torch.tensor([0.0]).cuda(), 'silog': torch.tensor([0.0]).cuda()}
+ nsamples = torch.tensor([0.0]).cuda()
+
+ for i, sample in enumerate(valloader):
+
+ img, depth, valid_mask = sample['image'].cuda().float(), sample['depth'].cuda()[0], sample['valid_mask'].cuda()[0]
+
+ with torch.no_grad():
+ pred = model(img)
+ pred = F.interpolate(pred[:, None], depth.shape[-2:], mode='bilinear', align_corners=True)[0, 0]
+
+ valid_mask = (valid_mask == 1) & (depth >= args.min_depth) & (depth <= args.max_depth)
+
+ if valid_mask.sum() < 10:
+ continue
+
+ cur_results = eval_depth(pred[valid_mask], depth[valid_mask])
+
+ for k in results.keys():
+ results[k] += cur_results[k]
+ nsamples += 1
+
+ torch.distributed.barrier()
+
+ for k in results.keys():
+ dist.reduce(results[k], dst=0)
+ dist.reduce(nsamples, dst=0)
+
+ if rank == 0:
+ logger.info('==========================================================================================')
+ logger.info('{:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}'.format(*tuple(results.keys())))
+ logger.info('{:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}'.format(*tuple([(v / nsamples).item() for v in results.values()])))
+ logger.info('==========================================================================================')
+ print()
+
+ for name, metric in results.items():
+ writer.add_scalar(f'eval/{name}', (metric / nsamples).item(), epoch)
+
+ for k in results.keys():
+ if k in ['d1', 'd2', 'd3']:
+ previous_best[k] = max(previous_best[k], (results[k] / nsamples).item())
+ else:
+ previous_best[k] = min(previous_best[k], (results[k] / nsamples).item())
+
+ if rank == 0:
+ checkpoint = {
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'previous_best': previous_best,
+ }
+ torch.save(checkpoint, os.path.join(args.save_path, 'latest.pth'))
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/metric_depth/util/dist_helper.py b/metric_depth/util/dist_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b6eb432b4988638ac9549a82fbaebf968fe9c61
--- /dev/null
+++ b/metric_depth/util/dist_helper.py
@@ -0,0 +1,41 @@
+import os
+import subprocess
+
+import torch
+import torch.distributed as dist
+
+
+def setup_distributed(backend="nccl", port=None):
+ """AdaHessian Optimizer
+ Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py
+ Originally licensed MIT, Copyright (c) 2020 Wei Li
+ """
+ num_gpus = torch.cuda.device_count()
+
+ if "SLURM_JOB_ID" in os.environ:
+ rank = int(os.environ["SLURM_PROCID"])
+ world_size = int(os.environ["SLURM_NTASKS"])
+ node_list = os.environ["SLURM_NODELIST"]
+ addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
+ # specify master port
+ if port is not None:
+ os.environ["MASTER_PORT"] = str(port)
+ elif "MASTER_PORT" not in os.environ:
+ os.environ["MASTER_PORT"] = "10685"
+ if "MASTER_ADDR" not in os.environ:
+ os.environ["MASTER_ADDR"] = addr
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["LOCAL_RANK"] = str(rank % num_gpus)
+ os.environ["RANK"] = str(rank)
+ else:
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+
+ torch.cuda.set_device(rank % num_gpus)
+
+ dist.init_process_group(
+ backend=backend,
+ world_size=world_size,
+ rank=rank,
+ )
+ return rank, world_size
diff --git a/metric_depth/util/loss.py b/metric_depth/util/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ae5b304effd46661e93ea23127d1115c36b5265
--- /dev/null
+++ b/metric_depth/util/loss.py
@@ -0,0 +1,16 @@
+import torch
+from torch import nn
+
+
+class SiLogLoss(nn.Module):
+ def __init__(self, lambd=0.5):
+ super().__init__()
+ self.lambd = lambd
+
+ def forward(self, pred, target, valid_mask):
+ valid_mask = valid_mask.detach()
+ diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
+ loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
+ self.lambd * torch.pow(diff_log.mean(), 2))
+
+ return loss
diff --git a/metric_depth/util/metric.py b/metric_depth/util/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..8638cf25875c753cb62c3977af1417c221237dce
--- /dev/null
+++ b/metric_depth/util/metric.py
@@ -0,0 +1,26 @@
+import torch
+
+
+def eval_depth(pred, target):
+ assert pred.shape == target.shape
+
+ thresh = torch.max((target / pred), (pred / target))
+
+ d1 = torch.sum(thresh < 1.25).float() / len(thresh)
+ d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh)
+ d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh)
+
+ diff = pred - target
+ diff_log = torch.log(pred) - torch.log(target)
+
+ abs_rel = torch.mean(torch.abs(diff) / target)
+ sq_rel = torch.mean(torch.pow(diff, 2) / target)
+
+ rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
+ rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2)))
+
+ log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target)))
+ silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))
+
+ return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(), 'sq_rel': sq_rel.item(),
+ 'rmse': rmse.item(), 'rmse_log': rmse_log.item(), 'log10':log10.item(), 'silog':silog.item()}
\ No newline at end of file
diff --git a/metric_depth/util/utils.py b/metric_depth/util/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e89b994538c5123075605fb6130022867f37c99b
--- /dev/null
+++ b/metric_depth/util/utils.py
@@ -0,0 +1,26 @@
+import os
+import re
+import numpy as np
+import logging
+
+logs = set()
+
+
+def init_log(name, level=logging.INFO):
+ if (name, level) in logs:
+ return
+ logs.add((name, level))
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+ ch = logging.StreamHandler()
+ ch.setLevel(level)
+ if "SLURM_PROCID" in os.environ:
+ rank = int(os.environ["SLURM_PROCID"])
+ logger.addFilter(lambda record: rank == 0)
+ else:
+ rank = 0
+ format_str = "[%(asctime)s][%(levelname)8s] %(message)s"
+ formatter = logging.Formatter(format_str)
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+ return logger
diff --git a/outputs/depth/depth_20250714_125057.png b/outputs/depth/depth_20250714_125057.png
new file mode 100644
index 0000000000000000000000000000000000000000..f5972bcb3e3369cd5e2d4971f5b84088c31b3053
--- /dev/null
+++ b/outputs/depth/depth_20250714_125057.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:401c45a38bce8552b7633c3606182c22ed132fb9f66d381a5b8b5bb1d00f85fc
+size 200509
diff --git a/outputs/depth/depth_20250714_125314.png b/outputs/depth/depth_20250714_125314.png
new file mode 100644
index 0000000000000000000000000000000000000000..ee96f08b0789feb27d04b438eb73f0ecf2aa9f1f
--- /dev/null
+++ b/outputs/depth/depth_20250714_125314.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f2771574fd95dfd9e34a6f098460a2dee71bc2acbfd2dd3c539b2ca6f2c6cb1
+size 233715
diff --git a/outputs/depth/depth_20250714_125405.png b/outputs/depth/depth_20250714_125405.png
new file mode 100644
index 0000000000000000000000000000000000000000..125e3e63bc7470da5c3ee70660d1420aa9a8d468
--- /dev/null
+++ b/outputs/depth/depth_20250714_125405.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3171880bd122688ebf3676bcbc7a628dd2f723347cd40c6c01f593861fe346f4
+size 269925
diff --git a/outputs/depth/depth_20250714_125456.png b/outputs/depth/depth_20250714_125456.png
new file mode 100644
index 0000000000000000000000000000000000000000..2b38c2fd5c266de3eb06a9a033dfd6d62e14a4cb
--- /dev/null
+++ b/outputs/depth/depth_20250714_125456.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:47f28b875efe304d22a2d48b26a500b934d6f53bac83ab55168688bb8d01cc93
+size 213507
diff --git a/outputs/depth/depth_20250714_125608.png b/outputs/depth/depth_20250714_125608.png
new file mode 100644
index 0000000000000000000000000000000000000000..bd25f4be054025929aac9deb9a69afb200a1e5f7
--- /dev/null
+++ b/outputs/depth/depth_20250714_125608.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1829f3f5ae4503815f5a1ae3eeeceb2d00b10f6f999ffa38a542d92c90b1dac9
+size 209668
diff --git a/outputs/depth/depth_20250714_125650.png b/outputs/depth/depth_20250714_125650.png
new file mode 100644
index 0000000000000000000000000000000000000000..f5972bcb3e3369cd5e2d4971f5b84088c31b3053
--- /dev/null
+++ b/outputs/depth/depth_20250714_125650.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:401c45a38bce8552b7633c3606182c22ed132fb9f66d381a5b8b5bb1d00f85fc
+size 200509
diff --git a/outputs/depth/depth_20250714_125834.png b/outputs/depth/depth_20250714_125834.png
new file mode 100644
index 0000000000000000000000000000000000000000..3d12c7b75434163f6d886c65ca985121062ad872
--- /dev/null
+++ b/outputs/depth/depth_20250714_125834.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:336f50e352b04a96d7f89212cf841250ab641d29f3d7d2a0033322934e28423e
+size 154764
diff --git a/outputs/depth/depth_20250714_125959.png b/outputs/depth/depth_20250714_125959.png
new file mode 100644
index 0000000000000000000000000000000000000000..11b3be7957795523c4cad03c6e42cdef1e3061ef
--- /dev/null
+++ b/outputs/depth/depth_20250714_125959.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0eb005b102ab2365a05e3eb22e06ea2738ca5d46ac813dc5f45a80d48bbcac9d
+size 208247
diff --git a/outputs/depth/depth_20250714_130113.png b/outputs/depth/depth_20250714_130113.png
new file mode 100644
index 0000000000000000000000000000000000000000..83b3668957af05621d9f4d06d8b25a5bc8c50a11
--- /dev/null
+++ b/outputs/depth/depth_20250714_130113.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:261e7c657422309b6e78cd202610cb983f2daf9a88130a953795ff8128d3df76
+size 160103
diff --git a/outputs/depth/depth_20250714_130229.png b/outputs/depth/depth_20250714_130229.png
new file mode 100644
index 0000000000000000000000000000000000000000..46ce8fd3b3798fc0892deb85506f7f18a931a1b2
--- /dev/null
+++ b/outputs/depth/depth_20250714_130229.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcdaea1beefa7cffcace1a3bbb2fab3c728b53cd57142f2eb1bbce499b1c58ab
+size 184653
diff --git a/outputs/pointclouds/pointcloud_20250714_125057.ply b/outputs/pointclouds/pointcloud_20250714_125057.ply
new file mode 100644
index 0000000000000000000000000000000000000000..ab7362e75eb31149a0e23db5560298b698703345
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_125057.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f07eab61319eeceb3da33df6c66da1a761345d57834edd1b7825f4e47f2851ee
+size 1184941
diff --git a/outputs/pointclouds/pointcloud_20250714_125314.ply b/outputs/pointclouds/pointcloud_20250714_125314.ply
new file mode 100644
index 0000000000000000000000000000000000000000..a1874f495adbf28c7829ac086987d7ecab8738ce
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_125314.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9638604c044a819a28a9bd08b1e2be1195e7ff16c5a090309f6b813edc84785c
+size 1158778
diff --git a/outputs/pointclouds/pointcloud_20250714_125405.ply b/outputs/pointclouds/pointcloud_20250714_125405.ply
new file mode 100644
index 0000000000000000000000000000000000000000..622e046f3c6688b2f7d13334c1bf8bac37718f84
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_125405.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3138ad4d68e7a18ba3c05eeeb2ffc4fbd2f3c7ba1f522ce9a4a3680e13b44d6a
+size 840961
diff --git a/outputs/pointclouds/pointcloud_20250714_125456.ply b/outputs/pointclouds/pointcloud_20250714_125456.ply
new file mode 100644
index 0000000000000000000000000000000000000000..82635f46d3a75510161742241b46a47fe175d2bb
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_125456.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d81148f95fcb702a75f36f5fb5622ace24837f8970a6038ce35cd2a7724a281
+size 1169929
diff --git a/outputs/pointclouds/pointcloud_20250714_125608.ply b/outputs/pointclouds/pointcloud_20250714_125608.ply
new file mode 100644
index 0000000000000000000000000000000000000000..20878f9008e0cb4e3fe32c265087450676ebf9a3
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_125608.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:567ef46a3ca77fb2b4cf8f39676f5ef9f65b832b5a94fa2ce5664dfc5f72b45a
+size 1348021
diff --git a/outputs/pointclouds/pointcloud_20250714_125650.ply b/outputs/pointclouds/pointcloud_20250714_125650.ply
new file mode 100644
index 0000000000000000000000000000000000000000..12c99393ad4e2c09143470decb2e29780d2b1bc4
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_125650.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7c694283b48f7f25bb80f6baaebd298fdf2ec194e75778b85919172715c6e43
+size 1184941
diff --git a/outputs/pointclouds/pointcloud_20250714_125834.ply b/outputs/pointclouds/pointcloud_20250714_125834.ply
new file mode 100644
index 0000000000000000000000000000000000000000..53ea5d5ed49875edc7014cdb240629e47b559ae4
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_125834.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7763c963d92363bfc6dd1eaef5345012a87b0d6ed5c9f0f0eb3eb1f33ca8725e
+size 1370809
diff --git a/outputs/pointclouds/pointcloud_20250714_125959.ply b/outputs/pointclouds/pointcloud_20250714_125959.ply
new file mode 100644
index 0000000000000000000000000000000000000000..d05dcf3f794e8372a9d0938c796dc873da987f8e
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_125959.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b705ff6cd6185ce44b7a9c2070803b5f3e9569ffae0c61e76884eea3f3887630
+size 1270153
diff --git a/outputs/pointclouds/pointcloud_20250714_130113.ply b/outputs/pointclouds/pointcloud_20250714_130113.ply
new file mode 100644
index 0000000000000000000000000000000000000000..0c78d6ac949f7f4341be5b83c4c1c8d4e524229f
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_130113.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c451509aab01141153203b0cf59900032320b67150e91d0fa9bfe66ce677c6c
+size 1415116
diff --git a/outputs/pointclouds/pointcloud_20250714_130229.ply b/outputs/pointclouds/pointcloud_20250714_130229.ply
new file mode 100644
index 0000000000000000000000000000000000000000..8e394273e6e374bea7ed540e20d4d60a5921fd5b
--- /dev/null
+++ b/outputs/pointclouds/pointcloud_20250714_130229.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:385aaf07d721abb355b199e53d265910a8d4ee5cf7e986471317da089eeefe79
+size 1374697
diff --git a/outputs/rgb/rgb_20250714_125057.png b/outputs/rgb/rgb_20250714_125057.png
new file mode 100644
index 0000000000000000000000000000000000000000..b95544682b8da9b9c4d1775064d358e1634bb2d4
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_125057.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc574b266f3a32b8f237a2e82d7223ee39cf78fbd128e568255bb926db26a686
+size 2243566
diff --git a/outputs/rgb/rgb_20250714_125314.png b/outputs/rgb/rgb_20250714_125314.png
new file mode 100644
index 0000000000000000000000000000000000000000..668d48b42f6b98a06b97416cb3fa27f591666a38
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_125314.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36d92a8346572ab245a0e03b648bafddc5a601ae21cc8154c1002993a7020af0
+size 1837539
diff --git a/outputs/rgb/rgb_20250714_125405.png b/outputs/rgb/rgb_20250714_125405.png
new file mode 100644
index 0000000000000000000000000000000000000000..2b69be4b6528cc1aba902146db8b6164ac102057
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_125405.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9bc1f998b4a943868d8580aa153b51770fcf17ec2797f51ba2c57a495b516c8
+size 1824275
diff --git a/outputs/rgb/rgb_20250714_125456.png b/outputs/rgb/rgb_20250714_125456.png
new file mode 100644
index 0000000000000000000000000000000000000000..06c87ee7e9e29f2aa50920713ee80b7e45eda4bc
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_125456.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:512a7ba8783182a5ab6c51d08b1642c8ca57e14d049f8cd9d0a9d34a24e1393b
+size 2389334
diff --git a/outputs/rgb/rgb_20250714_125608.png b/outputs/rgb/rgb_20250714_125608.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7bdbe6fe596ae9f80cfcd6b8e253c79e3b41b87
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_125608.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab848cd4e7bbfb69991ec8203840e964e57430c7d0dd4bb6e8f29079d4f372eb
+size 1661428
diff --git a/outputs/rgb/rgb_20250714_125834.png b/outputs/rgb/rgb_20250714_125834.png
new file mode 100644
index 0000000000000000000000000000000000000000..246c5a4230ea35c28a4e22e0e82b9357645c2e23
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_125834.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57e5e573bc467eb7a1be0ff6c7290a2429b537c1ff5b532c79d3660147eff2c1
+size 2251917
diff --git a/outputs/rgb/rgb_20250714_125959.png b/outputs/rgb/rgb_20250714_125959.png
new file mode 100644
index 0000000000000000000000000000000000000000..9d607fde0f38cfee44a70ac0eb0c65b1280b7345
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_125959.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:633ce48f7279fab5b83b2c2242bb26eeabd59687adb967494cbf751022c3b9ae
+size 2368581
diff --git a/outputs/rgb/rgb_20250714_130113.png b/outputs/rgb/rgb_20250714_130113.png
new file mode 100644
index 0000000000000000000000000000000000000000..3cd79e6353bed43332b9ff1e6453baff0f9c13d4
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_130113.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c28d8f294bf4f1a451c15c046b90778e8d05fbb0396f2471e26703617d994612
+size 2524941
diff --git a/outputs/rgb/rgb_20250714_130229.png b/outputs/rgb/rgb_20250714_130229.png
new file mode 100644
index 0000000000000000000000000000000000000000..c80245331ff1a6ccea100db1ec7d6a03d552cdb7
--- /dev/null
+++ b/outputs/rgb/rgb_20250714_130229.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8bd214bb04b24121038535f8a4decdc4cd86e0ca8954472b1ddf1dc74f400a6
+size 2256106
diff --git a/outputs/segmentation/segmentation_20250714_125057.png b/outputs/segmentation/segmentation_20250714_125057.png
new file mode 100644
index 0000000000000000000000000000000000000000..fe33a96511679007d7192410f3026bd06b4d156d
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_125057.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41e7d8f6c2ad61b5b5b9c20d8f681629954006024d590d6833f11a30003dee95
+size 620764
diff --git a/outputs/segmentation/segmentation_20250714_125314.png b/outputs/segmentation/segmentation_20250714_125314.png
new file mode 100644
index 0000000000000000000000000000000000000000..d0659271f4dbcdf13091c9642e73682aa08e0c6e
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_125314.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46e3dac205435b3821a17c27f330593265bd7b6a070e0780dd31f53ccc48de92
+size 704737
diff --git a/outputs/segmentation/segmentation_20250714_125405.png b/outputs/segmentation/segmentation_20250714_125405.png
new file mode 100644
index 0000000000000000000000000000000000000000..e23b54307945e4a2de6e9b0f66d5bf421e19599a
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_125405.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:add7021575b4768b2237a0828d8fba3ff412fb1db8117359f4be49fa8bf10b28
+size 699374
diff --git a/outputs/segmentation/segmentation_20250714_125456.png b/outputs/segmentation/segmentation_20250714_125456.png
new file mode 100644
index 0000000000000000000000000000000000000000..5974f78afb7f0110a9dedef68b54ca5bc7ccc58f
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_125456.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae339a134a677666d23273eb50cb9403bceeb2c4b70703fecbeb3eb983fdb532
+size 612118
diff --git a/outputs/segmentation/segmentation_20250714_125608.png b/outputs/segmentation/segmentation_20250714_125608.png
new file mode 100644
index 0000000000000000000000000000000000000000..1337ef69d8c0912c96db3b7f1c4f7b2b5bca04f2
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_125608.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a9578f22b2a2ce7adfe9d3ba9e716ee32bfcfd35f381fafbad8d5ecaa9d3f1f
+size 633454
diff --git a/outputs/segmentation/segmentation_20250714_125650.png b/outputs/segmentation/segmentation_20250714_125650.png
new file mode 100644
index 0000000000000000000000000000000000000000..fe33a96511679007d7192410f3026bd06b4d156d
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_125650.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41e7d8f6c2ad61b5b5b9c20d8f681629954006024d590d6833f11a30003dee95
+size 620764
diff --git a/outputs/segmentation/segmentation_20250714_125834.png b/outputs/segmentation/segmentation_20250714_125834.png
new file mode 100644
index 0000000000000000000000000000000000000000..f8cfa2a852f6da56f673e6ef115490f72822f2d8
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_125834.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa3e30e63e36ed6746d2e63e9c129c5ca437af8bf1c7f51aab972f98935b76c9
+size 621134
diff --git a/outputs/segmentation/segmentation_20250714_125959.png b/outputs/segmentation/segmentation_20250714_125959.png
new file mode 100644
index 0000000000000000000000000000000000000000..77f6533ffd4caf8c78d1b5798b8efe34ada264a3
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_125959.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4bbc8af91c187a087830feeaa4c6c7df3afd75130d6bcc9af2c8bb08d29f6508
+size 662183
diff --git a/outputs/segmentation/segmentation_20250714_130113.png b/outputs/segmentation/segmentation_20250714_130113.png
new file mode 100644
index 0000000000000000000000000000000000000000..e4e45ff04c1ffd5c7d6785d723cfebb1ef89ea92
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_130113.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9522f602b3f53cd482e6641deb2b1c3cfb1756bfe9ea7cdbc7f770736e717de
+size 693246
diff --git a/outputs/segmentation/segmentation_20250714_130229.png b/outputs/segmentation/segmentation_20250714_130229.png
new file mode 100644
index 0000000000000000000000000000000000000000..625007c9b701cd4c00bd5ea372c39fb8d0b52e22
--- /dev/null
+++ b/outputs/segmentation/segmentation_20250714_130229.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddce41a9080116b8622eefb127a8fd0c20a2d28e7f8022349b18d2fff762ed38
+size 620482
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..884028bbfeb11b08e8a1dd6d6f3154d66d1fb09d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,40 @@
+# Core ML and Deep Learning
+torch>=2.0.0
+torchvision>=0.15.0
+transformers>=4.21.0
+
+# Computer Vision and Image Processing
+opencv-python>=4.5.0
+pillow>=8.0.0
+scikit-image>=0.19.0
+
+# Scientific Computing
+numpy>=1.21.0
+scipy>=1.7.0
+
+# Visualization and Plotting
+matplotlib>=3.5.0
+plotly>=5.0.0
+seaborn>=0.11.0
+
+# 3D Processing
+open3d>=0.15.0
+
+# Web Interface
+gradio>=3.50.2
+fastapi>=0.68.0
+uvicorn>=0.15.0
+
+# Utilities
+requests>=2.25.0
+tqdm>=4.62.0
+pyyaml>=5.4.0
+
+# Optional: For better performance
+albumentations>=1.0.0
+timm>=0.6.0
+
+# For DepthAnythingV2 (these will be installed when cloning the repo)
+# The following are typically included in DepthAnythingV2's setup.py
+einops
+xformers
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9629b935ae69a8cbbb3326d7d22bb07ad9e416b
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python3
+"""
+Setup script for the Semantic Segmentation Gradio App
+This script helps install dependencies and set up the environment
+"""
+
+import subprocess
+import sys
+import os
+from pathlib import Path
+
+def run_command(command, description):
+ """Run a command and handle errors."""
+ print(f"\nš {description}...")
+ try:
+ result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
+ print(f"ā
{description} completed successfully")
+ return True
+ except subprocess.CalledProcessError as e:
+ print(f"ā Error during {description}:")
+ print(f"Command: {command}")
+ print(f"Error: {e.stderr}")
+ return False
+
+def check_python_version():
+ """Check if Python version is compatible."""
+ version = sys.version_info
+ if version.major < 3 or (version.major == 3 and version.minor < 8):
+ print("ā Python 3.8 or higher is required")
+ sys.exit(1)
+ print(f"ā
Python {version.major}.{version.minor}.{version.micro} detected")
+
+def install_dependencies():
+ """Install required dependencies."""
+ requirements = [
+ "gradio>=4.0.0",
+ "torch>=1.9.0",
+ "torchvision>=0.10.0",
+ "transformers>=4.21.0",
+ "pillow>=8.0.0",
+ "numpy>=1.21.0",
+ "matplotlib>=3.5.0",
+ "requests>=2.25.0",
+ ]
+
+ print("\nš¦ Installing dependencies...")
+ for req in requirements:
+ if not run_command(f"pip install {req}", f"Installing {req.split('>=')[0]}"):
+ return False
+ return True
+
+def create_directory_structure():
+ """Create necessary directories."""
+ directories = [
+ "src",
+ "src/models",
+ "sample_images",
+ "outputs"
+ ]
+
+ for directory in directories:
+ Path(directory).mkdir(parents=True, exist_ok=True)
+ print(f"š Created directory: {directory}")
+
+def download_sample_images():
+ """Download some sample images for testing."""
+ import requests
+ from PIL import Image
+ import io
+
+ sample_urls = {
+ "street_scene_1.jpg": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=800",
+ "street_scene_2.jpg": "https://images.unsplash.com/photo-1502920917128-1aa500764cbd?w=800",
+ "urban_road.jpg": "https://images.unsplash.com/photo-1516738901171-8eb4fc13bd20?w=800",
+ }
+
+ sample_dir = Path("sample_images")
+ sample_dir.mkdir(exist_ok=True)
+
+ print("\nš¼ļø Downloading sample images...")
+ for filename, url in sample_urls.items():
+ try:
+ response = requests.get(url, timeout=30)
+ response.raise_for_status()
+
+ image = Image.open(io.BytesIO(response.content))
+ image_path = sample_dir / filename
+ image.save(image_path)
+ print(f"ā
Downloaded: {filename}")
+
+ except Exception as e:
+ print(f"ā ļø Failed to download {filename}: {e}")
+
+def create_launch_script():
+ """Create a simple launch script."""
+ launch_script = '''#!/usr/bin/env python3
+"""
+Launch script for the Semantic Segmentation App
+"""
+
+import sys
+import os
+
+# Add the current directory to the path
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+# Import and run the app
+try:
+ from complete_gradio_app import create_gradio_interface
+ import torch
+
+ print("š Starting Semantic Segmentation App...")
+ print("š» Device:", "CUDA" if torch.cuda.is_available() else "CPU")
+
+ demo = create_gradio_interface()
+ demo.launch(
+ share=True,
+ debug=True,
+ server_name="0.0.0.0",
+ server_port=7860
+ )
+
+except ImportError as e:
+ print(f"ā Import error: {e}")
+ print("Please make sure all dependencies are installed by running: python setup.py")
+
+except Exception as e:
+ print(f"ā Error starting app: {e}")
+'''
+
+ with open("launch_app.py", "w") as f:
+ f.write(launch_script)
+
+ # Make it executable on Unix systems
+ if os.name != 'nt':
+ os.chmod("launch_app.py", 0o755)
+
+ print("ā
Created launch script: launch_app.py")
+
+def create_readme():
+ """Create a README file with usage instructions."""
+ readme_content = '''# Semantic Segmentation Gradio App
+
+A user-friendly web interface for semantic segmentation using OneFormer and Mask2Former models.
+
+## š Quick Start
+
+1. **Install dependencies:**
+ ```bash
+ python setup.py
+ ```
+
+2. **Launch the app:**
+ ```bash
+ python launch_app.py
+ ```
+
+ Or run directly:
+ ```bash
+ python complete_gradio_app.py
+ ```
+
+3. **Open your browser** and go to the provided local URL (usually http://localhost:7860)
+
+## š Requirements
+
+- Python 3.8+
+- CUDA-compatible GPU (optional, but recommended)
+- At least 8GB RAM
+- Internet connection (for model downloads)
+
+## šÆ Features
+
+- **Two State-of-the-Art Models:**
+ - OneFormer: Universal segmentation (semantic, instance, panoptic)
+ - Mask2Former: High-accuracy semantic segmentation
+
+- **User-Friendly Interface:**
+ - Upload custom images
+ - Select from sample images
+ - Adjustable overlay transparency
+ - Real-time processing
+
+- **Professional Output:**
+ - Colored segmentation overlays
+ - Detailed class statistics
+ - High-quality visualizations
+
+## š§ Troubleshooting
+
+### Common Issues:
+
+1. **CUDA out of memory:**
+ - Reduce image size
+ - Use CPU instead of GPU
+
+2. **Model download fails:**
+ - Check internet connection
+ - Try again (models are large ~1-2GB each)
+
+3. **ImportError:**
+ - Run `python setup.py` again
+ - Check Python version (3.8+ required)
+
+### Performance Tips:
+
+- First model load takes time (downloading from HuggingFace)
+- GPU acceleration significantly speeds up processing
+- Images are automatically resized to prevent memory issues
+
+## š Supported Classes
+
+The models are trained on Cityscapes dataset and can recognize:
+- Road, sidewalk, building, wall, fence
+- Traffic light, traffic sign, pole
+- Vegetation, terrain, sky
+- Person, rider, car, truck, bus, train, motorcycle, bicycle
+
+## šØ Color Coding
+
+Each class is visualized with a specific color following Cityscapes conventions:
+- Road: Dark purple
+- Sky: Steel blue
+- Person: Crimson
+- Car: Dark blue
+- Vegetation: Olive green
+- And more...
+
+## š License
+
+This project uses pre-trained models from HuggingFace:
+- OneFormer: [Model License](https://huggingface.co/shi-labs/oneformer_cityscapes_swin_large)
+- Mask2Former: [Model License](https://huggingface.co/facebook/mask2former-swin-large-cityscapes-semantic)
+
+## š¤ Contributing
+
+Feel free to submit issues and enhancement requests!
+'''
+
+ with open("README.md", "w") as f:
+ f.write(readme_content)
+
+ print("ā
Created README.md")
+
+def main():
+ """Main setup function."""
+ print("šÆ Semantic Segmentation App Setup")
+ print("=" * 50)
+
+ # Check Python version
+ check_python_version()
+
+ # Create directory structure
+ create_directory_structure()
+
+ # Install dependencies
+ if not install_dependencies():
+ print("\nā Failed to install some dependencies. Please check the errors above.")
+ return False
+
+ # Download sample images
+ try:
+ download_sample_images()
+ except Exception as e:
+ print(f"ā ļø Warning: Could not download sample images: {e}")
+
+ # Create launch script
+ create_launch_script()
+
+ # Create README
+ create_readme()
+
+ print("\n" + "=" * 50)
+ print("ā
Setup completed successfully!")
+ print("\nš To launch the app, run:")
+ print(" python launch_app.py")
+ print("\nš For more information, see README.md")
+
+ return True
+
+if __name__ == "__main__":
+ success = main()
+ sys.exit(0 if success else 1)
\ No newline at end of file