demo-space2 / app.py
swetchareddytukkani's picture
Update app.py
e8fb09f verified
import os
os.environ.setdefault("OPENCV_AVFOUNDATION_SKIP_AUTH", "1")
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import warnings
warnings.filterwarnings("ignore")
import re
import io
import json
import tempfile
import time
import threading
import uuid
import importlib.util
from pathlib import Path
from collections import deque
from typing import Optional, Tuple, Dict, List
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import signal
from scipy.signal import find_peaks, welch, get_window
# .mat support (SCAMPS)
try:
from scipy.io import loadmat
MAT_SUPPORT = True
except Exception:
MAT_SUPPORT = False
# Matplotlib headless backend
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import gradio as gr
# Global buffer for per-frame illumination/motion metrics
FRAME_METRICS: List[Dict] = []
class PhysMambaattention_viz:
"""Simplified Grad-CAM for PhysMamba."""
def __init__(self, model: nn.Module, device: torch.device):
self.model = model
self.device = device
self.activations = None
self.gradients = None
self.hook_handles = []
def _get_target_layer(self):
"""Auto-detect the best layer for visualization, excluding Mamba/SSM layers."""
# Strategy 1: Look for temporal difference convolution
for name, module in reversed(list(self.model.named_modules())):
if ('tdc' in name.lower() or 'temporal_diff' in name.lower()) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
print(f"[attention_viz] Selected TDC layer: {name} ({type(module).__name__})")
return name, module
# Strategy 2: Look for 3D convolutions (temporal-spatial)
for name, module in reversed(list(self.model.named_modules())):
if isinstance(module, nn.Conv3d) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
print(f"[attention_viz] Selected Conv3d layer: {name}")
return name, module
# Strategy 3: Look for 2D convolutions (spatial)
for name, module in reversed(list(self.model.named_modules())):
if isinstance(module, nn.Conv2d) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
print(f"[attention_viz] Selected Conv2d layer: {name}")
return name, module
# Strategy 4: Look for any convolution (last resort)
for name, module in reversed(list(self.model.named_modules())):
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
print(f"[attention_viz] Selected Conv layer (fallback): {name} ({type(module).__name__})")
return name, module
# Strategy 5: Print all available layers and pick the last non-Mamba one
print("\n[attention_viz] Available layers:")
suitable_layers = []
for name, module in self.model.named_modules():
layer_type = type(module).__name__
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
is_mamba = 'mamba' in name.lower() or 'ssm' in name.lower()
print(f" - {name}: {layer_type} {'(MAMBA - SKIP)' if is_mamba else ''}")
if not is_mamba:
suitable_layers.append((name, module))
if suitable_layers:
name, module = suitable_layers[-1]
print(f"\n[attention_viz] Selected last suitable layer: {name} ({type(module).__name__})")
return name, module
raise ValueError("No suitable layer found for Grad-CAM (all layers are Mamba/SSM)")
def _forward_hook(self, module, input, output):
self.activations = output.detach()
def _backward_hook(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def register_hooks(self, target_layer_name: Optional[str] = None):
"""Register hooks on target layer"""
self._remove_hooks()
if target_layer_name and target_layer_name.strip():
# Manual selection
try:
target_module = dict(self.model.named_modules())[target_layer_name.strip()]
print(f"Using manually specified layer: {target_layer_name}")
except KeyError:
print(f"⚠ Layer '{target_layer_name}' not found, falling back to auto-detection")
target_layer_name, target_module = self._get_target_layer()
else:
# Auto-detection
target_layer_name, target_module = self._get_target_layer()
fwd_handle = target_module.register_forward_hook(self._forward_hook)
bwd_handle = target_module.register_full_backward_hook(self._backward_hook)
self.hook_handles = [fwd_handle, bwd_handle]
def _remove_hooks(self):
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
self.activations = None
self.gradients = None
def generate(self, input_tensor: torch.Tensor) -> np.ndarray:
"""Generate Grad-CAM heatmap with improved gradient handling."""
# Set model to eval but keep gradient computation
self.model.eval()
# Ensure tensor requires grad
input_tensor = input_tensor.requires_grad_(True)
# Forward pass
output = self.model(input_tensor)
# Handle different output types
if isinstance(output, dict):
# Try common keys
for key in ('pred', 'output', 'bvp', 'logits', 'out'):
if key in output and isinstance(output[key], torch.Tensor):
output = output[key]
break
# Get scalar for backward
if output.numel() == 1:
target = output
else:
# Use mean of output
target = output.mean()
print(f"[attention_viz] Output shape: {output.shape}, Target: {target.item():.4f}")
# Backward pass
self.model.zero_grad()
target.backward(retain_graph=False)
# Check if we got gradients
if self.activations is None:
print("⚠ No activations captured!")
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
if self.gradients is None:
print("⚠ No gradients captured!")
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
print(f"[attention_viz] Activations shape: {self.activations.shape}")
print(f"[attention_viz] Gradients shape: {self.gradients.shape}")
activations = self.activations
gradients = self.gradients
# Compute weights (global average pooling of gradients)
if gradients.dim() == 5: # [B, C, T, H, W]
weights = gradients.mean(dim=[2, 3, 4], keepdim=True)
elif gradients.dim() == 4: # [B, C, H, W]
weights = gradients.mean(dim=[2, 3], keepdim=True)
elif gradients.dim() == 3: # [B, C, T]
weights = gradients.mean(dim=2, keepdim=True).unsqueeze(-1).unsqueeze(-1)
else:
print(f"⚠ Unexpected gradient dimensions: {gradients.dim()}")
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
# Weighted combination
cam = (weights * activations).sum(dim=1, keepdim=True)
# If 5D, average over time
if cam.dim() == 5:
cam = cam.mean(dim=2)
# ReLU
cam = F.relu(cam)
# Convert to numpy
cam = cam.squeeze().cpu().detach().numpy()
# Normalize
if cam.max() > cam.min():
cam = (cam - cam.min()) / (cam.max() - cam.min())
print(f"[attention_viz] Final heatmap shape: {cam.shape}, range: [{cam.min():.3f}, {cam.max():.3f}]")
return cam
def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
"""Overlay heatmap on frame."""
h, w = frame.shape[:2]
heatmap_resized = cv2.resize(heatmap, (w, h))
heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
return overlay
def cleanup(self):
self._remove_hooks()
def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray:
"""
Apply DiffNormalized preprocessing from the PhysMamba paper:
diff_t = (I_t - I_{t-1}) / (I_t + I_{t-1} + eps)
then global std-normalize.
frames: list of HxWx3 uint8 or float32 arrays (RGB or BGR, consistent).
Returns: (T, H, W, C) float32.
"""
if not frames:
return np.zeros((0,), dtype=np.float32)
if len(frames) < 2:
f0 = frames[0].astype(np.float32)
return np.stack([np.zeros_like(f0, dtype=np.float32)], axis=0)
diff_frames = []
for i in range(len(frames)):
if i == 0:
diff_frames.append(np.zeros_like(frames[0], dtype=np.float32))
else:
curr = frames[i].astype(np.float32)
prev = frames[i - 1].astype(np.float32)
denom = curr + prev + 1e-8
diff = (curr - prev) / denom
diff_frames.append(diff)
diff_array = np.stack(diff_frames).astype(np.float32)
std = float(diff_array.std()) + 1e-8
diff_array /= std
return diff_array
def preprocess_for_physmamba(
frames: List[np.ndarray],
target_frames: int = 128,
target_size: int = 128
) -> torch.Tensor:
"""
Complete DiffNormalized preprocessing pipeline to produce
a PhysMamba-compatible clip tensor of shape [1, 3, T, H, W].
NOTE: This path is *not* used in the current live demo, which instead
uses normalize_frame() + forward_bvp(). Keep for future experiments.
"""
if not frames:
# Dummy tensor; caller should guard length > 0
return torch.zeros(1, 3, target_frames, target_size, target_size, dtype=torch.float32)
# Temporal sampling / padding to target_frames
if len(frames) < target_frames:
frames = frames + [frames[-1]] * (target_frames - len(frames))
elif len(frames) > target_frames:
idx = np.linspace(0, len(frames) - 1, target_frames).astype(int)
frames = [frames[i] for i in idx]
# Convert to RGB and resize
frames_rgb = [f[..., ::-1].copy() for f in frames] # BGR->RGB
frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb]
# DiffNormalized
diff_array = apply_diff_normalized(frames_resized) # (T, H, W, C)
# To [B, C, T, H, W]
diff_array = np.transpose(diff_array, (3, 0, 1, 2)) # (C, T, H, W)
diff_array = np.expand_dims(diff_array, axis=0) # (1, C, T, H, W)
return torch.from_numpy(diff_array.astype(np.float32))
# ---------------------------------------------------------------------------
# Paths, device, constants
# ---------------------------------------------------------------------------
HERE = Path(__file__).resolve().parent
MODEL_DIR = HERE / "final_model_release"
LOG_DIR = HERE / "logs"
ANALYSIS_DIR = HERE / "analysis"
for d in (MODEL_DIR, LOG_DIR, ANALYSIS_DIR):
d.mkdir(exist_ok=True, parents=True)
DEVICE = (
torch.device("cuda") if torch.cuda.is_available()
else torch.device("mps") if torch.backends.mps.is_available()
else torch.device("cpu")
)
FACE_CASCADE = cv2.CascadeClassifier(
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
)
DEFAULT_SIZE = 128 # input H=W to model
DEFAULT_T = 128 # clip length
DEFAULT_STRIDE = 8 # inference hop
DISPLAY_FPS = 10
VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v']
_HR_SMOOTH = None
REST_HR_TARGET = 72.0
REST_HR_RANGE = (55.0, 95.0)
MAX_JUMP_BPM = 8.0
# Recognized GT files for subject folders
GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"}
GT_EXTS = {".txt", ".csv", ".json"}
def _as_path(maybe) -> Optional[str]:
"""
Return a filesystem path from Gradio values (str, dict, Path, tempfile objects, lists).
Handles:
- plain strings
- pathlib.Path
- Gradio dicts with keys: 'name', 'path', 'file'
- file-like objects with .name
- lists (takes first element)
"""
if maybe is None:
return None
# Gradio can pass a list (e.g., multiple files / directory upload)
if isinstance(maybe, list):
if not maybe:
return None
return _as_path(maybe[0])
if isinstance(maybe, str):
return maybe
if isinstance(maybe, Path):
return str(maybe)
# Gradio v4 File/Video components often pass a dict
if isinstance(maybe, dict):
for key in ("name", "path", "file"):
v = maybe.get(key)
if isinstance(v, str) and v:
return v
return None
# tempfile-like / UploadedFile objects
name = getattr(maybe, "name", None)
if isinstance(name, str) and name:
return name
try:
return str(maybe)
except Exception:
return None
def _import_from_file(py_path: Path):
spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path))
if not spec or not spec.loader:
raise ImportError(f"Cannot import module from {py_path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _looks_like_video(p: Path) -> bool:
"""
Heuristic for 'video-like' files used in subject-folder discovery.
Treat .mat as video, plus common video extensions.
"""
if p.suffix.lower() == ".mat":
return True
return p.suffix.lower() in VIDEO_EXTENSIONS
class SimpleActivationAttention:
"""Lightweight attention visualization using forward activations (no gradients)."""
def __init__(self, model: nn.Module, device: torch.device):
self.model = model
self.device = device
self.activations: Optional[torch.Tensor] = None
self.hook_handle: Optional[Any] = None
def _activation_hook(self, module, input, output):
"""Capture activations during forward pass."""
try:
self.activations = output.detach()
except Exception:
self.activations = None
def register_hook(self):
"""Register hook on a suitable conv layer (last conv before Mamba if possible)."""
target = None
target_name = None
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower():
target = module
target_name = name
if target is None:
print("⚠ [attention_viz] No suitable conv layer found, attention disabled")
return
self.hook_handle = target.register_forward_hook(self._activation_hook)
print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})")
def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]:
"""
Generate attention map from stored activations (call AFTER the forward pass).
Returns a 2D numpy array in [0,1] or None if unavailable.
"""
try:
if self.activations is None:
return None
act = self.activations
# Handle different tensor shapes
if act.dim() == 5: # [B, C, T, H, W]
# Average over channels and time: -> [B, H, W]
attention = act.mean(dim=[1, 2])
elif act.dim() == 4: # [B, C, H, W]
attention = act.mean(dim=1) # -> [B, H, W]
else:
print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}")
return None
# Take first batch
attention = attention[0].detach().cpu().numpy()
# Normalize to [0, 1]
a_min, a_max = attention.min(), attention.max()
if a_max > a_min:
attention = (attention - a_min) / (a_max - a_min)
else:
attention = np.zeros_like(attention, dtype=np.float32)
return attention
except Exception as e:
print(f"⚠ [attention_viz] Generation failed: {e}")
return None
def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
"""Overlay heatmap on frame."""
if heatmap is None or frame is None or frame.size == 0:
return frame
h, w = frame.shape[:2]
heatmap_resized = cv2.resize(heatmap, (w, h))
heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
overlay = cv2.addWeighted(frame, 1 - alpha, heatmap_colored, alpha, 0)
return overlay
def cleanup(self):
if self.hook_handle is not None:
self.hook_handle.remove()
self.hook_handle = None
self.activations = None
class VideoReader:
"""
Unified frame reader:
• Regular videos via cv2.VideoCapture
• .mat 'videos' (e.g., SCAMPS): expects array like (T,H,W,3) or (H,W,T,3) or (T,H,W)
Returns frames as BGR uint8.
"""
def __init__(self, path: str):
self.path = str(path)
self._cap = None
self._mat = None
self._idx = 0
self._len = 0
self._shape = None
self._fps = 0
if self.path.lower().endswith(".mat") and MAT_SUPPORT:
self._open_mat(self.path)
else:
self._open_cv(self.path)
def _open_cv(self, path: str):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise RuntimeError("Cannot open video")
self._cap = cap
self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
self._fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
def _open_mat(self, path: str):
try:
md = loadmat(path)
# Common keys in SCAMPS-like dumps
for key in ("video", "frames", "vid", "data"):
if key in md and isinstance(md[key], np.ndarray):
arr = md[key]
break
else:
arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None)
if arr is None:
raise RuntimeError("No ndarray found in .mat")
a = np.asarray(arr)
# Normalize to (T,H,W,3)
if a.ndim == 4:
if a.shape[-1] == 3:
# Heuristic: if first dim is much smaller than spatial dims -> assume T
if a.shape[0] <= a.shape[1] and a.shape[0] <= a.shape[2]: # (T,H,W,3)
v = a
else: # (H,W,T,3) -> (T,H,W,3)
v = np.transpose(a, (2, 0, 1, 3))
else:
v = a[..., :1] # take first channel
elif a.ndim == 3:
# (T,H,W) or (H,W,T)
if a.shape[0] <= a.shape[1] and a.shape[0] <= a.shape[2]: # (T,H,W)
v = a
else: # (H,W,T) -> (T,H,W)
v = np.transpose(a, (2, 0, 1))
v = v[..., None]
else:
raise RuntimeError(f"Unsupported .mat video shape: {a.shape}")
v = np.ascontiguousarray(v)
if v.shape[-1] == 1:
v = np.repeat(v, 3, axis=-1)
v = v.astype(np.uint8)
self._mat = v
self._len = v.shape[0]
self._shape = v.shape[1:3]
self._fps = 0.0 # unknown; caller can override
except Exception as e:
raise RuntimeError(f"Failed to open .mat video: {e}")
def read(self):
"""Return (ret, frame[BGR]) like cv2.VideoCapture.read()."""
if self._mat is not None:
if self._idx >= self._len:
return False, None
frame = self._mat[self._idx]
self._idx += 1
return True, frame
else:
return self._cap.read()
def fps(self, fallback: int = 30) -> int:
if self._mat is not None:
return int(fallback)
if self._fps and self._fps > 0:
return int(self._fps)
f = self._cap.get(cv2.CAP_PROP_FPS)
return int(f) if f and f > 0 else int(fallback)
def length(self) -> int:
return self._len
def release(self):
if self._cap is not None:
self._cap.release()
def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]:
x, y, w, h = face
# forehead
fh = frame[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)]
# cheeks
ck = frame[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)]
# full face
ff = frame[y:y + h, x:x + w]
return {"forehead": fh, "cheeks": ck, "face": ff}
def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float:
if patch is None or patch.size == 0:
return -1e9
g = patch[..., 1].astype(np.float32) / 255.0 # green channel
g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling
g = g - g.mean()
try:
b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band")
y = signal.filtfilt(b, a, g, method="gust")
except Exception:
y = g
return float((y ** 2).mean())
def pick_auto_roi(face: Tuple[int, int, int, int],
frame: np.ndarray,
attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]:
"""Simple ROI selection using signal quality + optional attention weighting."""
cands = roi_candidates(face, frame)
scores = {k: roi_quality_score(v) for k, v in cands.items()}
if attn is not None and attn.size:
H, W = frame.shape[:2]
try:
attn_resized = cv2.resize(attn, (W, H))
x, y, w, h = face
fh_attn = attn_resized[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)].mean() if attn_resized.size > 0 else 0.0
ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0
ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0
scores['forehead'] += fh_attn * 0.2
scores['cheeks'] += ck_attn * 0.2
scores['face'] += ff_attn * 0.2
except Exception:
pass
best = max(scores, key=scores.get)
return cands[best], best
def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
"""
Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None).
Heuristics:
- Subject folder: any directory containing at least one video-like file (.mat or common video).
- If multiple videos, pick the largest by size.
- GT file: prefer known names in GT_FILENAMES, else any .txt/.csv/.json not named readme.
"""
pairs: List[Tuple[str, Optional[str]]] = []
if not root_dir.exists():
return pairs
def pick_pair(folder: Path) -> Optional[Tuple[str, Optional[str]]]:
vids = [p for p in folder.rglob("*") if p.is_file() and _looks_like_video(p)]
if not vids:
return None
vids.sort(key=lambda p: p.stat().st_size if p.exists() else 0, reverse=True)
video = vids[0]
gt: Optional[Path] = None
for p in folder.rglob("*"):
if p.is_file() and p.name.lower() in GT_FILENAMES:
gt = p
break
if gt is None:
cands = [
p for p in folder.rglob("*")
if p.is_file() and p.suffix.lower() in GT_EXTS and "readme" not in p.name.lower()
]
if cands:
gt = cands[0]
return str(video), (str(gt) if gt else None)
subs = [d for d in root_dir.iterdir() if d.is_dir()]
if subs:
for sub in subs:
pair = pick_pair(sub)
if pair:
pairs.append(pair)
else:
pair = pick_pair(root_dir) # the root itself might be a single-subject folder
if pair:
pairs.append(pair)
# Deduplicate
seen = set()
uniq: List[Tuple[str, Optional[str]]] = []
for v, g in pairs:
key = (v, g or "")
if key not in seen:
seen.add(key)
uniq.append((v, g))
return uniq
def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"):
import inspect
if model_file:
model_path = (repo_root / model_file).resolve()
if model_path.exists():
try:
mod = _import_from_file(model_path)
if hasattr(mod, model_class):
return getattr(mod, model_class)
except Exception:
pass
search_dirs = [
repo_root / "neural_methods" / "model",
repo_root / "neural_methods",
repo_root
]
name_pattern = re.compile(r"mamba", re.IGNORECASE)
for base_dir in search_dirs:
if not base_dir.exists():
continue
for py_file in base_dir.rglob("*.py"):
if "__pycache__" in str(py_file) or "mamba_ssm" in str(py_file):
continue
try:
mod = _import_from_file(py_file)
for name, obj in inspect.getmembers(mod):
if callable(obj) and name_pattern.search(name) and "ssm" not in name.lower():
if inspect.isclass(obj) and issubclass(obj, nn.Module):
return obj
except Exception:
continue
raise ImportError("Could not find PhysMamba model class")
def load_physmamba_model(ckpt_path: Path, device: torch.device,
model_file: str = "", model_class: str = "PhysMamba"):
repo_root = Path(".").resolve()
Builder = find_physmamba_builder(repo_root, model_file, model_class)
import inspect
ctor_trials = [
{},
{"d_model": 96},
{"dim": 96},
{"d_model": 96, "frames": 128, "img_size": 128, "in_chans": 3},
{"frames": 128, "img_size": 128, "in_chans": 3},
{"frame_depth": 3},
]
model = None
for kwargs in ctor_trials:
try:
candidate = Builder(**kwargs) if inspect.isclass(Builder) else Builder(**kwargs)
if isinstance(candidate, nn.Module):
model = candidate
break
except Exception:
continue
if model is None:
raise RuntimeError("Could not construct PhysMamba model")
try:
checkpoint = torch.load(str(ckpt_path), map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
try:
model.load_state_dict(state_dict, strict=False)
except Exception:
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
except Exception:
# If loading fails, you still get an uninitialized model (for debugging)
pass
model.to(device).eval()
try:
with torch.no_grad():
_ = model(torch.zeros(1, 3, 8, 128, 128, device=device))
except Exception:
# Shape sanity check failed, but we keep the model usable.
pass
# For now: attention visualization disabled (extract_attention_map returns None)
attention_viz = None
return model, attention_viz
def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray:
"""
Stable band-pass with edge-safety and parameter clipping.
"""
x = np.asarray(x, dtype=float)
n = int(fs * 2)
if x.size < max(n, 8): # need at least ~2s
return x
nyq = 0.5 * fs
lo = max(low / nyq, 1e-6)
hi = min(high / nyq, 0.999_999)
if not (0.0 < lo < hi < 1.0):
return x
try:
b, a = signal.butter(order, [lo, hi], btype="band")
padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1))
return signal.filtfilt(b, a, x, padlen=padlen)
except Exception:
return x
def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float:
"""
HR (BPM) via Welch PSD peak in [lo, hi] Hz.
"""
x = np.asarray(x, dtype=float)
if x.size < int(fs * 4.0): # need ~4s for a usable PSD
return 0.0
try:
nper = int(min(max(64, fs * 2), min(512, x.size)))
f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant")
mask = (f >= lo) & (f <= hi)
if not np.any(mask):
return 0.0
f_band = f[mask]
p_band = pxx[mask]
if p_band.size == 0 or np.all(~np.isfinite(p_band)):
return 0.0
fpk = float(f_band[np.argmax(p_band)])
bpm = fpk * 60.0
return float(np.clip(bpm, 30.0, 220.0))
except Exception:
return 0.0
def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
"""
HRV RMSSD from peaks; robust to short/flat segments.
"""
x = np.asarray(x, dtype=float)
if x.size < int(fs * 5.0):
return 0.0
try:
peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs)))
if len(peaks) < 3:
return 0.0
rr = np.diff(peaks) / fs * 1000.0 # ms
if rr.size < 2:
return 0.0
return float(np.sqrt(np.mean(np.diff(rr) ** 2)))
except Exception:
return 0.0
def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
"""
Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band.
Signature unchanged to avoid breaking callers.
"""
global _HR_SMOOTH
y = np.asarray(pred, dtype=float)
if y.size == 0:
return y, 0.0
# 1) band-limit
y_filt = bandpass_filter(y, fs=fs, low=0.7, high=3.5, order=4)
# 2) HR estimate
hr = hr_from_welch(y_filt, fs=fs, lo=0.7, hi=3.5)
# 3) gentle attraction to resting band (if way off)
if hr > 0:
lo, hi = REST_HR_RANGE
if hr < lo or hr > hi:
dist = abs(hr - REST_HR_TARGET)
alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65))
hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET
# 4) temporal smoothing to limit frame-to-frame jumps
if hr > 0:
if _HR_SMOOTH is None:
_HR_SMOOTH = hr
else:
step = float(np.clip(hr - _HR_SMOOTH, -MAX_JUMP_BPM, MAX_JUMP_BPM))
_HR_SMOOTH = _HR_SMOOTH + 0.6 * step
hr = float(_HR_SMOOTH)
return y_filt, float(hr)
def draw_face_and_roi(frame_bgr: np.ndarray,
face_bbox: Optional[Tuple[int, int, int, int]],
roi_bbox: Optional[Tuple[int, int, int, int]],
label: str = "ROI") -> np.ndarray:
"""
Draw face (green) and ROI (cyan) rectangles on a copy of the frame.
"""
vis = frame_bgr.copy()
if face_bbox is not None:
x, y, w, h = face_bbox
cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 230, 0), 2)
cv2.putText(vis, "FACE", (x, max(20, y - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 230, 0), 2)
if roi_bbox is not None:
rx, ry, rw, rh = roi_bbox
cv2.rectangle(vis, (rx, ry), (rx + rw, ry + rh), (255, 220, 0), 2)
cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2)
return vis
def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
roi_type: str,
frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
"""
Compute the ROI rectangle (x,y,w,h) from a face bbox and your ROI rule.
Matches your crop_roi geometry.
"""
x, y, w, h = face_bbox
H, W = frame_shape[:2]
if roi_type == "forehead":
rx = int(x + 0.25 * w); rw = int(0.5 * w)
ry = int(y + 0.10 * h); rh = int(0.20 * h)
elif roi_type == "cheeks":
rx = int(x + 0.15 * w); rw = int(0.70 * w)
ry = int(y + 0.55 * h); rh = int(0.30 * h)
else:
rx, ry, rw, rh = x, y, w, h
rx2 = min(W, rx + rw)
ry2 = min(H, ry + rh)
rx = max(0, rx); ry = max(0, ry)
if rx2 <= rx or ry2 <= ry:
return (0, 0, 0, 0)
return (rx, ry, rx2 - rx, ry2 - ry)
def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
"""
Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR.
Assumes chw = (3, H, W) with zero-mean, unit-var normalization per-frame.
"""
if chw is None or chw.ndim != 3 or chw.shape[0] != 3:
return np.zeros((128, 128, 3), dtype=np.uint8)
img = chw.copy()
vmin, vmax = float(img.min()), float(img.max())
if vmax <= vmin + 1e-6:
img = np.zeros_like(img)
else:
img = (img - vmin) / (vmax - vmin)
img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR
return img
def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
if gt_len <= 1:
return None
if gt_fs and gt_fs > 0:
return np.arange(gt_len, dtype=float) / float(gt_fs)
return None # will fall back to length-matching overlay
def plot_signals_with_gt(time_axis: np.ndarray,
raw_signal: np.ndarray,
post_signal: np.ndarray,
fs: int,
out_path: str,
gt_time: Optional[np.ndarray] = None,
gt_bvp: Optional[np.ndarray] = None,
title: str = "rPPG Signals (Pred vs GT)") -> str:
"""
Save a 3-pane plot: (1) predicted raw, (2) predicted post, (3) overlay Pred vs GT (normalized).
If GT is provided, it is resampled to the prediction time grid and lag-aligned
(±5 s) using cross-correlation. The overlay includes Pearson r, lag, and HR stats.
"""
import numpy as _np
import matplotlib.pyplot as _plt
from matplotlib.gridspec import GridSpec as _GridSpec
def z(x):
x = _np.asarray(x, dtype=float)
if x.size == 0:
return x
m = float(_np.nanmean(x))
s = float(_np.nanstd(x)) + 1e-8
return (x - m) / s
def _bandpass(x, fs_local, lo=0.7, hi=3.5, order=4):
try:
return bandpass_filter(_np.asarray(x, float), fs=fs_local, low=lo, high=hi, order=order)
except Exception:
return _np.asarray(x, float)
def _welch_hr(x, fs_local):
try:
return float(hr_from_welch(_np.asarray(x, float), fs=fs_local, lo=0.7, hi=3.5))
except Exception:
return 0.0
def _safe_interp(x_t, y, t_new):
"""Monotonic-safe 1D interpolation with clipping to valid domain."""
x_t = _np.asarray(x_t, dtype=float).ravel()
y = _np.asarray(y, dtype=float).ravel()
t_new = _np.asarray(t_new, dtype=float).ravel()
if x_t.size < 2 or y.size != x_t.size:
if y.size == 0 or t_new.size == 0:
return _np.zeros_like(t_new)
idx = _np.linspace(0, y.size - 1, num=t_new.size)
return _np.interp(_np.arange(t_new.size), idx, y)
order = _np.argsort(x_t)
x_t = x_t[order]
y = y[order]
mask = _np.concatenate(([True], _np.diff(x_t) > 0))
x_t = x_t[mask]
y = y[mask]
t_clip = _np.clip(t_new, x_t[0], x_t[-1])
return _np.interp(t_clip, x_t, y)
def _best_lag(pred, gt, fs_local, max_lag_s=5.0):
"""Return lag (sec) that best aligns GT to Pred using cross-correlation on z-scored signals."""
x = z(pred); y = z(gt)
if x.size < 8 or y.size < 8:
return 0.0
n = int(min(len(x), len(y)))
x = x[:n]; y = y[:n]
max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local))))
lags = _np.arange(-max_lag, max_lag + 1)
best_corr = -_np.inf
best_lag = 0
for L in lags:
if L < 0:
xx = x[-L:n]
yy = y[0:n+L]
elif L > 0:
xx = x[0:n-L]
yy = y[L:n]
else:
xx = x
yy = y
if xx.size < 8 or yy.size < 8:
continue
c = _np.corrcoef(xx, yy)[0, 1]
if _np.isfinite(c) and c > best_corr:
best_corr = c
best_lag = L
return float(best_lag / float(fs_local))
def _apply_lag(y, lag_sec, fs_local):
"""Shift y by lag_sec (positive => delay GT) using sample roll; edges set to NaN."""
y = _np.asarray(y, float)
if y.size == 0 or fs_local <= 0:
return y
shift = int(round(lag_sec * fs_local))
if shift == 0:
return y
out = _np.empty_like(y)
out[:] = _np.nan
if shift > 0:
out[shift:] = y[:-shift]
else:
out[:shift] = y[-shift:]
return out
t = _np.asarray(time_axis, dtype=float)
raw = _np.asarray(raw_signal, dtype=float)
post = _np.asarray(post_signal, dtype=float)
if t.size == 0:
t = _np.arange(post.size, dtype=float) / max(fs, 1)
have_gt = gt_bvp is not None and _np.asarray(gt_bvp).size > 0
gt_on_pred = None
lag_sec = 0.0
pearson_r = _np.nan
hr_pred = _welch_hr(_bandpass(post, fs), fs)
hr_gt = 0.0
if have_gt:
gt = _np.asarray(gt_bvp, dtype=float).ravel()
if gt_time is not None and _np.asarray(gt_time).size == gt.size:
gt_t = _np.asarray(gt_time, dtype=float).ravel()
gt_on_pred = _safe_interp(gt_t, gt, t)
else:
gt_on_pred = _safe_interp(
_np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size),
gt, t
)
pred_bp = _bandpass(post, fs)
gt_bp = _bandpass(gt_on_pred, fs)
lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0)
gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs)
valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp)
if valid.sum() >= 16:
pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1])
else:
pearson_r = _np.nan
hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs)
_plt.figure(figsize=(13, 6), dpi=110)
gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35)
ax1 = _plt.subplot(gs[0, 0])
ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5)
ax1.set_title(f"Predicted (Raw) — fs={fs} Hz")
ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude")
ax1.grid(True, alpha=0.3)
ax2 = _plt.subplot(gs[0, 1])
ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5)
ax2.set_title("Predicted (Post-processed)")
ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude")
ax2.grid(True, alpha=0.3)
ax3 = _plt.subplot(gs[1, :])
ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6)
if have_gt and gt_on_pred is not None:
gt_bp = _bandpass(gt_on_pred, fs)
gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs)
ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9)
txt = [
f"HR_pred: {hr_pred:.1f} BPM",
f"HR_gt: {hr_gt:.1f} BPM",
f"Pearson r: {pearson_r:.3f}" if _np.isfinite(pearson_r) else "Pearson r: --",
f"Lag: {lag_sec:+.2f} s"
]
ax3.text(0.01, 0.98, "\n".join(txt), transform=ax3.transAxes,
va="top", ha="left", fontsize=9,
bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="0.8", alpha=0.9))
ax3.set_title("Pred vs GT (z-score overlay, lag-aligned)")
else:
ax3.set_title("Pred vs GT (no GT provided)")
ax3.set_xlabel("Time (s)"); ax3.set_ylabel("z")
ax3.grid(True, alpha=0.3)
ax3.legend(loc="upper right")
_plt.suptitle(title, fontweight="bold")
_plt.tight_layout(rect=[0, 0.02, 1, 0.97])
_plt.savefig(out_path, bbox_inches="tight")
_plt.close()
return out_path
def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
"""
Robust single-face detector with a few practical guards:
- converts to gray safely
- equalizes histogram (helps underexposure)
- tries multiple scales / minNeighbors
- returns the largest face bbox (x,y,w,h) or None
"""
if frame is None or frame.size == 0:
return None
# If cascade is missing, fail fast (prevents cryptic OpenCV errors)
if FACE_CASCADE is None:
return None
try:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
except Exception:
# If color conversion fails, assume already gray
if frame.ndim == 2:
gray = frame.copy()
else:
gray = cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY)
# Light preproc to improve Haar performance
gray = cv2.equalizeHist(gray)
faces_all: List[Tuple[int, int, int, int]] = []
# Try a couple of parameter combos to be more forgiving
params = [
dict(scaleFactor=1.05, minNeighbors=3),
dict(scaleFactor=1.10, minNeighbors=4),
dict(scaleFactor=1.20, minNeighbors=5),
]
for p in params:
try:
faces = FACE_CASCADE.detectMultiScale(gray, **p)
if faces is not None and len(faces) > 0:
faces_all.extend([tuple(map(int, f)) for f in faces])
except Exception:
continue
if not faces_all:
return None
# Return the largest (by area)
return max(faces_all, key=lambda f: f[2] * f[3])
def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]:
"""
Crop ROI from the frame based on a face bbox and the selected roi_type.
Returns the cropped BGR ROI or None if invalid.
"""
if face_bbox is None or frame is None or frame.size == 0:
return None
x, y, w, h = map(int, face_bbox)
H, W = frame.shape[:2]
if roi_type == "forehead":
rx = int(x + 0.25 * w); rw = int(0.50 * w)
ry = int(y + 0.10 * h); rh = int(0.20 * h)
elif roi_type == "cheeks":
rx = int(x + 0.15 * w); rw = int(0.70 * w)
ry = int(y + 0.55 * h); rh = int(0.30 * h)
else:
rx, ry, rw, rh = x, y, w, h
# clamp in-bounds
rx = max(0, rx); ry = max(0, ry)
rx2 = min(W, rx + rw); ry2 = min(H, ry + rh)
if rx2 <= rx or ry2 <= ry:
return None
roi = frame[ry:ry2, rx:rx2]
# Avoid empty or 1-pixel slivers
if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4:
return None
return roi
def crop_roi_with_bbox(
face_bbox: Tuple[int, int, int, int],
roi_type: str,
frame: np.ndarray
) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]:
"""
Same as crop_roi, but also returns the ROI bbox (x, y, w, h) in frame coords.
"""
if face_bbox is None or frame is None or frame.size == 0:
return None, None
x, y, w, h = map(int, face_bbox)
H, W = frame.shape[:2]
if roi_type == "forehead":
rx = int(x + 0.25 * w); rw = int(0.50 * w)
ry = int(y + 0.10 * h); rh = int(0.20 * h)
elif roi_type == "cheeks":
rx = int(x + 0.15 * w); rw = int(0.70 * w)
ry = int(y + 0.55 * h); rh = int(0.30 * h)
else:
rx, ry, rw, rh = x, y, w, h
rx = max(0, rx); ry = max(0, ry)
rx2 = min(W, rx + rw); ry2 = min(H, ry + rh)
if rx2 <= rx or ry2 <= ry:
return None, None
roi = frame[ry:ry2, rx:rx2]
if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4:
return None, None
return roi, (rx, ry, rx2 - rx, ry2 - ry)
def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
"""
PhysMamba-compatible normalization with DiffNormalized support.
Returns (3, size, size).
"""
if face_bgr is None or face_bgr.size == 0:
return np.zeros((3, size, size), dtype=np.float32)
try:
face = cv2.resize(face_bgr, (size, size), interpolation=cv2.INTER_AREA)
except Exception:
face = cv2.resize(face_bgr, (size, size))
face = face.astype(np.float32) / 255.0
# Per-frame standardization
mean = face.mean(axis=(0, 1), keepdims=True)
std = face.std(axis=(0, 1), keepdims=True) + 1e-6
face = (face - mean) / std
# BGR -> RGB and HWC -> CHW
chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False)
return chw
def extract_attention_map(model, clip_tensor: torch.Tensor, attention_viz) -> Optional[np.ndarray]:
"""Attention visualization disabled - model architecture incompatible."""
return None
def create_attention_overlay(
frame: np.ndarray,
attention: Optional[np.ndarray],
attention_viz: Optional[SimpleActivationAttention] = None
) -> np.ndarray:
"""Create attention heatmap overlay (currently passthrough)."""
return frame
def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
H, W = roi_bgr.shape[:2]
base_bvp = forward_bvp(
model,
torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE))
.unsqueeze(0).unsqueeze(2).to(DEVICE) # fake T=1 path if needed
)
base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs)
heat = np.zeros((H, W), np.float32)
for y in range(0, H - patch + 1, stride):
for x in range(0, W - patch + 1, stride):
tmp = roi_bgr.copy()
tmp[y:y+patch, x:x+patch] = 127 # occlude
bvp = forward_bvp(
model,
torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE))
.unsqueeze(0).unsqueeze(2).to(DEVICE)
)
power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs)
drop = max(0.0, base_power - power)
heat[y:y+patch, x:x+patch] += drop
heat -= heat.min()
if heat.max() > 1e-8:
heat /= heat.max()
return heat
def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
"""
Try common 5D layouts:
[B, C, T, H, W] then [B, T, C, H, W].
"""
last_err = None
try:
return model(clip_tensor)
except Exception as e:
last_err = e
try:
return model(clip_tensor.permute(0, 2, 1, 3, 4).contiguous())
except Exception as e:
last_err = e
raise last_err
def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
"""
Forward and extract a 1D time-like BVP vector with length T_clip.
Tolerant to dict/list/tuple heads and odd shapes.
"""
T_clip = int(clip_tensor.shape[2]) # intended time length for [B,C,T,H,W]
with torch.no_grad():
out = _call_model_try_orders(model, clip_tensor)
# unwrap common containers
if isinstance(out, dict):
for key in ("bvp", "ppg", "signal", "pred", "y", "out", "logits"):
if key in out and isinstance(out[key], torch.Tensor):
out = out[key]
break
if isinstance(out, (list, tuple)):
tensors = [t for t in out if isinstance(t, torch.Tensor)]
if not tensors:
return np.zeros(T_clip, dtype=np.float32)
def score(t: torch.Tensor):
has_T = 1 if T_clip in t.shape else 0
return (has_T, t.numel())
out = max(tensors, key=score)
if not isinstance(out, torch.Tensor):
return np.zeros(T_clip, dtype=np.float32)
out = out.detach().cpu().float()
# ---- 1D
if out.ndim == 1:
v = out
if v.shape[0] == T_clip:
return v.numpy()
if v.numel() == 1:
return np.full(T_clip, float(v.item()), dtype=np.float32)
return np.resize(v.numpy(), T_clip).astype(np.float32)
# ---- 2D
if out.ndim == 2:
B, K = out.shape
if B == 1:
v = out[0]
return (v.numpy() if v.shape[0] == T_clip
else np.resize(v.numpy(), T_clip).astype(np.float32))
if B == T_clip:
return out[:, 0].numpy()
if K == T_clip:
return out[0, :].numpy()
return np.resize(out.flatten().numpy(), T_clip).astype(np.float32)
# ---- 3D
if out.ndim == 3:
B, D1, D2 = out.shape[0], out.shape[1], out.shape[2]
if D1 == T_clip: # [B, T, C]
return out[0, :, 0].numpy()
if D2 == T_clip: # [B, C, T]
return out[0, 0, :].numpy()
v = out.mean(dim=tuple(range(1, out.ndim))).squeeze(0)
return np.resize(v.numpy(), T_clip).astype(np.float32)
# ---- 4D
if out.ndim == 4:
B, A, H, W = out.shape
if A == T_clip: # [B, T, H, W]
return out[0].mean(dim=(1, 2)).numpy()
v = out[0].mean(dim=(1, 2))
return np.resize(v.numpy(), T_clip).astype(np.float32)
# ---- 5D+
if out.ndim >= 5:
shape = list(out.shape)
try:
t_idx = next(i for i, s in enumerate(shape) if (i != 0 and s == T_clip))
except StopIteration:
pooled = out[0].mean(dim=tuple(i for i in range(1, out.ndim) if i not in (-1,)))
v = pooled.flatten()
return np.resize(v.numpy(), T_clip).astype(np.float32)
axes = list(range(out.ndim))
perm = [0, t_idx] + [i for i in axes[1:] if i != t_idx]
o2 = out.permute(*perm) # [B, T, ...]
pooled = o2.mean(dim=tuple(range(2, o2.ndim))) # -> [B, T]
return pooled[0].numpy()
# fallback: constant vector with mean value
val = float(out.mean().item()) if out.numel() else 0.0
return np.full(T_clip, val, dtype=np.float32)
def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
"""
Classical rPPG from green-channel means when the model yields nothing.
Detrend -> bandpass -> z-normalize.
"""
if means is None:
return np.array([], dtype=np.float32)
x = np.asarray(means, dtype=np.float32)
if x.size == 0:
return np.array([], dtype=np.float32)
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
try:
x = signal.detrend(x, type="linear")
except Exception:
pass
y = bandpass_filter(x, fs=fs, low=0.7, high=3.5, order=4)
std = float(np.std(y)) + 1e-6
return (y / std).astype(np.float32)
def _to_floats(s: str) -> List[float]:
"""
Extract all real numbers from free-form text, including scientific notation.
Gracefully ignores comments, units, and non-numeric junk.
"""
if not isinstance(s, str) or not s:
return []
# Strip comments starting with #, //, or ;
s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE)
# Normalize common delimiters
s = s.replace(",", " ").replace(";", " ")
toks = re.findall(
r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?",
s
)
out: List[float] = []
for t in toks:
try:
v = float(t)
if np.isfinite(v):
out.append(v)
except Exception:
continue
return out
def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
"""
Parse ground-truth files from:
• UBFC/UBFC-rPPG style TXT: 3 lines => PPG, HR, timestep(s)
• Generic TXT: one column (or free-form) numeric sequence
• JSON: keys like 'ppg' / 'bvp' (optionally nested), 'hr', 'fs'
• CSV: columns named BVP/PPG/Signal + optional HR + optional Time
• MAT: array under keys ['ppg','bvp','signal','wave']; optionally 'fs'/'hr'
• NPY: np.load() 1-D array; optionally sidecar .json with fs/hr (same stem)
Returns:
bvp : np.ndarray (float) — may be empty
hr : float (mean HR if available or estimated)
fs_hint : float — sampling rate if derivable (0.0 if unknown)
"""
if not gt_path or not os.path.exists(gt_path):
return np.array([]), 0.0, 0.0
p = Path(gt_path)
ext = p.suffix.lower()
def _fs_from_time_vector(tv: np.ndarray) -> float:
tv = np.asarray(tv, dtype=float)
if tv.ndim != 1 or tv.size < 2:
return 0.0
diffs = np.diff(tv)
diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
return (1.0 / float(np.median(diffs))) if diffs.size else 0.0
def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float:
if bvp is None or bvp.size == 0:
return 0.0
fs_use = fs_hint if (fs_hint and fs_hint > 0) else 30.0
bp = bandpass_filter(bvp.astype(float), fs=fs_use)
return hr_from_welch(bp, fs=fs_use)
# ================= UBFC-style TXT (3 lines) =================
if p.name.lower() == "ground_truth.txt" or (
ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2
):
try:
lines = [
ln.strip()
for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()
if ln.strip()
]
ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else []
hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else []
t_vals = _to_floats(lines[2]) if len(lines) >= 3 else []
bvp = np.asarray(ppg_vals, dtype=float) if ppg_vals else np.array([], dtype=float)
hr = float(np.nanmean(hr_vals)) if hr_vals else 0.0
fs_hint = 0.0
if t_vals:
if len(t_vals) == 1:
dt = float(t_vals[0])
if dt > 0:
fs_hint = 1.0 / dt
else:
fs_hint = _fs_from_time_vector(np.asarray(t_vals, dtype=float))
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
# Fall through to generic handlers
pass
# ================= Generic TXT =================
if ext == ".txt":
try:
nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore"))
bvp = np.asarray(nums, dtype=float) if nums else np.array([], dtype=float)
hr = _hr_from_bvp(bvp, fs_hint=0.0) if bvp.size else 0.0
return bvp, hr, 0.0
except Exception:
return np.array([]), 0.0, 0.0
# ================= JSON =================
if ext == ".json":
try:
data = json.loads(p.read_text(encoding="utf-8", errors="ignore"))
def _seek(obj, keys):
for k in keys:
if isinstance(obj, dict) and k in obj:
return obj[k]
return None
bvp = _seek(data, ("ppg", "bvp", "signal", "wave"))
if bvp is None:
for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"):
if container_key in data:
cand = _seek(data[container_key], ("ppg", "bvp", "signal", "wave"))
if cand is not None:
bvp = cand
break
if bvp is not None:
bvp = np.asarray(bvp, dtype=float).ravel()
else:
bvp = np.array([], dtype=float)
fs_hint = 0.0
if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0:
fs_hint = float(data["fs"])
hr = 0.0
if "hr" in data:
v = data["hr"]
hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
return np.array([]), 0.0, 0.0
# ================= CSV =================
if ext == ".csv":
try:
df = pd.read_csv(p)
cols = {str(c).strip().lower(): c for c in df.columns}
def _first_match(names):
for nm in names:
if nm in cols:
return cols[nm]
return None
bvp_col = _first_match(["bvp", "ppg", "wave", "signal", "bvp_signal", "ppg_signal"])
hr_col = _first_match(["hr", "heart_rate", "hr_bpm", "bpm"])
t_col = _first_match(["time", "t", "timestamp", "sec", "seconds", "time_s"])
bvp = np.asarray(df[bvp_col].values, dtype=float) if bvp_col else np.array([], dtype=float)
fs_hint = 0.0
if t_col is not None and len(df[t_col].values) >= 2:
fs_hint = _fs_from_time_vector(np.asarray(df[t_col].values, dtype=float))
hr = float(np.nanmean(df[hr_col].values)) if (hr_col and df[hr_col].notna().any()) else 0.0
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
return np.array([]), 0.0, 0.0
# ================= MAT =================
if ext == ".mat":
try:
md = loadmat(str(p))
arr = None
for key in ("ppg", "bvp", "signal", "wave"):
if key in md and isinstance(md[key], np.ndarray):
arr = md[key]
break
if arr is None:
for v in md.values():
if isinstance(v, np.ndarray) and v.ndim == 1:
arr = v
break
bvp = np.asarray(arr, dtype=float).ravel() if arr is not None else np.array([], dtype=float)
fs_hint = 0.0
for k in ("fs", "Fs", "sampling_rate", "sr"):
if k in md:
try:
fs_hint = float(np.ravel(md[k])[0])
break
except Exception:
pass
hr = 0.0
if "hr" in md:
try:
hr = float(np.nanmean(np.ravel(md["hr"])))
except Exception:
hr = 0.0
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
return np.array([]), 0.0, 0.0
# ================= NPY =================
if ext == ".npy":
try:
bvp = np.asarray(np.load(str(p)), dtype=float).ravel()
fs_hint, hr = 0.0, 0.0
sidecar = p.with_suffix(".json")
if sidecar.exists():
try:
meta = json.loads(sidecar.read_text(encoding="utf-8", errors="ignore"))
if isinstance(meta.get("fs", None), (int, float)) and meta["fs"] > 0:
fs_hint = float(meta["fs"])
if "hr" in meta:
v = meta["hr"]
hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
except Exception:
pass
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
return np.array([]), 0.0, 0.0
# Fallback (unsupported extension)
return np.array([]), 0.0, 0.0
def scan_models() -> List[str]:
if not MODEL_DIR.exists():
return []
models = []
for f in sorted(MODEL_DIR.iterdir()):
if f.suffix.lower() == '.pth':
models.append(f.name)
return models
_GLOBAL_CONTROLS: Dict[str, Dict] = {}
def ensure_controls(control_id: str) -> Tuple[str, Dict]:
# Use a stable default so Pause/Resume/Stop work for the current run
if not control_id:
control_id = "default-session"
if control_id not in _GLOBAL_CONTROLS:
_GLOBAL_CONTROLS[control_id] = {
'pause': threading.Event(),
'stop': threading.Event()
}
return control_id, _GLOBAL_CONTROLS[control_id]
def process_video_file(
video_path: str,
gt_file: Optional[str],
model_name: str,
fps_input: int,
max_seconds: int,
roi_type: str,
control_id: str
):
"""
Enhanced video processing with Grad-CAM attention visualization,
plus per-frame illumination and motion logging.
Returns 14 outputs per yield, matching Gradio UI:
status, pred_hr, gt_hr, rmssd,
frame_path, attention_path, signal_path, raw_path, post_path, csv_path,
global_brightness, global_motion, roi_brightness, roi_motion
"""
global _HR_SMOOTH
global FRAME_METRICS # global frame metrics buffer
_HR_SMOOTH = None
FRAME_METRICS = [] # reset per run
def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
if gt_len <= 1:
return None
if gt_fs and gt_fs > 0:
return np.arange(gt_len, dtype=float) / float(gt_fs)
return None
control_id, controls = ensure_controls(control_id)
controls['stop'].clear()
# Helper for consistent error yields (14 outputs)
def _error_status(msg: str):
return (
msg,
None, None, None, # HR, GT HR, RMSSD
None, None, None, # frame, attention, signal
None, None, # raw, post
None, # csv
None, None, None, None # brightness & motion
)
if not model_name:
yield _error_status("ERROR: No model selected")
return
if isinstance(model_name, int):
model_name = str(model_name)
model_path = MODEL_DIR / model_name
if not model_path.exists():
yield _error_status("ERROR: Model not found")
return
try:
model, attention_viz = load_physmamba_model(model_path, DEVICE)
except Exception as e:
yield _error_status(f"ERROR loading model: {str(e)}")
return
gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0)
if not video_path or not os.path.exists(video_path):
yield _error_status("ERROR: Video not found")
return
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
yield _error_status("ERROR: Cannot open video")
return
fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
max_frames = int(max_seconds * fps) if max_seconds and max_seconds > 0 else total_frames
MAX_SIGNAL_LENGTH = max_frames if max_frames > 0 else fps * 600
frames_chw = deque(maxlen=DEFAULT_T)
raw_g_means = deque(maxlen=DEFAULT_T)
bvp_stream: List[float] = []
frame_idx = 0
last_infer = -1
last_bpm = 0.0
last_rmssd = 0.0
last_attention = None
# previous grayscale frames for motion
prev_gray = None
prev_roi_gray = None
start_time = time.time()
next_display = time.time()
tmpdir = Path(tempfile.gettempdir())
frame_path = tmpdir / "frame.jpg"
attention_path = tmpdir / "attention.jpg"
signal_path = tmpdir / "signal.png"
raw_path = tmpdir / "raw_signal.png"
post_path = tmpdir / "post_signal.png"
# Initial status yield (14 outputs)
yield (
"Starting… reading video frames",
None,
f"{gt_hr:.1f}" if gt_hr > 0 else "--",
None,
None, None, None, None, None, None,
None, None, None, None
)
while True:
if controls['stop'].is_set():
break
while controls['pause'].is_set():
time.sleep(0.2)
if controls['stop'].is_set():
break
ret, frame = cap.read()
if not ret or (max_frames > 0 and frame_idx >= max_frames):
break
frame_idx += 1
# default per-frame metrics for this iteration
global_brightness = None
global_motion = None
roi_brightness = None
roi_motion = None
# global illumination & motion (full frame)
try:
frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
except Exception:
frame_gray = None
if frame_gray is not None:
global_brightness = float(frame_gray.mean())
if prev_gray is not None and prev_gray.shape == frame_gray.shape:
diff = cv2.absdiff(frame_gray, prev_gray)
global_motion = float(diff.mean())
prev_gray = frame_gray
face = detect_face(frame)
vis_frame = frame.copy()
if face is not None:
x, y, w, h = face
cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
if roi_type == "auto":
roi, _ = pick_auto_roi(face, frame, attn=last_attention)
else:
roi = crop_roi(face, roi_type, frame)
if roi is not None and roi.size > 0:
# ROI-level brightness & motion
try:
roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
except Exception:
roi_gray = None
if roi_gray is not None:
roi_brightness = float(roi_gray.mean())
if prev_roi_gray is not None and prev_roi_gray.shape == roi_gray.shape:
roi_diff = cv2.absdiff(roi_gray, prev_roi_gray)
roi_motion = float(roi_diff.mean())
prev_roi_gray = roi_gray
# existing green-channel mean for rPPG
try:
g_mean = float(roi[..., 1].astype(np.float32).mean())
raw_g_means.append(g_mean)
except Exception:
pass
face_norm = normalize_frame(roi, DEFAULT_SIZE)
frames_chw.append(face_norm)
if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE:
try:
clip = np.stack(list(frames_chw), axis=1).astype(np.float32)
except Exception as e:
print(f"[infer] clip stack failed: {e}")
clip = None
bvp_out = None
if clip is not None:
clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE)
try:
raw = forward_bvp(model, clip_t)
if isinstance(raw, np.ndarray):
raw = np.nan_to_num(
raw, nan=0.0, posinf=0.0, neginf=0.0
).astype(np.float32, copy=False)
bvp_out = raw if raw.size > 0 else None
else:
bvp_out = None
except Exception as e:
print(f"[infer] forward_bvp error: {e}")
bvp_out = None
# Generate attention with Grad-CAM
try:
last_attention = extract_attention_map(model, clip_t, attention_viz)
except Exception as e:
print(f"⚠ Attention generation error: {e}")
last_attention = None
if bvp_out is None or bvp_out.size == 0:
gbuf = np.nan_to_num(
np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0
)
fb = _fallback_bvp_from_means(gbuf, fs=fps)
if isinstance(fb, np.ndarray) and fb.size > 0:
bvp_out = fb
else:
print("[infer] fallback produced empty output")
if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0:
tail = min(DEFAULT_STRIDE, bvp_out.size)
bvp_stream.extend(bvp_out[-tail:].tolist())
if len(bvp_stream) > MAX_SIGNAL_LENGTH:
bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
if len(bvp_stream) >= int(5 * fps):
seg = np.asarray(
bvp_stream[-int(10 * fps):], dtype=np.float32
)
_, last_bpm = postprocess_bvp(seg, fs=fps)
last_rmssd = compute_rmssd(seg, fs=fps)
if frame_idx % (DEFAULT_STRIDE * 2) == 0:
print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}")
else:
print("[infer] no usable bvp_out after fallback")
last_infer = frame_idx
else:
cv2.putText(
vis_frame,
"No face detected",
(20, 40),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(30, 200, 255),
2,
)
# Log per-frame metrics into global FRAME_METRICS
try:
time_s = frame_idx / float(fps) if fps > 0 else float(frame_idx)
except Exception:
time_s = float(frame_idx)
FRAME_METRICS.append({
"frame_idx": int(frame_idx),
"time_s": float(time_s),
"global_brightness": float(global_brightness) if global_brightness is not None else None,
"global_motion": float(global_motion) if global_motion is not None else None,
"roi_brightness": float(roi_brightness) if roi_brightness is not None else None,
"roi_motion": float(roi_motion) if roi_motion is not None else None,
})
# Pretty strings for UI
gb_str = f"{global_brightness:.2f}" if global_brightness is not None else None
gm_str = f"{global_motion:.2f}" if global_motion is not None else None
rb_str = f"{roi_brightness:.2f}" if roi_brightness is not None else None
rm_str = f"{roi_motion:.2f}" if roi_motion is not None else None
if last_bpm > 0:
color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255)
cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1)
cv2.putText(
vis_frame,
f"HR: {last_bpm:.1f} BPM",
(20, 48),
cv2.FONT_HERSHEY_SIMPLEX,
0.9,
color,
2,
)
vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
cv2.imwrite(str(frame_path), vis_frame)
cv2.imwrite(str(attention_path), vis_attention)
now = time.time()
if now >= next_display:
if len(bvp_stream) >= max(10, int(1 * fps)):
try:
signal_array = np.array(bvp_stream, dtype=float)
time_axis = np.arange(len(signal_array)) / fps
raw_signal = bandpass_filter(signal_array, fs=fps)
post_signal, _ = postprocess_bvp(signal_array, fs=fps)
raw_vis = raw_signal - np.mean(raw_signal)
post_vis = post_signal - np.mean(post_signal)
plt.figure(figsize=(10, 4), dpi=100)
plt.plot(time_axis, raw_vis, linewidth=1.6)
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM')
plt.grid(True, alpha=0.3)
plt.tight_layout(); plt.savefig(str(raw_path), dpi=100, bbox_inches='tight'); plt.close()
plt.figure(figsize=(10, 4), dpi=100)
plt.plot(time_axis, post_vis, linewidth=1.6)
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM')
plt.grid(True, alpha=0.3)
plt.tight_layout(); plt.savefig(str(post_path), dpi=100, bbox_inches='tight'); plt.close()
if gt_bvp is not None and gt_bvp.size > 0:
try:
gt_time = _gt_time_axis(len(gt_bvp), gt_fs)
plot_signals_with_gt(
time_axis=time_axis,
raw_signal=raw_signal,
post_signal=post_signal,
fs=fps,
out_path=str(signal_path),
gt_time=gt_time,
gt_bvp=gt_bvp,
title=f"Pred vs GT — HR: {last_bpm:.1f} BPM"
)
except Exception:
fig = plt.figure(figsize=(12, 5), dpi=100)
gs = GridSpec(1, 2, figure=fig, wspace=0.3)
ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6)
ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3)
ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6)
ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3)
plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold')
plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all')
else:
fig = plt.figure(figsize=(12, 5), dpi=100)
gs = GridSpec(1, 2, figure=fig, wspace=0.3)
ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6)
ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3)
ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6)
ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3)
plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold')
plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all')
except Exception:
pass
elapsed = now - start_time
status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
# Periodic UI update: 14 outputs
yield (
status,
f"{last_bpm:.1f}" if last_bpm > 0 else None,
f"{gt_hr:.1f}" if gt_hr > 0 else "--",
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
str(frame_path),
str(attention_path),
str(signal_path) if signal_path.exists() else None,
str(raw_path) if raw_path.exists() else None,
str(post_path) if post_path.exists() else None,
None, # CSV placeholder (filled at end)
gb_str, # global brightness
gm_str, # global motion
rb_str, # ROI brightness
rm_str # ROI motion
)
next_display = now + (1.0 / DISPLAY_FPS)
cap.release()
# Cleanup Grad-CAM
if attention_viz:
attention_viz.cleanup()
csv_path = None
if bvp_stream:
csv_path = Path(tempfile.gettempdir()) / "bvp_output.csv"
time_array = np.arange(len(bvp_stream)) / fps
signal_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps)
try:
pd.DataFrame({'time_s': time_array, 'bvp': signal_final}).to_csv(csv_path, index=False)
except Exception:
csv_path = None
try:
final_overlay = Path(tempfile.gettempdir()) / "signal_final_overlay.png"
if gt_bvp is not None and gt_bvp.size > 0:
gt_time = _gt_time_axis(len(gt_bvp), gt_fs)
plot_signals_with_gt(
time_axis=time_array,
raw_signal=bandpass_filter(np.array(bvp_stream, dtype=float), fs=fps),
post_signal=signal_final,
fs=fps,
out_path=str(final_overlay),
gt_time=gt_time,
gt_bvp=gt_bvp,
title=f"Final Pred vs GT — HR: {last_bpm:.1f} BPM"
)
if final_overlay.exists():
signal_path = final_overlay
except Exception:
pass
# Save per-frame illumination & motion metrics
frame_metrics_path = None
if FRAME_METRICS:
try:
frame_metrics_path = tmpdir / "frame_metrics.csv"
pd.DataFrame(FRAME_METRICS).to_csv(frame_metrics_path, index=False)
print(f"[metrics] Saved per-frame illumination & motion to: {frame_metrics_path}")
except Exception as e:
print(f"[metrics] Failed to save frame metrics CSV: {e}")
frame_metrics_path = None
# Decide which CSV to expose: metrics preferred, else BVP
final_csv = frame_metrics_path or csv_path
elapsed = time.time() - start_time
final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
# Final yield: 14 outputs
yield (
final_status,
f"{last_bpm:.1f}" if last_bpm > 0 else None,
f"{gt_hr:.1f}" if gt_hr > 0 else "--",
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
str(frame_path),
str(attention_path),
str(signal_path) if signal_path.exists() else None,
str(raw_path) if raw_path.exists() else None,
str(post_path) if post_path.exists() else None,
str(final_csv) if final_csv else None,
None, # final global brightness (leave None or reuse gb_str)
None, # final global motion
None, # final ROI brightness
None # final ROI motion
)
def process_stream(
input_source: str,
video_path: Optional[str],
gt_file: Optional[str],
model_name: str,
fps_input: int,
max_seconds: int,
roi_type: str,
control_id: str
):
if input_source == "Live Webcam":
yield from process_live_webcam(model_name, fps_input, roi_type, control_id)
else:
yield from process_video_file(
video_path, gt_file, model_name, fps_input,
max_seconds, roi_type, control_id
)
def pause_processing(control_id: str) -> str:
_, controls = ensure_controls(control_id)
controls['pause'].set()
return "Paused"
def resume_processing(control_id: str) -> str:
_, controls = ensure_controls(control_id)
controls['pause'].clear()
return "Resumed"
def stop_processing(control_id: str) -> str:
_, controls = ensure_controls(control_id)
controls['stop'].set()
controls['pause'].clear()
return "Stopped"
def reset_ui():
# 14 values in total, matching all outputs from run_processing / process_stream
return (
"Ready", # status_text
None, # hr_output
None, # gt_hr_output
None, # rmssd_output
None, # frame_output
None, # attention_output
None, # signal_output
None, # raw_signal_output
None, # post_signal_output
None, # csv_output
None, # global_brightness_output
None, # global_motion_output
None, # roi_brightness_output
None # roi_motion_output
)
def handle_folder_upload(files):
if not files:
return None, None, "No files uploaded"
if not isinstance(files, list):
files = [files]
video_path = None
for file_obj in files:
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
if file_path.suffix.lower() in VIDEO_EXTENSIONS:
video_path = str(file_path)
break
gt_path = None
gt_patterns = ['gtdump.txt', 'ground_truth.txt', 'gt.txt']
for file_obj in files:
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
if file_path.name.lower() in [p.lower() for p in gt_patterns]:
gt_path = str(file_path)
break
if not gt_path:
for file_obj in files:
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
if file_path.suffix.lower() in ['.txt', '.json', '.csv']:
if 'readme' not in file_path.name.lower():
gt_path = str(file_path)
break
status = []
if video_path:
status.append(f"Video: {Path(video_path).name}")
else:
status.append("No video found")
if gt_path:
status.append(f"GT: {Path(gt_path).name}")
else:
status.append("No ground truth")
return video_path, gt_path, " | ".join(status)
with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as demo:
gr.Markdown("# rPPG Analysis Tool with Attention Visualization")
with gr.Row():
input_source = gr.Radio(
choices=["Video File", "Subject Folder", "Live Webcam"],
value="Live Webcam",
label="Input Source"
)
with gr.Row():
target_layer_input = gr.Textbox(
label="Grad-CAM Target Layer (optional - leave empty for auto)",
placeholder="e.g., backbone.conv3, encoder.layer2",
value=""
)
with gr.Row(visible=False) as video_inputs:
with gr.Column():
video_upload = gr.Video(label="Upload Video", sources=["upload"])
with gr.Column():
gt_upload = gr.File(
label="Ground Truth (optional)",
file_types=[".txt", ".csv", ".json"]
)
with gr.Row(visible=False) as folder_inputs:
with gr.Column():
folder_upload = gr.File(
label="Upload Subject Folder",
file_count="directory",
file_types=None,
type="filepath"
)
folder_status = gr.Textbox(label="Folder Status", interactive=False)
with gr.Row(visible=True) as webcam_inputs:
gr.Markdown("### Live Webcam - Click Run to start")
def toggle_input_source(source):
return [
gr.update(visible=(source == "Video File")),
gr.update(visible=(source == "Subject Folder")),
gr.update(visible=(source == "Live Webcam"))
]
input_source.change(
toggle_input_source,
inputs=[input_source],
outputs=[video_inputs, folder_inputs, webcam_inputs]
)
folder_video = gr.State(None)
folder_gt = gr.State(None)
folder_upload.upload(
handle_folder_upload,
inputs=[folder_upload],
outputs=[folder_video, folder_gt, folder_status]
)
with gr.Row():
with gr.Column(scale=3):
model_dropdown = gr.Dropdown(
choices=scan_models(),
value=scan_models()[0] if scan_models() else None,
label="PhysMamba Model",
interactive=True
)
with gr.Column(scale=1):
refresh_models_btn = gr.Button("Refresh", variant="secondary")
refresh_models_btn.click(
lambda: gr.update(choices=scan_models()),
inputs=None,
outputs=[model_dropdown]
)
with gr.Row():
fps_slider = gr.Slider(
minimum=10, maximum=120, value=30, step=5,
label="FPS"
)
max_seconds_slider = gr.Slider(
minimum=10, maximum=600, value=180, step=10,
label="Max Duration (s)"
)
with gr.Row():
roi_dropdown = gr.Dropdown(
choices=["auto", "forehead", "cheeks", "face"],
value="auto",
label="ROI"
)
control_state = gr.State(value="")
placeholder_state = gr.State(value=None)
with gr.Row():
run_btn = gr.Button("Run", variant="primary")
pause_btn = gr.Button("Pause", variant="secondary")
resume_btn = gr.Button("Resume", variant="secondary")
stop_btn = gr.Button("Stop", variant="stop")
status_text = gr.Textbox(label="Status", lines=2, value="Ready")
with gr.Row():
hr_output = gr.Textbox(label="HR (BPM)", interactive=False)
gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False)
rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False)
# NEW: illumination & motion front-end outputs
with gr.Row():
global_brightness_output = gr.Textbox(
label="Global Brightness", interactive=False
)
global_motion_output = gr.Textbox(
label="Global Motion", interactive=False
)
roi_brightness_output = gr.Textbox(
label="ROI Brightness", interactive=False
)
roi_motion_output = gr.Textbox(
label="ROI Motion", interactive=False
)
with gr.Row():
with gr.Column():
frame_output = gr.Image(label="Video Feed", type="filepath")
with gr.Column():
attention_output = gr.Image(label="Attention Map", type="filepath")
with gr.Row():
signal_output = gr.Image(label="Signal Comparison", type="filepath")
with gr.Row():
with gr.Column():
raw_signal_output = gr.Image(label="Raw Signal", type="filepath")
with gr.Column():
post_signal_output = gr.Image(label="Post-processed Signal", type="filepath")
with gr.Row():
csv_output = gr.File(label="Download CSV")
pause_btn.click(
pause_processing,
inputs=[control_state],
outputs=[status_text]
)
resume_btn.click(
resume_processing,
inputs=[control_state],
outputs=[status_text]
)
stop_btn.click(
stop_processing,
inputs=[control_state],
outputs=[status_text]
).then(
reset_ui,
inputs=None,
outputs=[
status_text,
hr_output,
gt_hr_output,
rmssd_output,
frame_output,
attention_output,
signal_output,
raw_signal_output,
post_signal_output,
csv_output,
global_brightness_output,
global_motion_output,
roi_brightness_output,
roi_motion_output,
]
)
def run_processing(
input_source,
video_upload,
gt_upload,
folder_video,
folder_gt,
model_name,
fps,
max_sec,
roi,
ctrl_id
):
"""Wrapper that resolves paths and streams from process_stream."""
if isinstance(model_name, int):
model_name = str(model_name)
if not model_name:
# must return 14 outputs to match UI wiring
yield (
"ERROR: No model selected",
None, None, None, # HR, GT HR, RMSSD
None, None, None, # frame, attention, signal
None, None, # raw, post
None, # CSV
None, None, None, None # brightness & motion fields
)
return
if input_source == "Video File":
video_path = _as_path(video_upload)
gt_file = _as_path(gt_upload)
elif input_source == "Subject Folder":
video_path = _as_path(folder_video)
gt_file = _as_path(folder_gt)
else: # Live Webcam
video_path, gt_file = None, None
# This yields 14-tuples from process_video_file / process_live_webcam
yield from process_stream(
input_source, video_path, gt_file,
model_name, fps, max_sec, roi, ctrl_id
)
run_btn.click(
fn=run_processing,
inputs=[
input_source,
video_upload,
gt_upload,
folder_video,
folder_gt,
model_dropdown,
fps_slider,
max_seconds_slider,
roi_dropdown,
control_state
],
outputs=[
status_text,
hr_output,
gt_hr_output,
rmssd_output,
frame_output,
attention_output,
signal_output,
raw_signal_output,
post_signal_output,
csv_output,
global_brightness_output,
global_motion_output,
roi_brightness_output,
roi_motion_output,
]
)
if __name__ == "__main__":
import os
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=False,
)