Spaces:
Sleeping
Sleeping
| import os | |
| os.environ.setdefault("OPENCV_AVFOUNDATION_SKIP_AUTH", "1") | |
| os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import re | |
| import io | |
| import json | |
| import tempfile | |
| import time | |
| import threading | |
| import uuid | |
| import importlib.util | |
| from pathlib import Path | |
| from collections import deque | |
| from typing import Optional, Tuple, Dict, List | |
| import numpy as np | |
| import pandas as pd | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from scipy import signal | |
| from scipy.signal import find_peaks, welch, get_window | |
| # .mat support (SCAMPS) | |
| try: | |
| from scipy.io import loadmat | |
| MAT_SUPPORT = True | |
| except Exception: | |
| MAT_SUPPORT = False | |
| # Matplotlib headless backend | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib.gridspec import GridSpec | |
| import gradio as gr | |
| class PhysMambaattention_viz: | |
| """Simplified Grad-CAM for PhysMamba.""" | |
| def __init__(self, model: nn.Module, device: torch.device): | |
| self.model = model | |
| self.device = device | |
| self.activations = None | |
| self.gradients = None | |
| self.hook_handles = [] | |
| def _get_target_layer(self): | |
| """Auto-detect the best layer for visualization, excluding Mamba/SSM layers.""" | |
| # Strategy 1: Look for temporal difference convolution | |
| for name, module in reversed(list(self.model.named_modules())): | |
| if ('tdc' in name.lower() or 'temporal_diff' in name.lower()) and not ('mamba' in name.lower() or 'ssm' in name.lower()): | |
| if isinstance(module, (nn.Conv2d, nn.Conv3d)): | |
| print(f"[attention_viz] Selected TDC layer: {name} ({type(module).__name__})") | |
| return name, module | |
| # Strategy 2: Look for 3D convolutions (temporal-spatial) | |
| for name, module in reversed(list(self.model.named_modules())): | |
| if isinstance(module, nn.Conv3d) and not ('mamba' in name.lower() or 'ssm' in name.lower()): | |
| print(f"[attention_viz] Selected Conv3d layer: {name}") | |
| return name, module | |
| # Strategy 3: Look for 2D convolutions (spatial) | |
| for name, module in reversed(list(self.model.named_modules())): | |
| if isinstance(module, nn.Conv2d) and not ('mamba' in name.lower() or 'ssm' in name.lower()): | |
| print(f"[attention_viz] Selected Conv2d layer: {name}") | |
| return name, module | |
| # Strategy 4: Look for any convolution (last resort) | |
| for name, module in reversed(list(self.model.named_modules())): | |
| if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | |
| print(f"[attention_viz] Selected Conv layer (fallback): {name} ({type(module).__name__})") | |
| return name, module | |
| # Strategy 5: Print all available layers and pick the last non-Mamba one | |
| print("\n[attention_viz] Available layers:") | |
| suitable_layers = [] | |
| for name, module in self.model.named_modules(): | |
| layer_type = type(module).__name__ | |
| if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): | |
| is_mamba = 'mamba' in name.lower() or 'ssm' in name.lower() | |
| print(f" - {name}: {layer_type} {'(MAMBA - SKIP)' if is_mamba else ''}") | |
| if not is_mamba: | |
| suitable_layers.append((name, module)) | |
| if suitable_layers: | |
| name, module = suitable_layers[-1] | |
| print(f"\n[attention_viz] Selected last suitable layer: {name} ({type(module).__name__})") | |
| return name, module | |
| raise ValueError("No suitable layer found for Grad-CAM (all layers are Mamba/SSM)") | |
| def _forward_hook(self, module, input, output): | |
| self.activations = output.detach() | |
| def _backward_hook(self, module, grad_input, grad_output): | |
| self.gradients = grad_output[0].detach() | |
| def register_hooks(self, target_layer_name: Optional[str] = None): | |
| """Register hooks on target layer""" | |
| self._remove_hooks() | |
| if target_layer_name and target_layer_name.strip(): | |
| # Manual selection | |
| try: | |
| target_module = dict(self.model.named_modules())[target_layer_name.strip()] | |
| print(f"Using manually specified layer: {target_layer_name}") | |
| except KeyError: | |
| print(f"⚠ Layer '{target_layer_name}' not found, falling back to auto-detection") | |
| target_layer_name, target_module = self._get_target_layer() | |
| else: | |
| # Auto-detection | |
| target_layer_name, target_module = self._get_target_layer() | |
| fwd_handle = target_module.register_forward_hook(self._forward_hook) | |
| bwd_handle = target_module.register_full_backward_hook(self._backward_hook) | |
| self.hook_handles = [fwd_handle, bwd_handle] | |
| def _remove_hooks(self): | |
| for handle in self.hook_handles: | |
| handle.remove() | |
| self.hook_handles = [] | |
| self.activations = None | |
| self.gradients = None | |
| def generate(self, input_tensor: torch.Tensor) -> np.ndarray: | |
| """Generate Grad-CAM heatmap with improved gradient handling.""" | |
| # Set model to eval but keep gradient computation | |
| self.model.eval() | |
| # Ensure tensor requires grad | |
| input_tensor = input_tensor.requires_grad_(True) | |
| # Forward pass | |
| output = self.model(input_tensor) | |
| # Handle different output types | |
| if isinstance(output, dict): | |
| # Try common keys | |
| for key in ('pred', 'output', 'bvp', 'logits', 'out'): | |
| if key in output and isinstance(output[key], torch.Tensor): | |
| output = output[key] | |
| break | |
| # Get scalar for backward | |
| if output.numel() == 1: | |
| target = output | |
| else: | |
| # Use mean of output | |
| target = output.mean() | |
| print(f"[attention_viz] Output shape: {output.shape}, Target: {target.item():.4f}") | |
| # Backward pass | |
| self.model.zero_grad() | |
| target.backward(retain_graph=False) | |
| # Check if we got gradients | |
| if self.activations is None: | |
| print("⚠ No activations captured!") | |
| return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1])) | |
| if self.gradients is None: | |
| print("⚠ No gradients captured!") | |
| return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1])) | |
| print(f"[attention_viz] Activations shape: {self.activations.shape}") | |
| print(f"[attention_viz] Gradients shape: {self.gradients.shape}") | |
| activations = self.activations | |
| gradients = self.gradients | |
| # Compute weights (global average pooling of gradients) | |
| if gradients.dim() == 5: # [B, C, T, H, W] | |
| weights = gradients.mean(dim=[2, 3, 4], keepdim=True) | |
| elif gradients.dim() == 4: # [B, C, H, W] | |
| weights = gradients.mean(dim=[2, 3], keepdim=True) | |
| elif gradients.dim() == 3: # [B, C, T] | |
| weights = gradients.mean(dim=2, keepdim=True).unsqueeze(-1).unsqueeze(-1) | |
| else: | |
| print(f"⚠ Unexpected gradient dimensions: {gradients.dim()}") | |
| return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1])) | |
| # Weighted combination | |
| cam = (weights * activations).sum(dim=1, keepdim=True) | |
| # If 5D, average over time | |
| if cam.dim() == 5: | |
| cam = cam.mean(dim=2) | |
| # ReLU | |
| cam = F.relu(cam) | |
| # Convert to numpy | |
| cam = cam.squeeze().cpu().detach().numpy() | |
| # Normalize | |
| if cam.max() > cam.min(): | |
| cam = (cam - cam.min()) / (cam.max() - cam.min()) | |
| print(f"[attention_viz] Final heatmap shape: {cam.shape}, range: [{cam.min():.3f}, {cam.max():.3f}]") | |
| return cam | |
| def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray: | |
| """Overlay heatmap on frame.""" | |
| h, w = frame.shape[:2] | |
| heatmap_resized = cv2.resize(heatmap, (w, h)) | |
| heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8) | |
| heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) | |
| overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0) | |
| return overlay | |
| def cleanup(self): | |
| self._remove_hooks() | |
| def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray: | |
| """Apply DiffNormalized preprocessing from PhysMamba paper.""" | |
| if len(frames) < 2: | |
| return np.zeros((len(frames), *frames[0].shape), dtype=np.float32) | |
| diff_frames = [] | |
| for i in range(len(frames)): | |
| if i == 0: | |
| diff_frames.append(np.zeros_like(frames[0], dtype=np.float32)) | |
| else: | |
| curr = frames[i].astype(np.float32) | |
| prev = frames[i-1].astype(np.float32) | |
| denominator = curr + prev + 1e-8 | |
| diff = (curr - prev) / denominator | |
| diff_frames.append(diff) | |
| diff_array = np.stack(diff_frames) | |
| std = diff_array.std() | |
| if std > 0: | |
| diff_array = diff_array / std | |
| return diff_array | |
| def preprocess_for_physmamba(frames: List[np.ndarray], | |
| target_frames: int = 128, | |
| target_size: int = 128) -> torch.Tensor: | |
| """Complete preprocessing pipeline for PhysMamba model.""" | |
| if len(frames) < target_frames: | |
| frames = frames + [frames[-1]] * (target_frames - len(frames)) | |
| elif len(frames) > target_frames: | |
| indices = np.linspace(0, len(frames)-1, target_frames).astype(int) | |
| frames = [frames[i] for i in indices] | |
| frames_rgb = [f[..., ::-1].copy() for f in frames] | |
| frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb] | |
| frames_diff = apply_diff_normalized(frames_resized) | |
| frames_transposed = np.transpose(frames_diff, (3, 0, 1, 2)) | |
| frames_batched = np.expand_dims(frames_transposed, axis=0) | |
| tensor = torch.from_numpy(frames_batched.astype(np.float32)) | |
| return tensor | |
| HERE = Path(__file__).parent | |
| MODEL_DIR = HERE / "final_model_release" | |
| LOG_DIR = HERE / "logs" | |
| ANALYSIS_DIR = HERE / "analysis" | |
| for d in [MODEL_DIR, LOG_DIR, ANALYSIS_DIR]: | |
| d.mkdir(exist_ok=True) | |
| DEVICE = ( | |
| torch.device("cuda") if torch.cuda.is_available() | |
| else torch.device("mps") if torch.backends.mps.is_available() | |
| else torch.device("cpu") | |
| ) | |
| FACE_CASCADE = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml") | |
| DEFAULT_SIZE = 128 # input H=W to model | |
| DEFAULT_T = 128 # clip length | |
| DEFAULT_STRIDE = 8 # inference hop | |
| DISPLAY_FPS = 10 | |
| VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'] | |
| _HR_SMOOTH = None | |
| REST_HR_TARGET = 72.0 | |
| REST_HR_RANGE = (55.0, 95.0) | |
| MAX_JUMP_BPM = 8.0 | |
| # Recognized GT files for subject folders | |
| GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"} | |
| GT_EXTS = {".txt", ".csv", ".json"} | |
| def _as_path(maybe) -> Optional[str]: | |
| """Return a filesystem path from Gradio values (str, dict, tempfile objects).""" | |
| if maybe is None: | |
| return None | |
| if isinstance(maybe, str): | |
| return maybe | |
| if isinstance(maybe, dict): | |
| return maybe.get("name") or maybe.get("path") | |
| name = getattr(maybe, "name", None) # tempfile-like object | |
| if isinstance(name, str): | |
| return name | |
| try: | |
| return str(maybe) | |
| except Exception: | |
| return None | |
| def _import_from_file(py_path: Path): | |
| spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path)) | |
| if not spec or not spec.loader: | |
| raise ImportError(f"Cannot import module from {py_path}") | |
| mod = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(mod) | |
| return mod | |
| def _looks_like_video(p: Path) -> bool: | |
| if p.suffix.lower() == ".mat": | |
| return True | |
| return p.suffix.lower() in VIDEO_EXTENSIONS | |
| class SimpleActivationAttention: | |
| """Lightweight attention visualization without gradients.""" | |
| def __init__(self, model: nn.Module, device: torch.device): | |
| self.model = model | |
| self.device = device | |
| self.activations = None | |
| self.hook_handle = None | |
| def _activation_hook(self, module, input, output): | |
| """Capture activations during forward pass.""" | |
| self.activations = output.detach() | |
| def register_hook(self): | |
| """Register hook on a suitable layer.""" | |
| # Find the last convolutional layer before Mamba | |
| target = None | |
| target_name = None | |
| for name, module in self.model.named_modules(): | |
| if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower(): | |
| target = module | |
| target_name = name | |
| if target is None: | |
| print("⚠ [attention_viz] No suitable conv layer found, attention disabled") | |
| return | |
| self.hook_handle = target.register_forward_hook(self._activation_hook) | |
| print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})") | |
| def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]: | |
| """Generate attention map from activations (call after forward pass).""" | |
| try: | |
| if self.activations is None: | |
| return None | |
| # Process activations to create spatial attention | |
| act = self.activations | |
| # Handle different tensor shapes | |
| if act.dim() == 5: # [B, C, T, H, W] | |
| # Average over time and channels | |
| attention = act.mean(dim=[1, 2]) # -> [B, H, W] | |
| elif act.dim() == 4: # [B, C, H, W] | |
| attention = act.mean(dim=1) # -> [B, H, W] | |
| else: | |
| print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}") | |
| return None | |
| # Convert to numpy | |
| attention = attention.squeeze().cpu().numpy() | |
| # Normalize to [0, 1] | |
| if attention.max() > attention.min(): | |
| attention = (attention - attention.min()) / (attention.max() - attention.min()) | |
| return attention | |
| except Exception as e: | |
| print(f"⚠ [attention_viz] Generation failed: {e}") | |
| return None | |
| def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray: | |
| """Overlay heatmap on frame.""" | |
| h, w = frame.shape[:2] | |
| heatmap_resized = cv2.resize(heatmap, (w, h)) | |
| heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8) | |
| heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) | |
| overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0) | |
| return overlay | |
| def cleanup(self): | |
| if self.hook_handle is not None: | |
| self.hook_handle.remove() | |
| class VideoReader: | |
| """ | |
| Unified frame reader: | |
| • Regular videos via cv2.VideoCapture | |
| • .mat 'videos' (e.g., SCAMPS): expects array like (T,H,W,3) or (H,W,T,3) or (T,H,W) | |
| Returns frames as BGR uint8. | |
| """ | |
| def __init__(self, path: str): | |
| self.path = str(path) | |
| self._cap = None | |
| self._mat = None | |
| self._idx = 0 | |
| self._len = 0 | |
| self._shape = None | |
| if self.path.lower().endswith(".mat") and MAT_SUPPORT: | |
| self._open_mat(self.path) | |
| else: | |
| self._open_cv(self.path) | |
| def _open_cv(self, path: str): | |
| cap = cv2.VideoCapture(path) | |
| if not cap.isOpened(): | |
| raise RuntimeError("Cannot open video") | |
| self._cap = cap | |
| self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| def _open_mat(self, path: str): | |
| try: | |
| md = loadmat(path) | |
| # Common keys in SCAMPS-like dumps | |
| for key in ("video", "frames", "vid", "data"): | |
| if key in md and isinstance(md[key], np.ndarray): | |
| arr = md[key] | |
| break | |
| else: | |
| arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None) | |
| if arr is None: | |
| raise RuntimeError("No ndarray found in .mat") | |
| a = np.asarray(arr) | |
| # Normalize to (T,H,W,3) | |
| if a.ndim == 4: | |
| if a.shape[-1] == 3: | |
| if a.shape[0] < a.shape[2]: # (T,H,W,3) heuristic | |
| v = a | |
| else: # (H,W,T,3) -> (T,H,W,3) | |
| v = np.transpose(a, (2, 0, 1, 3)) | |
| else: | |
| v = a[..., :1] # take first channel | |
| elif a.ndim == 3: | |
| if a.shape[0] < a.shape[2]: # (T,H,W) | |
| v = a | |
| else: # (H,W,T) -> (T,H,W) | |
| v = np.transpose(a, (2, 0, 1)) | |
| v = v[..., None] | |
| else: | |
| raise RuntimeError(f"Unsupported .mat video shape: {a.shape}") | |
| v = np.ascontiguousarray(v) | |
| if v.shape[-1] == 1: | |
| v = np.repeat(v, 3, axis=-1) | |
| v = v.astype(np.uint8) | |
| self._mat = v | |
| self._len = v.shape[0] | |
| self._shape = v.shape[1:3] | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to open .mat video: {e}") | |
| def read(self): | |
| """Return (ret, frame[BGR]) like cv2.VideoCapture.read().""" | |
| if self._mat is not None: | |
| if self._idx >= self._len: | |
| return False, None | |
| frame = self._mat[self._idx] | |
| self._idx += 1 | |
| return True, frame | |
| else: | |
| return self._cap.read() | |
| def fps(self, fallback: int = 30) -> int: | |
| if self._mat is not None: | |
| return fallback # .mat typically lacks FPS; caller can override | |
| f = self._cap.get(cv2.CAP_PROP_FPS) | |
| return int(f) if f and f > 0 else fallback | |
| def length(self) -> int: | |
| return self._len | |
| def release(self): | |
| if self._cap is not None: | |
| self._cap.release() | |
| def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]: | |
| x, y, w, h = face | |
| # forehead | |
| fh = frame[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)] | |
| # cheeks | |
| ck = frame[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)] | |
| # full face | |
| ff = frame[y:y + h, x:x + w] | |
| return {"forehead": fh, "cheeks": ck, "face": ff} | |
| def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float: | |
| if patch is None or patch.size == 0: | |
| return -1e9 | |
| g = patch[..., 1].astype(np.float32) / 255.0 # green channel | |
| g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling | |
| g = g - g.mean() | |
| b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band") | |
| try: | |
| y = signal.filtfilt(b, a, g, method="gust") | |
| except Exception: | |
| y = g | |
| return float((y ** 2).mean()) | |
| def pick_auto_roi(face: Tuple[int, int, int, int], | |
| frame: np.ndarray, | |
| attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]: | |
| """Simple ROI selection.""" | |
| cands = roi_candidates(face, frame) | |
| scores = {k: roi_quality_score(v) for k, v in cands.items()} | |
| if attn is not None and attn.size: | |
| H, W = frame.shape[:2] | |
| try: | |
| attn_resized = cv2.resize(attn, (W, H)) | |
| x, y, w, h = face | |
| fh_attn = attn_resized[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)].mean() if attn_resized.size > 0 else 0.0 | |
| ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0 | |
| ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0 | |
| scores['forehead'] += fh_attn * 0.2 | |
| scores['cheeks'] += ck_attn * 0.2 | |
| scores['face'] += ff_attn * 0.2 | |
| except Exception: | |
| pass | |
| best = max(scores, key=scores.get) | |
| return cands[best], best | |
| def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]: | |
| """ | |
| Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None). | |
| Heuristics: | |
| - Subject folder: any directory containing at least one video-like file (.mat or common video). | |
| - If multiple videos, pick the largest by size. | |
| - GT file: prefer known names in GT_FILENAMES, else any .txt/.csv/.json not named readme. | |
| """ | |
| pairs: List[Tuple[str, Optional[str]]] = [] | |
| if not root_dir.exists(): | |
| return pairs | |
| def pick_pair(folder: Path) -> Optional[Tuple[str, Optional[str]]]: | |
| vids = [p for p in folder.rglob("*") if p.is_file() and _looks_like_video(p)] | |
| if not vids: | |
| return None | |
| vids.sort(key=lambda p: p.stat().st_size if p.exists() else 0, reverse=True) | |
| video = vids[0] | |
| gt: Optional[Path] = None | |
| for p in folder.rglob("*"): | |
| if p.is_file() and p.name.lower() in GT_FILENAMES: | |
| gt = p | |
| break | |
| if gt is None: | |
| cands = [ | |
| p for p in folder.rglob("*") | |
| if p.is_file() and p.suffix.lower() in GT_EXTS and "readme" not in p.name.lower() | |
| ] | |
| if cands: | |
| gt = cands[0] | |
| return str(video), (str(gt) if gt else None) | |
| subs = [d for d in root_dir.iterdir() if d.is_dir()] | |
| if subs: | |
| for sub in subs: | |
| pair = pick_pair(sub) | |
| if pair: | |
| pairs.append(pair) | |
| else: | |
| pair = pick_pair(root_dir) # the root itself might be a single-subject folder | |
| if pair: | |
| pairs.append(pair) | |
| # Deduplicate | |
| seen = set() | |
| uniq: List[Tuple[str, Optional[str]]] = [] | |
| for v, g in pairs: | |
| key = (v, g or "") | |
| if key not in seen: | |
| seen.add(key) | |
| uniq.append((v, g)) | |
| return uniq | |
| def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"): | |
| import inspect | |
| if model_file: | |
| model_path = (repo_root / model_file).resolve() | |
| if model_path.exists(): | |
| try: | |
| mod = _import_from_file(model_path) | |
| if hasattr(mod, model_class): | |
| return getattr(mod, model_class) | |
| except Exception: | |
| pass | |
| search_dirs = [ | |
| repo_root / "neural_methods" / "model", | |
| repo_root / "neural_methods", | |
| repo_root | |
| ] | |
| name_pattern = re.compile(r"mamba", re.IGNORECASE) | |
| for base_dir in search_dirs: | |
| if not base_dir.exists(): | |
| continue | |
| for py_file in base_dir.rglob("*.py"): | |
| if "__pycache__" in str(py_file) or "mamba_ssm" in str(py_file): | |
| continue | |
| try: | |
| mod = _import_from_file(py_file) | |
| for name, obj in inspect.getmembers(mod): | |
| if callable(obj) and name_pattern.search(name) and "ssm" not in name.lower(): | |
| if inspect.isclass(obj) and issubclass(obj, nn.Module): | |
| return obj | |
| except Exception: | |
| continue | |
| raise ImportError(f"Could not find PhysMamba model class") | |
| def load_physmamba_model(ckpt_path: Path, device: torch.device, | |
| model_file: str = "", model_class: str = "PhysMamba"): | |
| repo_root = Path(".").resolve() | |
| Builder = find_physmamba_builder(repo_root, model_file, model_class) | |
| import inspect | |
| ctor_trials = [ | |
| {}, | |
| {"d_model": 96}, | |
| {"dim": 96}, | |
| {"d_model": 96, "frames": 128, "img_size": 128, "in_chans": 3}, | |
| {"frames": 128, "img_size": 128, "in_chans": 3}, | |
| {"frame_depth": 3}, | |
| ] | |
| model = None | |
| for kwargs in ctor_trials: | |
| try: | |
| candidate = Builder(**kwargs) if inspect.isclass(Builder) else Builder(**kwargs) | |
| if isinstance(candidate, nn.Module): | |
| model = candidate | |
| break | |
| except Exception: | |
| continue | |
| if model is None: | |
| raise RuntimeError("Could not construct PhysMamba model") | |
| try: | |
| checkpoint = torch.load(str(ckpt_path), map_location="cpu") | |
| state_dict = checkpoint.get("state_dict", checkpoint) | |
| try: | |
| model.load_state_dict(state_dict, strict=False) | |
| except Exception: | |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict, strict=False) | |
| except Exception: | |
| pass | |
| model.to(device).eval() | |
| try: | |
| with torch.no_grad(): | |
| _ = model(torch.zeros(1, 3, 8, 128, 128, device=device)) | |
| except Exception: | |
| pass | |
| # Disable attention visualization since model forward pass is incompatible | |
| attention_viz = None | |
| return model, attention_viz | |
| def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray: | |
| """ | |
| Stable band-pass with edge-safety and parameter clipping. | |
| """ | |
| x = np.asarray(x, dtype=float) | |
| n = int(fs * 2) | |
| if x.size < max(n, 8): # need at least ~2s | |
| return x | |
| nyq = 0.5 * fs | |
| lo = max(low / nyq, 1e-6) | |
| hi = min(high / nyq, 0.999_999) | |
| if not (0.0 < lo < hi < 1.0): | |
| return x | |
| try: | |
| b, a = signal.butter(order, [lo, hi], btype="band") | |
| # padlen must be < len(x); reduce when short | |
| padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1)) | |
| return signal.filtfilt(b, a, x, padlen=padlen) | |
| except Exception: | |
| return x | |
| def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float: | |
| """ | |
| HR (BPM) via Welch PSD peak in [lo, hi] Hz. | |
| """ | |
| x = np.asarray(x, dtype=float) | |
| if x.size < int(fs * 4.0): # need ~4s for a usable PSD | |
| return 0.0 | |
| try: | |
| # nperseg tuned for short windows while avoiding tiny segments | |
| nper = int(min(max(64, fs * 2), min(512, x.size))) | |
| f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant") | |
| mask = (f >= lo) & (f <= hi) | |
| if not np.any(mask): | |
| return 0.0 | |
| f_band = f[mask] | |
| p_band = pxx[mask] | |
| if p_band.size == 0 or np.all(~np.isfinite(p_band)): | |
| return 0.0 | |
| fpk = float(f_band[np.argmax(p_band)]) | |
| bpm = fpk * 60.0 | |
| # clip to plausible range | |
| return float(np.clip(bpm, 30.0, 220.0)) | |
| except Exception: | |
| return 0.0 | |
| def compute_rmssd(x: np.ndarray, fs: int = 30) -> float: | |
| """ | |
| HRV RMSSD from peaks; robust to short/flat segments. | |
| """ | |
| x = np.asarray(x, dtype=float) | |
| if x.size < int(fs * 5.0): | |
| return 0.0 | |
| try: | |
| # peak distance ~ 0.5s minimum (avoid double counting) | |
| peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs))) | |
| if len(peaks) < 3: | |
| return 0.0 | |
| rr = np.diff(peaks) / fs * 1000.0 # ms | |
| if rr.size < 2: | |
| return 0.0 | |
| return float(np.sqrt(np.mean(np.diff(rr) ** 2))) | |
| except Exception: | |
| return 0.0 | |
| def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]: | |
| """ | |
| Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band. | |
| Signature unchanged to avoid breaking callers. | |
| """ | |
| global _HR_SMOOTH | |
| y = np.asarray(pred, dtype=float) | |
| if y.size == 0: | |
| return y, 0.0 | |
| # 1) band-limit | |
| y_filt = bandpass_filter(y, fs=fs, low=0.7, high=3.5, order=4) | |
| # 2) HR estimate | |
| hr = hr_from_welch(y_filt, fs=fs, lo=0.7, hi=3.5) | |
| # 3) gentle attraction to resting band (if way off) | |
| if hr > 0: | |
| lo, hi = REST_HR_RANGE | |
| if hr < lo or hr > hi: | |
| dist = abs(hr - REST_HR_TARGET) | |
| # farther away -> stronger pull | |
| alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65)) | |
| hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET | |
| # 4) temporal smoothing to limit frame-to-frame jumps | |
| if hr > 0: | |
| if _HR_SMOOTH is None: | |
| _HR_SMOOTH = hr | |
| else: | |
| step = float(np.clip(hr - _HR_SMOOTH, -MAX_JUMP_BPM, MAX_JUMP_BPM)) | |
| _HR_SMOOTH = _HR_SMOOTH + 0.6 * step | |
| hr = float(_HR_SMOOTH) | |
| return y_filt, float(hr) | |
| def draw_face_and_roi(frame_bgr: np.ndarray, | |
| face_bbox: Optional[Tuple[int, int, int, int]], | |
| roi_bbox: Optional[Tuple[int, int, int, int]], | |
| label: str = "ROI") -> np.ndarray: | |
| """ | |
| Draw face (green) and ROI (cyan) rectangles on a copy of the frame. | |
| """ | |
| vis = frame_bgr.copy() | |
| if face_bbox is not None: | |
| x, y, w, h = face_bbox | |
| cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 230, 0), 2) | |
| cv2.putText(vis, "FACE", (x, max(20, y - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 230, 0), 2) | |
| if roi_bbox is not None: | |
| rx, ry, rw, rh = roi_bbox | |
| cv2.rectangle(vis, (rx, ry), (rx + rw, ry + rh), (255, 220, 0), 2) | |
| cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2) | |
| return vis | |
| def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int], | |
| roi_type: str, | |
| frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]: | |
| """ | |
| Compute the ROI rectangle (x,y,w,h) from a face bbox and your ROI rule. | |
| Matches your crop_roi geometry. | |
| """ | |
| x, y, w, h = face_bbox | |
| H, W = frame_shape[:2] | |
| if roi_type == "forehead": | |
| rx = int(x + 0.25 * w); rw = int(0.5 * w) | |
| ry = int(y + 0.10 * h); rh = int(0.20 * h) | |
| elif roi_type == "cheeks": | |
| rx = int(x + 0.15 * w); rw = int(0.70 * w) | |
| ry = int(y + 0.55 * h); rh = int(0.30 * h) | |
| else: | |
| rx, ry, rw, rh = x, y, w, h | |
| rx2 = min(W, rx + rw) | |
| ry2 = min(H, ry + rh) | |
| rx = max(0, rx); ry = max(0, ry) | |
| if rx2 <= rx or ry2 <= ry: | |
| return (0, 0, 0, 0) | |
| return (rx, ry, rx2 - rx, ry2 - ry) | |
| def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray: | |
| """ | |
| Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR. | |
| Assumes chw = (3, H, W) with zero-mean, unit-var normalization per-frame. | |
| """ | |
| if chw is None or chw.ndim != 3 or chw.shape[0] != 3: | |
| return np.zeros((128, 128, 3), dtype=np.uint8) | |
| # Undo channel-first & normalization to a viewable image | |
| img = chw.copy() | |
| # Re-normalize to 0..1 by min-max of the tensor to "show" contrast | |
| vmin, vmax = float(img.min()), float(img.max()) | |
| if vmax <= vmin + 1e-6: | |
| img = np.zeros_like(img) | |
| else: | |
| img = (img - vmin) / (vmax - vmin) | |
| img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR | |
| return img | |
| def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]: | |
| if gt_len <= 1: | |
| return None | |
| if gt_fs and gt_fs > 0: | |
| return np.arange(gt_len, dtype=float) / float(gt_fs) | |
| return None # will fall back to length-matching overlay | |
| def plot_signals_with_gt(time_axis: np.ndarray, | |
| raw_signal: np.ndarray, | |
| post_signal: np.ndarray, | |
| fs: int, | |
| out_path: str, | |
| gt_time: Optional[np.ndarray] = None, | |
| gt_bvp: Optional[np.ndarray] = None, | |
| title: str = "rPPG Signals (Pred vs GT)") -> str: | |
| """ | |
| Save a 3-pane plot: (1) predicted raw, (2) predicted post, (3) overlay Pred vs GT (normalized). | |
| If GT is provided, it is resampled to the prediction time grid and lag-aligned | |
| (±5 s) using cross-correlation. The overlay includes Pearson r, lag, and HR stats. | |
| """ | |
| import numpy as _np | |
| import matplotlib.pyplot as _plt | |
| from matplotlib.gridspec import GridSpec as _GridSpec | |
| def z(x): | |
| x = _np.asarray(x, dtype=float) | |
| if x.size == 0: | |
| return x | |
| m = float(_np.nanmean(x)) | |
| s = float(_np.nanstd(x)) + 1e-8 | |
| return (x - m) / s | |
| def _bandpass(x, fs_local, lo=0.7, hi=3.5, order=4): | |
| try: | |
| return bandpass_filter(_np.asarray(x, float), fs=fs_local, low=lo, high=hi, order=order) | |
| except Exception: | |
| return _np.asarray(x, float) | |
| def _welch_hr(x, fs_local): | |
| try: | |
| return float(hr_from_welch(_np.asarray(x, float), fs=fs_local, lo=0.7, hi=3.5)) | |
| except Exception: | |
| return 0.0 | |
| def _safe_interp(x_t, y, t_new): | |
| """Monotonic-safe 1D interpolation with clipping to valid domain.""" | |
| x_t = _np.asarray(x_t, dtype=float).ravel() | |
| y = _np.asarray(y, dtype=float).ravel() | |
| t_new = _np.asarray(t_new, dtype=float).ravel() | |
| if x_t.size < 2 or y.size != x_t.size: | |
| # Fallback: length-based resize to t_new length | |
| if y.size == 0 or t_new.size == 0: | |
| return _np.zeros_like(t_new) | |
| idx = _np.linspace(0, y.size - 1, num=t_new.size) | |
| return _np.interp(_np.arange(t_new.size), idx, y) | |
| # Enforce strictly increasing time (dedup if needed) | |
| order = _np.argsort(x_t) | |
| x_t = x_t[order] | |
| y = y[order] | |
| mask = _np.concatenate(([True], _np.diff(x_t) > 0)) | |
| x_t = x_t[mask] | |
| y = y[mask] | |
| # Clip t_new to the valid domain to avoid edge extrapolation artifacts | |
| t_clip = _np.clip(t_new, x_t[0], x_t[-1]) | |
| return _np.interp(t_clip, x_t, y) | |
| def _best_lag(pred, gt, fs_local, max_lag_s=5.0): | |
| """Return lag (sec) that best aligns GT to Pred using cross-correlation on z-scored signals.""" | |
| x = z(pred); y = z(gt) | |
| if x.size < 8 or y.size < 8: | |
| return 0.0 | |
| n = int(min(len(x), len(y))) | |
| x = x[:n]; y = y[:n] | |
| max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local)))) | |
| # valid lags: negative means GT should be shifted left (advance) relative to Pred | |
| lags = _np.arange(-max_lag, max_lag + 1) | |
| # compute correlation for each lag | |
| best_corr = -_np.inf | |
| best_lag = 0 | |
| for L in lags: | |
| if L < 0: | |
| xx = x[-L:n] | |
| yy = y[0:n+L] | |
| elif L > 0: | |
| xx = x[0:n-L] | |
| yy = y[L:n] | |
| else: | |
| xx = x | |
| yy = y | |
| if xx.size < 8 or yy.size < 8: | |
| continue | |
| c = _np.corrcoef(xx, yy)[0, 1] | |
| if _np.isfinite(c) and c > best_corr: | |
| best_corr = c | |
| best_lag = L | |
| return float(best_lag / float(fs_local)) | |
| def _apply_lag(y, lag_sec, fs_local): | |
| """Shift y by lag_sec (positive => delay GT) using sample roll; edges set to NaN.""" | |
| y = _np.asarray(y, float) | |
| if y.size == 0 or fs_local <= 0: | |
| return y | |
| shift = int(round(lag_sec * fs_local)) | |
| if shift == 0: | |
| return y | |
| out = _np.empty_like(y) | |
| out[:] = _np.nan | |
| if shift > 0: | |
| # delay: move content right | |
| out[shift:] = y[:-shift] | |
| else: | |
| # advance: move content left | |
| out[:shift] = y[-shift:] | |
| return out | |
| t = _np.asarray(time_axis, dtype=float) | |
| raw = _np.asarray(raw_signal, dtype=float) | |
| post = _np.asarray(post_signal, dtype=float) | |
| # guard | |
| if t.size == 0: | |
| t = _np.arange(post.size, dtype=float) / max(fs, 1) | |
| have_gt = gt_bvp is not None and _np.asarray(gt_bvp).size > 0 | |
| gt_on_pred = None | |
| lag_sec = 0.0 | |
| pearson_r = _np.nan | |
| hr_pred = _welch_hr(_bandpass(post, fs), fs) | |
| hr_gt = 0.0 | |
| if have_gt: | |
| gt = _np.asarray(gt_bvp, dtype=float).ravel() | |
| if gt_time is not None and _np.asarray(gt_time).size == gt.size: | |
| gt_t = _np.asarray(gt_time, dtype=float).ravel() | |
| gt_on_pred = _safe_interp(gt_t, gt, t) | |
| else: | |
| # No time vector: try length-based mapping to pred time grid | |
| gt_on_pred = _safe_interp(_np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size), | |
| gt, t) | |
| # Band-limit both before correlation/HR | |
| pred_bp = _bandpass(post, fs) | |
| gt_bp = _bandpass(gt_on_pred, fs) | |
| # Estimate best lag (sec) of GT relative to Pred | |
| lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0) | |
| # Apply lag to GT for visualization and correlation | |
| gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs) | |
| # Compute Pearson r on overlapping valid samples | |
| valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp) | |
| if valid.sum() >= 16: | |
| pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1]) | |
| else: | |
| pearson_r = _np.nan | |
| hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs) | |
| _plt.figure(figsize=(13, 6), dpi=110) | |
| gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35) | |
| # (1) Raw Pred | |
| ax1 = _plt.subplot(gs[0, 0]) | |
| ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5) | |
| ax1.set_title(f"Predicted (Raw) — fs={fs} Hz") | |
| ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude") | |
| ax1.grid(True, alpha=0.3) | |
| # (2) Post Pred | |
| ax2 = _plt.subplot(gs[0, 1]) | |
| ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5) | |
| ax2.set_title("Predicted (Post-processed)") | |
| ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude") | |
| ax2.grid(True, alpha=0.3) | |
| # (3) Overlay Pred vs GT (z-scored) OR just post | |
| ax3 = _plt.subplot(gs[1, :]) | |
| ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6) | |
| if have_gt and gt_on_pred is not None: | |
| gt_bp = _bandpass(gt_on_pred, fs) | |
| gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs) | |
| ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9) | |
| # metrics box | |
| txt = [ | |
| f"HR_pred: {hr_pred:.1f} BPM", | |
| f"HR_gt: {hr_gt:.1f} BPM", | |
| f"Pearson r: {pearson_r:.3f}" if _np.isfinite(pearson_r) else "Pearson r: --", | |
| f"Lag: {lag_sec:+.2f} s" | |
| ] | |
| ax3.text(0.01, 0.98, "\n".join(txt), transform=ax3.transAxes, | |
| va="top", ha="left", fontsize=9, | |
| bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="0.8", alpha=0.9)) | |
| ax3.set_title("Pred vs GT (z-score overlay, lag-aligned)") | |
| else: | |
| ax3.set_title("Pred vs GT (no GT provided)") | |
| ax3.set_xlabel("Time (s)"); ax3.set_ylabel("z") | |
| ax3.grid(True, alpha=0.3) | |
| ax3.legend(loc="upper right") | |
| _plt.suptitle(title, fontweight="bold") | |
| _plt.tight_layout(rect=[0, 0.02, 1, 0.97]) | |
| _plt.savefig(out_path, bbox_inches="tight") | |
| _plt.close() | |
| return out_path | |
| def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]: | |
| """ | |
| Robust single-face detector with a few practical guards: | |
| - converts to gray safely | |
| - equalizes histogram (helps underexposure) | |
| - tries multiple scales / minNeighbors | |
| - returns the largest face bbox (x,y,w,h) or None | |
| """ | |
| if frame is None or frame.size == 0: | |
| return None | |
| try: | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| except Exception: | |
| # If color conversion fails, assume already gray | |
| gray = frame.copy() if frame.ndim == 2 else cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY) | |
| # Light preproc to improve Haar performance | |
| gray = cv2.equalizeHist(gray) | |
| faces_all = [] | |
| # Try a couple of parameter combos to be more forgiving | |
| params = [ | |
| dict(scaleFactor=1.05, minNeighbors=3), | |
| dict(scaleFactor=1.10, minNeighbors=4), | |
| dict(scaleFactor=1.20, minNeighbors=5), | |
| ] | |
| for p in params: | |
| try: | |
| faces = FACE_CASCADE.detectMultiScale(gray, **p) | |
| if faces is not None and len(faces) > 0: | |
| faces_all.extend([tuple(map(int, f)) for f in faces]) | |
| except Exception: | |
| continue | |
| if not faces_all: | |
| return None | |
| # Return the largest (by area) | |
| return max(faces_all, key=lambda f: f[2] * f[3]) | |
| def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]: | |
| """ | |
| Crop ROI from the frame based on a face bbox and the selected roi_type. | |
| Returns the cropped BGR ROI or None if invalid. | |
| """ | |
| if face_bbox is None or frame is None or frame.size == 0: | |
| return None | |
| x, y, w, h = map(int, face_bbox) | |
| H, W = frame.shape[:2] | |
| if roi_type == "forehead": | |
| rx = int(x + 0.25 * w); rw = int(0.50 * w) | |
| ry = int(y + 0.10 * h); rh = int(0.20 * h) | |
| elif roi_type == "cheeks": | |
| rx = int(x + 0.15 * w); rw = int(0.70 * w) | |
| ry = int(y + 0.55 * h); rh = int(0.30 * h) | |
| else: | |
| rx, ry, rw, rh = x, y, w, h | |
| # clamp in-bounds | |
| rx = max(0, rx); ry = max(0, ry) | |
| rx2 = min(W, rx + rw); ry2 = min(H, ry + rh) | |
| if rx2 <= rx or ry2 <= ry: | |
| return None | |
| roi = frame[ry:ry2, rx:rx2] | |
| # Avoid empty or 1-pixel slivers | |
| if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4: | |
| return None | |
| return roi | |
| def crop_roi_with_bbox(face_bbox: Tuple[int, int, int, int], | |
| roi_type: str, | |
| frame: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]: | |
| if face_bbox is None or frame is None or frame.size == 0: | |
| return None, None | |
| x, y, w, h = map(int, face_bbox) | |
| H, W = frame.shape[:2] | |
| if roi_type == "forehead": | |
| rx = int(x + 0.25 * w); rw = int(0.50 * w) | |
| ry = int(y + 0.10 * h); rh = int(0.20 * h) | |
| elif roi_type == "cheeks": | |
| rx = int(x + 0.15 * w); rw = int(0.70 * w) | |
| ry = int(y + 0.55 * h); rh = int(0.30 * h) | |
| else: | |
| rx, ry, rw, rh = x, y, w, h | |
| rx = max(0, rx); ry = max(0, ry) | |
| rx2 = min(W, rx + rw); ry2 = min(H, ry + rh) | |
| if rx2 <= rx or ry2 <= ry: | |
| return None, None | |
| roi = frame[ry:ry2, rx:rx2] | |
| if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4: | |
| return None, None | |
| return roi, (rx, ry, rx2 - rx, ry2 - ry) | |
| def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray: | |
| """ | |
| PhysMamba-compatible normalization with DiffNormalized support. | |
| Returns (3, size, size). | |
| """ | |
| if face_bgr is None or face_bgr.size == 0: | |
| return np.zeros((3, size, size), dtype=np.float32) | |
| try: | |
| face = cv2.resize(face_bgr, (size, size), interpolation=cv2.INTER_AREA) | |
| except Exception: | |
| face = cv2.resize(face_bgr, (size, size)) | |
| face = face.astype(np.float32) / 255.0 | |
| # Per-frame standardization | |
| mean = face.mean(axis=(0, 1), keepdims=True) | |
| std = face.std(axis=(0, 1), keepdims=True) + 1e-6 | |
| face = (face - mean) / std | |
| # BGR -> RGB and HWC -> CHW | |
| chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False) | |
| return chw | |
| def extract_attention_map(model, clip_tensor: torch.Tensor, | |
| attention_viz) -> Optional[np.ndarray]: | |
| """Attention visualization disabled - model architecture incompatible.""" | |
| return None | |
| def create_attention_overlay(frame: np.ndarray, attention: Optional[np.ndarray], | |
| attention_viz: Optional[SimpleActivationAttention] = None) -> np.ndarray: | |
| """Create attention heatmap overlay.""" | |
| # Attention disabled - return original frame | |
| return frame | |
| def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12): | |
| H, W = roi_bgr.shape[:2] | |
| base_bvp = forward_bvp(model, torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE)) | |
| .unsqueeze(0).unsqueeze(2).to(DEVICE)) # fake T=1 path if needed | |
| base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs) | |
| heat = np.zeros((H, W), np.float32) | |
| for y in range(0, H - patch + 1, stride): | |
| for x in range(0, W - patch + 1, stride): | |
| tmp = roi_bgr.copy() | |
| tmp[y:y+patch, x:x+patch] = 127 # occlude | |
| bvp = forward_bvp(model, torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE)) | |
| .unsqueeze(0).unsqueeze(2).to(DEVICE)) | |
| power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs) | |
| drop = max(0.0, base_power - power) | |
| heat[y:y+patch, x:x+patch] += drop | |
| heat -= heat.min() | |
| if heat.max() > 1e-8: heat /= heat.max() | |
| return heat | |
| def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor): | |
| """ | |
| Try common 5D layouts: | |
| [B, C, T, H, W] then [B, T, C, H, W]. | |
| """ | |
| last_err = None | |
| try: | |
| return model(clip_tensor) | |
| except Exception as e: | |
| last_err = e | |
| try: | |
| return model(clip_tensor.permute(0, 2, 1, 3, 4).contiguous()) | |
| except Exception as e: | |
| last_err = e | |
| raise last_err | |
| def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray: | |
| """ | |
| Forward and extract a 1D time-like BVP vector with length T_clip. | |
| Tolerant to dict/list/tuple heads and odd shapes. | |
| """ | |
| T_clip = int(clip_tensor.shape[2]) # intended time length for [B,C,T,H,W] | |
| with torch.no_grad(): | |
| out = _call_model_try_orders(model, clip_tensor) | |
| # unwrap common containers | |
| if isinstance(out, dict): | |
| for key in ("bvp", "ppg", "signal", "pred", "y", "out", "logits"): | |
| if key in out and isinstance(out[key], torch.Tensor): | |
| out = out[key] | |
| break | |
| if isinstance(out, (list, tuple)): | |
| tensors = [t for t in out if isinstance(t, torch.Tensor)] | |
| if not tensors: | |
| return np.zeros(T_clip, dtype=np.float32) | |
| def score(t: torch.Tensor): | |
| has_T = 1 if T_clip in t.shape else 0 | |
| return (has_T, t.numel()) | |
| out = max(tensors, key=score) | |
| if not isinstance(out, torch.Tensor): | |
| return np.zeros(T_clip, dtype=np.float32) | |
| out = out.detach().cpu().float() | |
| # ---- 1D | |
| if out.ndim == 1: | |
| v = out | |
| if v.shape[0] == T_clip: | |
| return v.numpy() | |
| if v.numel() == 1: | |
| return np.full(T_clip, float(v.item()), dtype=np.float32) | |
| return np.resize(v.numpy(), T_clip).astype(np.float32) | |
| # ---- 2D | |
| if out.ndim == 2: | |
| B, K = out.shape | |
| if B == 1: | |
| v = out[0] | |
| return (v.numpy() if v.shape[0] == T_clip else np.resize(v.numpy(), T_clip).astype(np.float32)) | |
| if B == T_clip: | |
| return out[:, 0].numpy() | |
| if K == T_clip: | |
| return out[0, :].numpy() | |
| return np.resize(out.flatten().numpy(), T_clip).astype(np.float32) | |
| # ---- 3D | |
| if out.ndim == 3: | |
| B, D1, D2 = out.shape[0], out.shape[1], out.shape[2] | |
| if D1 == T_clip: # [B, T, C] | |
| return out[0, :, 0].numpy() | |
| if D2 == T_clip: # [B, C, T] | |
| return out[0, 0, :].numpy() | |
| v = out.mean(dim=tuple(range(1, out.ndim))).squeeze(0) | |
| return np.resize(v.numpy(), T_clip).astype(np.float32) | |
| # ---- 4D | |
| if out.ndim == 4: | |
| B, A, H, W = out.shape | |
| if A == T_clip: # [B, T, H, W] | |
| return out[0].mean(dim=(1, 2)).numpy() | |
| v = out[0].mean(dim=(1, 2)) | |
| return np.resize(v.numpy(), T_clip).astype(np.float32) | |
| # ---- 5D+ | |
| if out.ndim >= 5: | |
| shape = list(out.shape) | |
| try: | |
| t_idx = next(i for i, s in enumerate(shape) if (i != 0 and s == T_clip)) | |
| except StopIteration: | |
| pooled = out[0].mean(dim=tuple(i for i in range(1, out.ndim) if i not in (-1,))) | |
| v = pooled.flatten() | |
| return np.resize(v.numpy(), T_clip).astype(np.float32) | |
| axes = list(range(out.ndim)) | |
| perm = [0, t_idx] + [i for i in axes[1:] if i != t_idx] | |
| o2 = out.permute(*perm) # [B, T, ...] | |
| pooled = o2.mean(dim=tuple(range(2, o2.ndim))) # -> [B, T] | |
| return pooled[0].numpy() | |
| # fallback: constant vector with mean value | |
| val = float(out.mean().item()) if out.numel() else 0.0 | |
| return np.full(T_clip, val, dtype=np.float32) | |
| def _fallback_bvp_from_means(means, fs: int) -> np.ndarray: | |
| """ | |
| Classical rPPG from green-channel means when the model yields nothing. | |
| Detrend -> bandpass -> z-normalize. | |
| """ | |
| if means is None: | |
| return np.array([], dtype=np.float32) | |
| x = np.asarray(means, dtype=np.float32) | |
| if x.size == 0: | |
| return np.array([], dtype=np.float32) | |
| x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) | |
| try: | |
| x = signal.detrend(x, type="linear") | |
| except Exception: | |
| pass | |
| y = bandpass_filter(x, fs=fs, low=0.7, high=3.5, order=4) | |
| std = float(np.std(y)) + 1e-6 | |
| return (y / std).astype(np.float32) | |
| def _to_floats(s: str) -> List[float]: | |
| """ | |
| Extract all real numbers from free-form text, including scientific notation. | |
| Gracefully ignores 'nan', 'inf', units, and comments. | |
| """ | |
| if not isinstance(s, str) or not s: | |
| return [] | |
| s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE) | |
| s = s.replace(",", " ").replace(";", " ") | |
| toks = re.findall(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?", s) | |
| out: List[float] = [] | |
| for t in toks: | |
| try: | |
| v = float(t) | |
| if np.isfinite(v): | |
| out.append(v) | |
| except Exception: | |
| continue | |
| return out | |
| def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]: | |
| """ | |
| Parse ground-truth files from: | |
| • UBFC/UBFC-rPPG style TXT: 3 lines => PPG, HR, timestep(s) | |
| • Generic TXT: one column (or free-form) numeric sequence | |
| • JSON: keys like 'ppg' / 'bvp' (optionally nested), 'hr', 'fs' | |
| • CSV: columns named BVP/PPG/Signal + optional HR + optional Time | |
| • MAT: array under keys ['ppg','bvp','signal','wave']; optionally 'fs'/'hr' | |
| • NPY: np.load() 1-D array; optionally sidecar .json with fs/hr (same stem) | |
| Returns: | |
| bvp : np.ndarray (float) — may be empty | |
| hr : float (mean HR if available or estimated) | |
| fs_hint : float — sampling rate if derivable (0.0 if unknown) | |
| """ | |
| if not gt_path or not os.path.exists(gt_path): | |
| return np.array([]), 0.0, 0.0 | |
| p = Path(gt_path) | |
| ext = p.suffix.lower() | |
| def _fs_from_time_vector(tv: np.ndarray) -> float: | |
| tv = np.asarray(tv, dtype=float) | |
| if tv.ndim != 1 or tv.size < 2: | |
| return 0.0 | |
| diffs = np.diff(tv) | |
| diffs = diffs[np.isfinite(diffs) & (diffs > 0)] | |
| return (1.0 / float(np.median(diffs))) if diffs.size else 0.0 | |
| def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float: | |
| if bvp is None or bvp.size == 0: | |
| return 0.0 | |
| fs_use = fs_hint if (fs_hint and fs_hint > 0) else 30.0 | |
| bp = bandpass_filter(bvp.astype(float), fs=fs_use) | |
| return hr_from_welch(bp, fs=fs_use) | |
| if p.name.lower() == "ground_truth.txt" or (ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2): | |
| try: | |
| lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines() if ln.strip()] | |
| ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else [] | |
| hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else [] | |
| t_vals = _to_floats(lines[2]) if len(lines) >= 3 else [] | |
| bvp = np.asarray(ppg_vals, dtype=float) if ppg_vals else np.array([], dtype=float) | |
| hr = float(np.nanmean(hr_vals)) if hr_vals else 0.0 | |
| fs_hint = 0.0 | |
| if t_vals: | |
| if len(t_vals) == 1: | |
| dt = float(t_vals[0]) | |
| if dt > 0: | |
| fs_hint = 1.0 / dt | |
| else: | |
| fs_hint = _fs_from_time_vector(np.asarray(t_vals, dtype=float)) | |
| if hr == 0.0 and bvp.size: | |
| hr = _hr_from_bvp(bvp, fs_hint) | |
| return bvp, hr, fs_hint | |
| except Exception: | |
| # Fall through to generic handlers | |
| pass | |
| if ext == ".txt": | |
| try: | |
| nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore")) | |
| bvp = np.asarray(nums, dtype=float) if nums else np.array([], dtype=float) | |
| hr = _hr_from_bvp(bvp, fs_hint=0.0) if bvp.size else 0.0 | |
| return bvp, hr, 0.0 | |
| except Exception: | |
| return np.array([]), 0.0, 0.0 | |
| if ext == ".json": | |
| try: | |
| data = json.loads(p.read_text(encoding="utf-8", errors="ignore")) | |
| # Try several paths for BVP array | |
| bvp = None | |
| def _seek(obj, keys): | |
| for k in keys: | |
| if isinstance(obj, dict) and k in obj: | |
| return obj[k] | |
| return None | |
| # Direct top-level | |
| bvp = _seek(data, ("ppg", "bvp", "signal", "wave")) | |
| # Common nested containers | |
| if bvp is None: | |
| for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"): | |
| if container_key in data: | |
| cand = _seek(data[container_key], ("ppg", "bvp", "signal", "wave")) | |
| if cand is not None: | |
| bvp = cand | |
| break | |
| if bvp is not None: | |
| bvp = np.asarray(bvp, dtype=float).ravel() | |
| else: | |
| bvp = np.array([], dtype=float) | |
| # fs / hr (accept scalar or array) | |
| fs_hint = 0.0 | |
| if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0: | |
| fs_hint = float(data["fs"]) | |
| hr = 0.0 | |
| if "hr" in data: | |
| v = data["hr"] | |
| hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v) | |
| if hr == 0.0 and bvp.size: | |
| hr = _hr_from_bvp(bvp, fs_hint) | |
| return bvp, hr, fs_hint | |
| except Exception: | |
| return np.array([]), 0.0, 0.0 | |
| if ext == ".csv": | |
| try: | |
| df = pd.read_csv(p) | |
| # Normalize column names | |
| cols = {str(c).strip().lower(): c for c in df.columns} | |
| def _first_match(names): | |
| for nm in names: | |
| if nm in cols: | |
| return cols[nm] | |
| return None | |
| bvp_col = _first_match(["bvp", "ppg", "wave", "signal", "bvp_signal", "ppg_signal"]) | |
| hr_col = _first_match(["hr", "heart_rate", "hr_bpm", "bpm"]) | |
| t_col = _first_match(["time", "t", "timestamp", "sec", "seconds", "time_s"]) | |
| bvp = np.asarray(df[bvp_col].values, dtype=float) if bvp_col else np.array([], dtype=float) | |
| fs_hint = 0.0 | |
| if t_col is not None and len(df[t_col].values) >= 2: | |
| fs_hint = _fs_from_time_vector(np.asarray(df[t_col].values, dtype=float)) | |
| hr = float(np.nanmean(df[hr_col].values)) if (hr_col and df[hr_col].notna().any()) else 0.0 | |
| if hr == 0.0 and bvp.size: | |
| hr = _hr_from_bvp(bvp, fs_hint) | |
| return bvp, hr, fs_hint | |
| except Exception: | |
| return np.array([]), 0.0, 0.0 | |
| if ext == ".mat": | |
| try: | |
| md = loadmat(str(p)) | |
| # look for most likely array | |
| arr = None | |
| for key in ("ppg", "bvp", "signal", "wave"): | |
| if key in md and isinstance(md[key], np.ndarray): | |
| arr = md[key] | |
| break | |
| if arr is None: | |
| # fallback: first 1-D array | |
| for v in md.values(): | |
| if isinstance(v, np.ndarray) and v.ndim == 1: | |
| arr = v | |
| break | |
| bvp = np.asarray(arr, dtype=float).ravel() if arr is not None else np.array([], dtype=float) | |
| fs_hint = 0.0 | |
| for k in ("fs", "Fs", "sampling_rate", "sr"): | |
| if k in md: | |
| try: | |
| fs_hint = float(np.ravel(md[k])[0]) | |
| break | |
| except Exception: | |
| pass | |
| hr = 0.0 | |
| if "hr" in md: | |
| try: | |
| hr = float(np.nanmean(np.ravel(md["hr"]))) | |
| except Exception: | |
| hr = 0.0 | |
| if hr == 0.0 and bvp.size: | |
| hr = _hr_from_bvp(bvp, fs_hint) | |
| return bvp, hr, fs_hint | |
| except Exception: | |
| return np.array([]), 0.0, 0.0 | |
| # ================= NPY ================= | |
| if ext == ".npy": | |
| try: | |
| bvp = np.asarray(np.load(str(p)), dtype=float).ravel() | |
| fs_hint, hr = 0.0, 0.0 | |
| # optional sidecar JSON (same stem) with fs/hr | |
| sidecar = p.with_suffix(".json") | |
| if sidecar.exists(): | |
| try: | |
| meta = json.loads(sidecar.read_text(encoding="utf-8", errors="ignore")) | |
| if isinstance(meta.get("fs", None), (int, float)) and meta["fs"] > 0: | |
| fs_hint = float(meta["fs"]) | |
| if "hr" in meta: | |
| v = meta["hr"] | |
| hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v) | |
| except Exception: | |
| pass | |
| if hr == 0.0 and bvp.size: | |
| hr = _hr_from_bvp(bvp, fs_hint) | |
| return bvp, hr, fs_hint | |
| except Exception: | |
| return np.array([]), 0.0, 0.0 | |
| # Fallback (unsupported extension) | |
| return np.array([]), 0.0, 0.0 | |
| def scan_models() -> List[str]: | |
| if not MODEL_DIR.exists(): | |
| return [] | |
| models = [] | |
| for f in sorted(MODEL_DIR.iterdir()): | |
| if f.suffix.lower() == '.pth': | |
| models.append(f.name) | |
| return models | |
| _GLOBAL_CONTROLS: Dict[str, Dict] = {} | |
| def ensure_controls(control_id: str) -> Tuple[str, Dict]: | |
| # Use a stable default so Pause/Resume/Stop work for the current run | |
| if not control_id: | |
| control_id = "default-session" | |
| if control_id not in _GLOBAL_CONTROLS: | |
| _GLOBAL_CONTROLS[control_id] = { | |
| 'pause': threading.Event(), | |
| 'stop': threading.Event() | |
| } | |
| return control_id, _GLOBAL_CONTROLS[control_id] | |
| def process_video_file( | |
| video_path: str, | |
| gt_file: Optional[str], | |
| model_name: str, | |
| fps_input: int, | |
| max_seconds: int, | |
| roi_type: str, | |
| control_id: str | |
| ): | |
| """ | |
| Enhanced video processing with Grad-CAM attention visualization. | |
| """ | |
| global _HR_SMOOTH | |
| _HR_SMOOTH = None | |
| def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]: | |
| if gt_len <= 1: | |
| return None | |
| if gt_fs and gt_fs > 0: | |
| return np.arange(gt_len, dtype=float) / float(gt_fs) | |
| return None | |
| control_id, controls = ensure_controls(control_id) | |
| controls['stop'].clear() | |
| if not model_name: | |
| yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None) | |
| return | |
| if isinstance(model_name, int): | |
| model_name = str(model_name) | |
| model_path = MODEL_DIR / model_name | |
| if not model_path.exists(): | |
| yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None) | |
| return | |
| try: | |
| model, attention_viz = load_physmamba_model(model_path, DEVICE) | |
| except Exception as e: | |
| yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None) | |
| return | |
| gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0) | |
| if not video_path or not os.path.exists(video_path): | |
| yield ("ERROR: Video not found", None, None, None, None, None, None, None, None, None) | |
| return | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| yield ("ERROR: Cannot open video", None, None, None, None, None, None, None, None, None) | |
| return | |
| fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| max_frames = int(max_seconds * fps) if max_seconds and max_seconds > 0 else total_frames | |
| MAX_SIGNAL_LENGTH = max_frames if max_frames > 0 else fps * 600 | |
| frames_chw = deque(maxlen=DEFAULT_T) | |
| raw_g_means = deque(maxlen=DEFAULT_T) | |
| bvp_stream: List[float] = [] | |
| frame_idx = 0 | |
| last_infer = -1 | |
| last_bpm = 0.0 | |
| last_rmssd = 0.0 | |
| last_attention = None | |
| start_time = time.time() | |
| next_display = time.time() | |
| tmpdir = Path(tempfile.gettempdir()) | |
| frame_path = tmpdir / "frame.jpg" | |
| attention_path = tmpdir / "attention.jpg" | |
| signal_path = tmpdir / "signal.png" | |
| raw_path = tmpdir / "raw_signal.png" | |
| post_path = tmpdir / "post_signal.png" | |
| yield ("Starting… reading video frames", None, f"{gt_hr:.1f}" if gt_hr > 0 else "--", | |
| None, None, None, None, None, None, None) | |
| while True: | |
| if controls['stop'].is_set(): | |
| break | |
| while controls['pause'].is_set(): | |
| time.sleep(0.2) | |
| if controls['stop'].is_set(): | |
| break | |
| ret, frame = cap.read() | |
| if not ret or (max_frames > 0 and frame_idx >= max_frames): | |
| break | |
| frame_idx += 1 | |
| face = detect_face(frame) | |
| vis_frame = frame.copy() | |
| if face is not None: | |
| x, y, w, h = face | |
| cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 2) | |
| if roi_type == "auto": | |
| roi, _ = pick_auto_roi(face, frame, attn=last_attention) | |
| else: | |
| roi = crop_roi(face, roi_type, frame) | |
| if roi is not None and roi.size > 0: | |
| try: | |
| g_mean = float(roi[..., 1].astype(np.float32).mean()) | |
| raw_g_means.append(g_mean) | |
| except Exception: | |
| pass | |
| face_norm = normalize_frame(roi, DEFAULT_SIZE) | |
| frames_chw.append(face_norm) | |
| if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE: | |
| try: | |
| clip = np.stack(list(frames_chw), axis=1).astype(np.float32) | |
| except Exception as e: | |
| print(f"[infer] clip stack failed: {e}") | |
| clip = None | |
| bvp_out = None | |
| if clip is not None: | |
| clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE) | |
| try: | |
| raw = forward_bvp(model, clip_t) | |
| if isinstance(raw, np.ndarray): | |
| raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False) | |
| bvp_out = raw if raw.size > 0 else None | |
| else: | |
| bvp_out = None | |
| except Exception as e: | |
| print(f"[infer] forward_bvp error: {e}") | |
| bvp_out = None | |
| # UPDATED: Generate attention with Grad-CAM | |
| try: | |
| last_attention = extract_attention_map(model, clip_t, attention_viz) | |
| except Exception as e: | |
| print(f"⚠ Attention generation error: {e}") | |
| last_attention = None | |
| if bvp_out is None or bvp_out.size == 0: | |
| gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0) | |
| fb = _fallback_bvp_from_means(gbuf, fs=fps) | |
| if isinstance(fb, np.ndarray) and fb.size > 0: | |
| bvp_out = fb | |
| else: | |
| print("[infer] fallback produced empty output") | |
| if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0: | |
| tail = min(DEFAULT_STRIDE, bvp_out.size) | |
| bvp_stream.extend(bvp_out[-tail:].tolist()) | |
| if len(bvp_stream) > MAX_SIGNAL_LENGTH: | |
| bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:] | |
| if len(bvp_stream) >= int(5 * fps): | |
| seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32) | |
| _, last_bpm = postprocess_bvp(seg, fs=fps) | |
| last_rmssd = compute_rmssd(seg, fs=fps) | |
| if frame_idx % (DEFAULT_STRIDE * 2) == 0: | |
| print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}") | |
| else: | |
| print("[infer] no usable bvp_out after fallback") | |
| last_infer = frame_idx | |
| else: | |
| cv2.putText(vis_frame, "No face detected", (20, 40), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2) | |
| if last_bpm > 0: | |
| color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255) | |
| cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1) | |
| cv2.putText(vis_frame, f"HR: {last_bpm:.1f} BPM", (20, 48), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) | |
| vis_attention = create_attention_overlay(frame, last_attention, attention_viz) | |
| cv2.imwrite(str(frame_path), vis_frame) | |
| cv2.imwrite(str(attention_path), vis_attention) | |
| now = time.time() | |
| if now >= next_display: | |
| if len(bvp_stream) >= max(10, int(1 * fps)): | |
| try: | |
| signal_array = np.array(bvp_stream, dtype=float) | |
| time_axis = np.arange(len(signal_array)) / fps | |
| raw_signal = bandpass_filter(signal_array, fs=fps) | |
| post_signal, _ = postprocess_bvp(signal_array, fs=fps) | |
| raw_vis = raw_signal - np.mean(raw_signal) | |
| post_vis = post_signal - np.mean(post_signal) | |
| plt.figure(figsize=(10, 4), dpi=100) | |
| plt.plot(time_axis, raw_vis, linewidth=1.6) | |
| plt.xlabel('Time (s)'); plt.ylabel('Amplitude') | |
| plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM') | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout(); plt.savefig(str(raw_path), dpi=100, bbox_inches='tight'); plt.close() | |
| plt.figure(figsize=(10, 4), dpi=100) | |
| plt.plot(time_axis, post_vis, linewidth=1.6) | |
| plt.xlabel('Time (s)'); plt.ylabel('Amplitude') | |
| plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM') | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout(); plt.savefig(str(post_path), dpi=100, bbox_inches='tight'); plt.close() | |
| if gt_bvp is not None and gt_bvp.size > 0: | |
| try: | |
| gt_time = _gt_time_axis(len(gt_bvp), gt_fs) | |
| plot_signals_with_gt( | |
| time_axis=time_axis, | |
| raw_signal=raw_signal, | |
| post_signal=post_signal, | |
| fs=fps, | |
| out_path=str(signal_path), | |
| gt_time=gt_time, | |
| gt_bvp=gt_bvp, | |
| title=f"Pred vs GT — HR: {last_bpm:.1f} BPM" | |
| ) | |
| except Exception: | |
| fig = plt.figure(figsize=(12, 5), dpi=100) | |
| gs = GridSpec(1, 2, figure=fig, wspace=0.3) | |
| ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6) | |
| ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3) | |
| ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6) | |
| ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3) | |
| plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold') | |
| plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all') | |
| else: | |
| fig = plt.figure(figsize=(12, 5), dpi=100) | |
| gs = GridSpec(1, 2, figure=fig, wspace=0.3) | |
| ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6) | |
| ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3) | |
| ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6) | |
| ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3) | |
| plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold') | |
| plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all') | |
| except Exception: | |
| pass | |
| elapsed = now - start_time | |
| status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM" | |
| yield ( | |
| status, | |
| f"{last_bpm:.1f}" if last_bpm > 0 else None, | |
| f"{gt_hr:.1f}" if gt_hr > 0 else "--", | |
| f"{last_rmssd:.1f}" if last_rmssd > 0 else None, | |
| str(frame_path), | |
| str(attention_path), | |
| str(signal_path) if signal_path.exists() else None, | |
| str(raw_path) if raw_path.exists() else None, | |
| str(post_path) if post_path.exists() else None, | |
| None | |
| ) | |
| next_display = now + (1.0 / DISPLAY_FPS) | |
| cap.release() | |
| # Cleanup Grad-CAM | |
| if attention_viz: attention_viz.cleanup() | |
| csv_path = None | |
| if bvp_stream: | |
| csv_path = Path(tempfile.gettempdir()) / "bvp_output.csv" | |
| time_array = np.arange(len(bvp_stream)) / fps | |
| signal_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps) | |
| try: | |
| pd.DataFrame({'time_s': time_array, 'bvp': signal_final}).to_csv(csv_path, index=False) | |
| except Exception: | |
| csv_path = None | |
| try: | |
| final_overlay = Path(tempfile.gettempdir()) / "signal_final_overlay.png" | |
| if gt_bvp is not None and gt_bvp.size > 0: | |
| gt_time = _gt_time_axis(len(gt_bvp), gt_fs) | |
| plot_signals_with_gt( | |
| time_axis=time_array, | |
| raw_signal=bandpass_filter(np.array(bvp_stream, dtype=float), fs=fps), | |
| post_signal=signal_final, | |
| fs=fps, | |
| out_path=str(final_overlay), | |
| gt_time=gt_time, | |
| gt_bvp=gt_bvp, | |
| title=f"Final Pred vs GT — HR: {last_bpm:.1f} BPM" | |
| ) | |
| if final_overlay.exists(): | |
| signal_path = final_overlay | |
| except Exception: | |
| pass | |
| elapsed = time.time() - start_time | |
| final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM" | |
| yield ( | |
| final_status, | |
| f"{last_bpm:.1f}" if last_bpm > 0 else None, | |
| f"{gt_hr:.1f}" if gt_hr > 0 else "--", | |
| f"{last_rmssd:.1f}" if last_rmssd > 0 else None, | |
| str(frame_path), | |
| str(attention_path), | |
| str(signal_path) if signal_path.exists() else None, | |
| str(raw_path) if raw_path.exists() else None, | |
| str(post_path) if post_path.exists() else None, | |
| str(csv_path) if csv_path else None | |
| ) | |
| def process_live_webcam( | |
| model_name: str, | |
| fps_input: int, | |
| roi_type: str, | |
| control_id: str | |
| ): | |
| """Stream live webcam with Grad-CAM attention visualization.""" | |
| global _HR_SMOOTH | |
| _HR_SMOOTH = None | |
| def _perf_heartbeat(frame_idx, t0, bvp_len, frames_chw_len, fps): | |
| if frame_idx == 1: | |
| print(f"[run] device={DEVICE} target_fps={fps}") | |
| if frame_idx % 60 == 0: | |
| elapsed = time.time() - t0 | |
| cur_fps = frame_idx / max(elapsed, 1e-6) | |
| print(f"[perf] frames={frame_idx} ({cur_fps:.1f} FPS) " | |
| f"clip={frames_chw_len}/{DEFAULT_T} bvp_len={bvp_len}") | |
| control_id, controls = ensure_controls(control_id) | |
| controls['stop'].clear() | |
| if not model_name: | |
| yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None) | |
| return | |
| model_path = MODEL_DIR / model_name | |
| if not model_path.exists(): | |
| yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None) | |
| return | |
| try: | |
| model, attention_viz = load_physmamba_model(model_path, DEVICE) | |
| except Exception as e: | |
| yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None) | |
| return | |
| cap = None | |
| for camera_id in [0, 1]: | |
| for backend in [cv2.CAP_AVFOUNDATION, cv2.CAP_ANY]: | |
| try: | |
| test_cap = cv2.VideoCapture(camera_id, backend) | |
| if test_cap.isOpened(): | |
| test_cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
| test_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
| ret, frame = test_cap.read() | |
| if ret and frame is not None: | |
| cap = test_cap | |
| break | |
| test_cap.release() | |
| except Exception: | |
| pass | |
| if cap is not None: | |
| break | |
| if cap is None: | |
| yield ("ERROR: Cannot access webcam", None, None, None, None, None, None, None, None, None) | |
| return | |
| fps = int(fps_input) if fps_input else 30 | |
| frames_chw = deque(maxlen=DEFAULT_T) | |
| raw_g_means = deque(maxlen=DEFAULT_T) | |
| bvp_stream: List[float] = [] | |
| frame_idx = 0 | |
| last_infer = -1 | |
| last_bpm = 0.0 | |
| last_rmssd = 0.0 | |
| last_attention = None | |
| t0 = time.time() | |
| next_display = time.time() | |
| DISPLAY_INTERVAL = 0.5 | |
| frame_path = Path(tempfile.gettempdir()) / "live_frame.jpg" | |
| attention_path = Path(tempfile.gettempdir()) / "live_attention.jpg" | |
| signal_path = Path(tempfile.gettempdir()) / "live_signal.png" | |
| raw_path = Path(tempfile.gettempdir()) / "live_raw.png" | |
| post_path = Path(tempfile.gettempdir()) / "live_post.png" | |
| MAX_SIGNAL_LENGTH = fps * 60 | |
| yield ("Starting… waiting for frames", None, "--", None, | |
| None, None, None, None, None, None) | |
| while True: | |
| if controls['stop'].is_set(): | |
| break | |
| while controls['pause'].is_set(): | |
| time.sleep(0.2) | |
| if controls['stop'].is_set(): | |
| break | |
| ret, frame = cap.read() | |
| if not ret: | |
| time.sleep(0.05) | |
| continue | |
| frame_idx += 1 | |
| _perf_heartbeat(frame_idx, t0, len(bvp_stream), len(frames_chw), fps) | |
| face = detect_face(frame) | |
| vis_frame = frame.copy() | |
| if face is not None: | |
| x, y, w, h = face | |
| cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 3) | |
| cv2.putText(vis_frame, "FACE", (x, max(20, y - 10)), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| if roi_type == "auto": | |
| roi, _ = pick_auto_roi(face, frame, attn=last_attention) | |
| else: | |
| roi = crop_roi(face, roi_type, frame) | |
| if roi is not None and roi.size > 0: | |
| try: | |
| g_mean = float(roi[..., 1].astype(np.float32).mean()) | |
| raw_g_means.append(g_mean) | |
| except Exception: | |
| pass | |
| chw = normalize_frame(roi, DEFAULT_SIZE) | |
| frames_chw.append(chw) | |
| if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE: | |
| try: | |
| clip = np.stack(list(frames_chw), axis=1).astype(np.float32) | |
| except Exception as e: | |
| print(f"[infer] clip stack failed: {e}") | |
| clip = None | |
| bvp_out = None | |
| if clip is not None: | |
| clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE) | |
| try: | |
| raw = forward_bvp(model, clip_t) | |
| if isinstance(raw, np.ndarray): | |
| raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False) | |
| bvp_out = raw if raw.size > 0 else None | |
| else: | |
| bvp_out = None | |
| except Exception as e: | |
| print(f"[infer] forward_bvp error: {e}") | |
| bvp_out = None | |
| try: | |
| last_attention = extract_attention_map(model, clip_t, attention_viz) | |
| except Exception as e: | |
| print(f" Attention generation error: {e}") | |
| last_attention = None | |
| if bvp_out is None or bvp_out.size == 0: | |
| gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0) | |
| fb = _fallback_bvp_from_means(gbuf, fs=fps) | |
| if isinstance(fb, np.ndarray) and fb.size > 0: | |
| bvp_out = fb | |
| else: | |
| print("[infer] fallback produced empty output") | |
| if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0: | |
| tail = min(DEFAULT_STRIDE, bvp_out.size) | |
| bvp_stream.extend(bvp_out[-tail:].tolist()) | |
| if len(bvp_stream) > MAX_SIGNAL_LENGTH: | |
| bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:] | |
| if len(bvp_stream) >= int(5 * fps): | |
| seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32) | |
| _, last_bpm = postprocess_bvp(seg, fs=fps) | |
| last_rmssd = compute_rmssd(seg, fs=fps) | |
| if frame_idx % (DEFAULT_STRIDE * 2) == 0: | |
| print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}") | |
| else: | |
| print("[infer] no usable bvp_out after fallback") | |
| last_infer = frame_idx | |
| else: | |
| cv2.putText(vis_frame, "No face detected", (20, 40), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2) | |
| cv2.putText(vis_frame, f"Fill: {len(frames_chw)}/{DEFAULT_T}", (20, 25), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| cv2.putText(vis_frame, f"BVP: {len(bvp_stream)}", (20, 45), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| if last_bpm > 0: | |
| color = (50, 255, 50) if 55 <= last_bpm <= 100 else (50, 50, 255) if last_bpm > 100 else (255, 200, 50) | |
| overlay = vis_frame.copy() | |
| cv2.rectangle(overlay, (10, 10), (450, 100), (0, 0, 0), -1) | |
| vis_frame = cv2.addWeighted(vis_frame, 0.6, overlay, 0.4, 0) | |
| cv2.putText(vis_frame, f"HR: {last_bpm:.0f} BPM", (20, 80), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 3) | |
| cv2.circle(vis_frame, (30, 30), 10, (0, 255, 0), -1) | |
| cv2.putText(vis_frame, "LIVE", (50, 38), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) | |
| else: | |
| cv2.putText(vis_frame, "Collecting…", (20, 80), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2) | |
| vis_attention = create_attention_overlay(frame, last_attention, attention_viz) | |
| cv2.imwrite(str(frame_path), vis_frame) | |
| cv2.imwrite(str(attention_path), vis_attention) | |
| now = time.time() | |
| if now >= next_display: | |
| if len(bvp_stream) >= max(10, int(1 * fps)): | |
| try: | |
| sig = np.array(bvp_stream, dtype=float) | |
| t_axis = np.arange(len(sig)) / fps | |
| raw_sig = bandpass_filter(sig, fs=fps) | |
| post_sig, _ = postprocess_bvp(sig, fs=fps) | |
| rv = raw_sig - np.mean(raw_sig) | |
| pv = post_sig - np.mean(post_sig) | |
| plt.figure(figsize=(10, 4), dpi=100) | |
| plt.plot(t_axis, rv, linewidth=2) | |
| plt.xlabel('Time (s)'); plt.ylabel('Amplitude') | |
| plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM') | |
| plt.grid(True, alpha=0.3) | |
| plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]]) | |
| plt.tight_layout(); plt.savefig(str(raw_path)); plt.close() | |
| plt.figure(figsize=(10, 4), dpi=100) | |
| plt.plot(t_axis, pv, linewidth=2) | |
| plt.xlabel('Time (s)'); plt.ylabel('Amplitude') | |
| plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM') | |
| plt.grid(True, alpha=0.3) | |
| plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]]) | |
| plt.tight_layout(); plt.savefig(str(post_path)); plt.close() | |
| fig = plt.figure(figsize=(12, 5), dpi=100) | |
| from matplotlib.gridspec import GridSpec | |
| gs = GridSpec(1, 2, figure=fig, wspace=0.3) | |
| ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(t_axis, rv, linewidth=2) | |
| ax1.set_title('Raw Signal'); ax1.grid(True, alpha=0.3) | |
| ax1.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]]) | |
| ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(t_axis, pv, linewidth=2) | |
| ax2.set_title('Post-processed'); ax2.grid(True, alpha=0.3) | |
| ax2.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]]) | |
| plt.suptitle(f'LIVE rPPG - HR: {last_bpm:.1f} BPM') | |
| plt.savefig(str(signal_path)); plt.close('all') | |
| except Exception: | |
| pass | |
| elapsed = now - t0 | |
| status = (f"Frame {frame_idx} | Fill {len(frames_chw)}/{DEFAULT_T} | " | |
| f"BVP {len(bvp_stream)} | HR {last_bpm:.1f} BPM | " | |
| f"Time {int(elapsed)}s") | |
| yield ( | |
| status, | |
| f"{last_bpm:.1f}" if last_bpm > 0 else None, | |
| "--", | |
| f"{last_rmssd:.1f}" if last_rmssd > 0 else None, | |
| str(frame_path), | |
| str(attention_path), | |
| str(signal_path) if signal_path.exists() else None, | |
| str(raw_path) if raw_path.exists() else None, | |
| str(post_path) if post_path.exists() else None, | |
| None | |
| ) | |
| next_display = now + DISPLAY_INTERVAL | |
| time.sleep(0.01) | |
| cap.release() | |
| # Cleanup Grad-CAM | |
| if attention_viz: attention_viz.cleanup() | |
| csv_path = None | |
| if bvp_stream: | |
| csv_path = Path(tempfile.gettempdir()) / "live_bvp.csv" | |
| t = np.arange(len(bvp_stream)) / fps | |
| sig_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps) | |
| pd.DataFrame({"time_s": t, "bvp": sig_final}).to_csv(csv_path, index=False) | |
| elapsed = time.time() - t0 | |
| final_status = f"Session ended | Frames {frame_idx} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM" | |
| yield ( | |
| final_status, | |
| f"{last_bpm:.1f}" if last_bpm > 0 else None, | |
| "--", | |
| f"{last_rmssd:.1f}" if last_rmssd > 0 else None, | |
| str(frame_path), | |
| str(attention_path), | |
| str(signal_path) if signal_path.exists() else None, | |
| str(raw_path) if raw_path.exists() else None, | |
| str(post_path) if post_path.exists() else None, | |
| str(csv_path) if csv_path else None | |
| ) | |
| def process_stream( | |
| input_source: str, | |
| video_path: Optional[str], | |
| gt_file: Optional[str], | |
| model_name: str, | |
| fps_input: int, | |
| max_seconds: int, | |
| roi_type: str, | |
| control_id: str | |
| ): | |
| if input_source == "Live Webcam": | |
| yield from process_live_webcam(model_name, fps_input, roi_type, control_id) | |
| else: | |
| yield from process_video_file(video_path, gt_file, model_name, fps_input, | |
| max_seconds, roi_type, control_id) | |
| def pause_processing(control_id: str) -> str: | |
| _, controls = ensure_controls(control_id) | |
| controls['pause'].set() | |
| return "Paused" | |
| def resume_processing(control_id: str) -> str: | |
| _, controls = ensure_controls(control_id) | |
| controls['pause'].clear() | |
| return "Resumed" | |
| def stop_processing(control_id: str) -> str: | |
| _, controls = ensure_controls(control_id) | |
| controls['stop'].set() | |
| controls['pause'].clear() | |
| return "Stopped" | |
| def reset_ui(): | |
| return ("Ready", None, None, None, None, None, None, None, None, None) | |
| def handle_folder_upload(files): | |
| if not files: | |
| return None, None, "No files uploaded" | |
| if not isinstance(files, list): | |
| files = [files] | |
| video_path = None | |
| for file_obj in files: | |
| file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name) | |
| if file_path.suffix.lower() in VIDEO_EXTENSIONS: | |
| video_path = str(file_path) | |
| break | |
| gt_path = None | |
| gt_patterns = ['gtdump.txt', 'ground_truth.txt', 'gt.txt'] | |
| for file_obj in files: | |
| file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name) | |
| if file_path.name.lower() in [p.lower() for p in gt_patterns]: | |
| gt_path = str(file_path) | |
| break | |
| if not gt_path: | |
| for file_obj in files: | |
| file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name) | |
| if file_path.suffix.lower() in ['.txt', '.json', '.csv']: | |
| if 'readme' not in file_path.name.lower(): | |
| gt_path = str(file_path) | |
| break | |
| status = [] | |
| if video_path: | |
| status.append(f"Video: {Path(video_path).name}") | |
| else: | |
| status.append("No video found") | |
| if gt_path: | |
| status.append(f"GT: {Path(gt_path).name}") | |
| else: | |
| status.append("No ground truth") | |
| return video_path, gt_path, " | ".join(status) | |
| with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# rPPG Analysis Tool with Attention Visualization") | |
| with gr.Row(): | |
| input_source = gr.Radio( | |
| choices=["Video File", "Subject Folder", "Live Webcam"], | |
| value="Live Webcam", | |
| label="Input Source" | |
| ) | |
| with gr.Row(): | |
| target_layer_input = gr.Textbox( | |
| label="Grad-CAM Target Layer (optional - leave empty for auto)", | |
| placeholder="e.g., backbone.conv3, encoder.layer2", | |
| value="" | |
| ) | |
| with gr.Row(visible=False) as video_inputs: | |
| with gr.Column(): | |
| video_upload = gr.Video(label="Upload Video", sources=["upload"]) | |
| with gr.Column(): | |
| gt_upload = gr.File(label="Ground Truth (optional)", | |
| file_types=[".txt", ".csv", ".json"]) | |
| with gr.Row(visible=False) as folder_inputs: | |
| with gr.Column(): | |
| folder_upload = gr.File( | |
| label="Upload Subject Folder", | |
| file_count="directory", | |
| file_types=None, | |
| type="filepath" | |
| ) | |
| folder_status = gr.Textbox(label="Folder Status", interactive=False) | |
| with gr.Row(visible=True) as webcam_inputs: | |
| gr.Markdown("### Live Webcam - Click Run to start") | |
| def toggle_input_source(source): | |
| return [ | |
| gr.update(visible=(source == "Video File")), | |
| gr.update(visible=(source == "Subject Folder")), | |
| gr.update(visible=(source == "Live Webcam")) | |
| ] | |
| input_source.change( | |
| toggle_input_source, | |
| inputs=[input_source], | |
| outputs=[video_inputs, folder_inputs, webcam_inputs] | |
| ) | |
| folder_video = gr.State(None) | |
| folder_gt = gr.State(None) | |
| folder_upload.upload( | |
| handle_folder_upload, | |
| inputs=[folder_upload], | |
| outputs=[folder_video, folder_gt, folder_status] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| model_dropdown = gr.Dropdown( | |
| choices=scan_models(), | |
| value=scan_models()[0] if scan_models() else None, | |
| label="PhysMamba Model", | |
| interactive=True | |
| ) | |
| with gr.Column(scale=1): | |
| refresh_models_btn = gr.Button("Refresh", variant="secondary") | |
| refresh_models_btn.click( | |
| lambda: gr.update(choices=scan_models()), | |
| inputs=None, | |
| outputs=[model_dropdown] | |
| ) | |
| with gr.Row(): | |
| fps_slider = gr.Slider( | |
| minimum=10, maximum=120, value=30, step=5, | |
| label="FPS" | |
| ) | |
| max_seconds_slider = gr.Slider( | |
| minimum=10, maximum=600, value=180, step=10, | |
| label="Max Duration (s)" | |
| ) | |
| with gr.Row(): | |
| roi_dropdown = gr.Dropdown(choices=["auto","forehead","cheeks","face"], value="auto", label="ROI") | |
| control_state = gr.State(value="") | |
| placeholder_state = gr.State(value=None) | |
| with gr.Row(): | |
| run_btn = gr.Button("Run", variant="primary") | |
| pause_btn = gr.Button("Pause", variant="secondary") | |
| resume_btn = gr.Button("Resume", variant="secondary") | |
| stop_btn = gr.Button("Stop", variant="stop") | |
| status_text = gr.Textbox(label="Status", lines=2, value="Ready") | |
| with gr.Row(): | |
| hr_output = gr.Textbox(label="HR (BPM)", interactive=False) | |
| gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False) | |
| rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| frame_output = gr.Image(label="Video Feed", type="filepath") | |
| with gr.Column(): | |
| attention_output = gr.Image(label="Attention Map", type="filepath") | |
| with gr.Row(): | |
| signal_output = gr.Image(label="Signal Comparison", type="filepath") | |
| with gr.Row(): | |
| with gr.Column(): | |
| raw_signal_output = gr.Image(label="Raw Signal", type="filepath") | |
| with gr.Column(): | |
| post_signal_output = gr.Image(label="Post-processed Signal", type="filepath") | |
| with gr.Row(): | |
| csv_output = gr.File(label="Download CSV") | |
| pause_btn.click( | |
| pause_processing, | |
| inputs=[control_state], | |
| outputs=[status_text] | |
| ) | |
| resume_btn.click( | |
| resume_processing, | |
| inputs=[control_state], | |
| outputs=[status_text] | |
| ) | |
| stop_btn.click( | |
| stop_processing, | |
| inputs=[control_state], | |
| outputs=[status_text] | |
| ).then( | |
| reset_ui, | |
| inputs=None, | |
| outputs=[status_text, hr_output, gt_hr_output, rmssd_output, | |
| frame_output, attention_output, signal_output, | |
| raw_signal_output, post_signal_output, csv_output] | |
| ) | |
| def run_processing(input_source, video_upload, gt_upload, folder_video, folder_gt, | |
| model_name, fps, max_sec, roi, ctrl_id): | |
| """Fixed version that handles model_name type conversion.""" | |
| if isinstance(model_name, int): | |
| model_name = str(model_name) | |
| if not model_name: | |
| yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None) | |
| return | |
| if input_source == "Video File": | |
| video_path = _as_path(video_upload) | |
| gt_file = _as_path(gt_upload) | |
| elif input_source == "Subject Folder": | |
| video_path = _as_path(folder_video) | |
| gt_file = _as_path(folder_gt) | |
| else: # Live Webcam | |
| video_path, gt_file = None, None | |
| yield from process_stream( | |
| input_source, video_path, gt_file, | |
| model_name, fps, max_sec, roi, ctrl_id | |
| ) | |
| run_btn.click( | |
| fn=run_processing, | |
| inputs=[ | |
| input_source, | |
| video_upload, | |
| gt_upload, | |
| folder_video, | |
| folder_gt, | |
| model_dropdown, | |
| fps_slider, | |
| max_seconds_slider, | |
| roi_dropdown, | |
| control_state | |
| ], | |
| outputs=[ | |
| status_text, | |
| hr_output, | |
| gt_hr_output, | |
| rmssd_output, | |
| frame_output, | |
| attention_output, | |
| signal_output, | |
| raw_signal_output, | |
| post_signal_output, | |
| csv_output | |
| ] | |
| ) | |
| run_btn.click( | |
| fn=run_processing, | |
| inputs=[ | |
| input_source, | |
| video_upload, | |
| folder_video, | |
| folder_gt, | |
| model_dropdown, | |
| fps_slider, | |
| max_seconds_slider, | |
| roi_dropdown, | |
| control_state | |
| ], | |
| outputs=[ | |
| status_text, | |
| hr_output, | |
| gt_hr_output, | |
| rmssd_output, | |
| frame_output, | |
| attention_output, | |
| signal_output, | |
| raw_signal_output, | |
| post_signal_output, | |
| csv_output | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| import os | |
| demo.queue(max_size=10).launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", 7860)), | |
| share=False, | |
| ) | |