demo-space / app.py
swetchareddytukkani's picture
Update app.py
ca1002d verified
import os
os.environ.setdefault("OPENCV_AVFOUNDATION_SKIP_AUTH", "1")
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import warnings
warnings.filterwarnings("ignore")
import re
import io
import json
import tempfile
import time
import threading
import uuid
import importlib.util
from pathlib import Path
from collections import deque
from typing import Optional, Tuple, Dict, List
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import signal
from scipy.signal import find_peaks, welch, get_window
# .mat support (SCAMPS)
try:
from scipy.io import loadmat
MAT_SUPPORT = True
except Exception:
MAT_SUPPORT = False
# Matplotlib headless backend
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import gradio as gr
class PhysMambaattention_viz:
"""Simplified Grad-CAM for PhysMamba."""
def __init__(self, model: nn.Module, device: torch.device):
self.model = model
self.device = device
self.activations = None
self.gradients = None
self.hook_handles = []
def _get_target_layer(self):
"""Auto-detect the best layer for visualization, excluding Mamba/SSM layers."""
# Strategy 1: Look for temporal difference convolution
for name, module in reversed(list(self.model.named_modules())):
if ('tdc' in name.lower() or 'temporal_diff' in name.lower()) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
print(f"[attention_viz] Selected TDC layer: {name} ({type(module).__name__})")
return name, module
# Strategy 2: Look for 3D convolutions (temporal-spatial)
for name, module in reversed(list(self.model.named_modules())):
if isinstance(module, nn.Conv3d) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
print(f"[attention_viz] Selected Conv3d layer: {name}")
return name, module
# Strategy 3: Look for 2D convolutions (spatial)
for name, module in reversed(list(self.model.named_modules())):
if isinstance(module, nn.Conv2d) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
print(f"[attention_viz] Selected Conv2d layer: {name}")
return name, module
# Strategy 4: Look for any convolution (last resort)
for name, module in reversed(list(self.model.named_modules())):
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
print(f"[attention_viz] Selected Conv layer (fallback): {name} ({type(module).__name__})")
return name, module
# Strategy 5: Print all available layers and pick the last non-Mamba one
print("\n[attention_viz] Available layers:")
suitable_layers = []
for name, module in self.model.named_modules():
layer_type = type(module).__name__
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
is_mamba = 'mamba' in name.lower() or 'ssm' in name.lower()
print(f" - {name}: {layer_type} {'(MAMBA - SKIP)' if is_mamba else ''}")
if not is_mamba:
suitable_layers.append((name, module))
if suitable_layers:
name, module = suitable_layers[-1]
print(f"\n[attention_viz] Selected last suitable layer: {name} ({type(module).__name__})")
return name, module
raise ValueError("No suitable layer found for Grad-CAM (all layers are Mamba/SSM)")
def _forward_hook(self, module, input, output):
self.activations = output.detach()
def _backward_hook(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def register_hooks(self, target_layer_name: Optional[str] = None):
"""Register hooks on target layer"""
self._remove_hooks()
if target_layer_name and target_layer_name.strip():
# Manual selection
try:
target_module = dict(self.model.named_modules())[target_layer_name.strip()]
print(f"Using manually specified layer: {target_layer_name}")
except KeyError:
print(f"⚠ Layer '{target_layer_name}' not found, falling back to auto-detection")
target_layer_name, target_module = self._get_target_layer()
else:
# Auto-detection
target_layer_name, target_module = self._get_target_layer()
fwd_handle = target_module.register_forward_hook(self._forward_hook)
bwd_handle = target_module.register_full_backward_hook(self._backward_hook)
self.hook_handles = [fwd_handle, bwd_handle]
def _remove_hooks(self):
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
self.activations = None
self.gradients = None
def generate(self, input_tensor: torch.Tensor) -> np.ndarray:
"""Generate Grad-CAM heatmap with improved gradient handling."""
# Set model to eval but keep gradient computation
self.model.eval()
# Ensure tensor requires grad
input_tensor = input_tensor.requires_grad_(True)
# Forward pass
output = self.model(input_tensor)
# Handle different output types
if isinstance(output, dict):
# Try common keys
for key in ('pred', 'output', 'bvp', 'logits', 'out'):
if key in output and isinstance(output[key], torch.Tensor):
output = output[key]
break
# Get scalar for backward
if output.numel() == 1:
target = output
else:
# Use mean of output
target = output.mean()
print(f"[attention_viz] Output shape: {output.shape}, Target: {target.item():.4f}")
# Backward pass
self.model.zero_grad()
target.backward(retain_graph=False)
# Check if we got gradients
if self.activations is None:
print("⚠ No activations captured!")
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
if self.gradients is None:
print("⚠ No gradients captured!")
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
print(f"[attention_viz] Activations shape: {self.activations.shape}")
print(f"[attention_viz] Gradients shape: {self.gradients.shape}")
activations = self.activations
gradients = self.gradients
# Compute weights (global average pooling of gradients)
if gradients.dim() == 5: # [B, C, T, H, W]
weights = gradients.mean(dim=[2, 3, 4], keepdim=True)
elif gradients.dim() == 4: # [B, C, H, W]
weights = gradients.mean(dim=[2, 3], keepdim=True)
elif gradients.dim() == 3: # [B, C, T]
weights = gradients.mean(dim=2, keepdim=True).unsqueeze(-1).unsqueeze(-1)
else:
print(f"⚠ Unexpected gradient dimensions: {gradients.dim()}")
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
# Weighted combination
cam = (weights * activations).sum(dim=1, keepdim=True)
# If 5D, average over time
if cam.dim() == 5:
cam = cam.mean(dim=2)
# ReLU
cam = F.relu(cam)
# Convert to numpy
cam = cam.squeeze().cpu().detach().numpy()
# Normalize
if cam.max() > cam.min():
cam = (cam - cam.min()) / (cam.max() - cam.min())
print(f"[attention_viz] Final heatmap shape: {cam.shape}, range: [{cam.min():.3f}, {cam.max():.3f}]")
return cam
def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
"""Overlay heatmap on frame."""
h, w = frame.shape[:2]
heatmap_resized = cv2.resize(heatmap, (w, h))
heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
return overlay
def cleanup(self):
self._remove_hooks()
def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray:
"""Apply DiffNormalized preprocessing from PhysMamba paper."""
if len(frames) < 2:
return np.zeros((len(frames), *frames[0].shape), dtype=np.float32)
diff_frames = []
for i in range(len(frames)):
if i == 0:
diff_frames.append(np.zeros_like(frames[0], dtype=np.float32))
else:
curr = frames[i].astype(np.float32)
prev = frames[i-1].astype(np.float32)
denominator = curr + prev + 1e-8
diff = (curr - prev) / denominator
diff_frames.append(diff)
diff_array = np.stack(diff_frames)
std = diff_array.std()
if std > 0:
diff_array = diff_array / std
return diff_array
def preprocess_for_physmamba(frames: List[np.ndarray],
target_frames: int = 128,
target_size: int = 128) -> torch.Tensor:
"""Complete preprocessing pipeline for PhysMamba model."""
if len(frames) < target_frames:
frames = frames + [frames[-1]] * (target_frames - len(frames))
elif len(frames) > target_frames:
indices = np.linspace(0, len(frames)-1, target_frames).astype(int)
frames = [frames[i] for i in indices]
frames_rgb = [f[..., ::-1].copy() for f in frames]
frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb]
frames_diff = apply_diff_normalized(frames_resized)
frames_transposed = np.transpose(frames_diff, (3, 0, 1, 2))
frames_batched = np.expand_dims(frames_transposed, axis=0)
tensor = torch.from_numpy(frames_batched.astype(np.float32))
return tensor
HERE = Path(__file__).parent
MODEL_DIR = HERE / "final_model_release"
LOG_DIR = HERE / "logs"
ANALYSIS_DIR = HERE / "analysis"
for d in [MODEL_DIR, LOG_DIR, ANALYSIS_DIR]:
d.mkdir(exist_ok=True)
DEVICE = (
torch.device("cuda") if torch.cuda.is_available()
else torch.device("mps") if torch.backends.mps.is_available()
else torch.device("cpu")
)
FACE_CASCADE = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
DEFAULT_SIZE = 128 # input H=W to model
DEFAULT_T = 128 # clip length
DEFAULT_STRIDE = 8 # inference hop
DISPLAY_FPS = 10
VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v']
_HR_SMOOTH = None
REST_HR_TARGET = 72.0
REST_HR_RANGE = (55.0, 95.0)
MAX_JUMP_BPM = 8.0
# Recognized GT files for subject folders
GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"}
GT_EXTS = {".txt", ".csv", ".json"}
def _as_path(maybe) -> Optional[str]:
"""Return a filesystem path from Gradio values (str, dict, tempfile objects)."""
if maybe is None:
return None
if isinstance(maybe, str):
return maybe
if isinstance(maybe, dict):
return maybe.get("name") or maybe.get("path")
name = getattr(maybe, "name", None) # tempfile-like object
if isinstance(name, str):
return name
try:
return str(maybe)
except Exception:
return None
def _import_from_file(py_path: Path):
spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path))
if not spec or not spec.loader:
raise ImportError(f"Cannot import module from {py_path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _looks_like_video(p: Path) -> bool:
if p.suffix.lower() == ".mat":
return True
return p.suffix.lower() in VIDEO_EXTENSIONS
class SimpleActivationAttention:
"""Lightweight attention visualization without gradients."""
def __init__(self, model: nn.Module, device: torch.device):
self.model = model
self.device = device
self.activations = None
self.hook_handle = None
def _activation_hook(self, module, input, output):
"""Capture activations during forward pass."""
self.activations = output.detach()
def register_hook(self):
"""Register hook on a suitable layer."""
# Find the last convolutional layer before Mamba
target = None
target_name = None
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower():
target = module
target_name = name
if target is None:
print("⚠ [attention_viz] No suitable conv layer found, attention disabled")
return
self.hook_handle = target.register_forward_hook(self._activation_hook)
print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})")
def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]:
"""Generate attention map from activations (call after forward pass)."""
try:
if self.activations is None:
return None
# Process activations to create spatial attention
act = self.activations
# Handle different tensor shapes
if act.dim() == 5: # [B, C, T, H, W]
# Average over time and channels
attention = act.mean(dim=[1, 2]) # -> [B, H, W]
elif act.dim() == 4: # [B, C, H, W]
attention = act.mean(dim=1) # -> [B, H, W]
else:
print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}")
return None
# Convert to numpy
attention = attention.squeeze().cpu().numpy()
# Normalize to [0, 1]
if attention.max() > attention.min():
attention = (attention - attention.min()) / (attention.max() - attention.min())
return attention
except Exception as e:
print(f"⚠ [attention_viz] Generation failed: {e}")
return None
def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
"""Overlay heatmap on frame."""
h, w = frame.shape[:2]
heatmap_resized = cv2.resize(heatmap, (w, h))
heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
return overlay
def cleanup(self):
if self.hook_handle is not None:
self.hook_handle.remove()
class VideoReader:
"""
Unified frame reader:
• Regular videos via cv2.VideoCapture
• .mat 'videos' (e.g., SCAMPS): expects array like (T,H,W,3) or (H,W,T,3) or (T,H,W)
Returns frames as BGR uint8.
"""
def __init__(self, path: str):
self.path = str(path)
self._cap = None
self._mat = None
self._idx = 0
self._len = 0
self._shape = None
if self.path.lower().endswith(".mat") and MAT_SUPPORT:
self._open_mat(self.path)
else:
self._open_cv(self.path)
def _open_cv(self, path: str):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise RuntimeError("Cannot open video")
self._cap = cap
self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
def _open_mat(self, path: str):
try:
md = loadmat(path)
# Common keys in SCAMPS-like dumps
for key in ("video", "frames", "vid", "data"):
if key in md and isinstance(md[key], np.ndarray):
arr = md[key]
break
else:
arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None)
if arr is None:
raise RuntimeError("No ndarray found in .mat")
a = np.asarray(arr)
# Normalize to (T,H,W,3)
if a.ndim == 4:
if a.shape[-1] == 3:
if a.shape[0] < a.shape[2]: # (T,H,W,3) heuristic
v = a
else: # (H,W,T,3) -> (T,H,W,3)
v = np.transpose(a, (2, 0, 1, 3))
else:
v = a[..., :1] # take first channel
elif a.ndim == 3:
if a.shape[0] < a.shape[2]: # (T,H,W)
v = a
else: # (H,W,T) -> (T,H,W)
v = np.transpose(a, (2, 0, 1))
v = v[..., None]
else:
raise RuntimeError(f"Unsupported .mat video shape: {a.shape}")
v = np.ascontiguousarray(v)
if v.shape[-1] == 1:
v = np.repeat(v, 3, axis=-1)
v = v.astype(np.uint8)
self._mat = v
self._len = v.shape[0]
self._shape = v.shape[1:3]
except Exception as e:
raise RuntimeError(f"Failed to open .mat video: {e}")
def read(self):
"""Return (ret, frame[BGR]) like cv2.VideoCapture.read()."""
if self._mat is not None:
if self._idx >= self._len:
return False, None
frame = self._mat[self._idx]
self._idx += 1
return True, frame
else:
return self._cap.read()
def fps(self, fallback: int = 30) -> int:
if self._mat is not None:
return fallback # .mat typically lacks FPS; caller can override
f = self._cap.get(cv2.CAP_PROP_FPS)
return int(f) if f and f > 0 else fallback
def length(self) -> int:
return self._len
def release(self):
if self._cap is not None:
self._cap.release()
def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]:
x, y, w, h = face
# forehead
fh = frame[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)]
# cheeks
ck = frame[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)]
# full face
ff = frame[y:y + h, x:x + w]
return {"forehead": fh, "cheeks": ck, "face": ff}
def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float:
if patch is None or patch.size == 0:
return -1e9
g = patch[..., 1].astype(np.float32) / 255.0 # green channel
g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling
g = g - g.mean()
b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band")
try:
y = signal.filtfilt(b, a, g, method="gust")
except Exception:
y = g
return float((y ** 2).mean())
def pick_auto_roi(face: Tuple[int, int, int, int],
frame: np.ndarray,
attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]:
"""Simple ROI selection."""
cands = roi_candidates(face, frame)
scores = {k: roi_quality_score(v) for k, v in cands.items()}
if attn is not None and attn.size:
H, W = frame.shape[:2]
try:
attn_resized = cv2.resize(attn, (W, H))
x, y, w, h = face
fh_attn = attn_resized[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)].mean() if attn_resized.size > 0 else 0.0
ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0
ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0
scores['forehead'] += fh_attn * 0.2
scores['cheeks'] += ck_attn * 0.2
scores['face'] += ff_attn * 0.2
except Exception:
pass
best = max(scores, key=scores.get)
return cands[best], best
def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
"""
Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None).
Heuristics:
- Subject folder: any directory containing at least one video-like file (.mat or common video).
- If multiple videos, pick the largest by size.
- GT file: prefer known names in GT_FILENAMES, else any .txt/.csv/.json not named readme.
"""
pairs: List[Tuple[str, Optional[str]]] = []
if not root_dir.exists():
return pairs
def pick_pair(folder: Path) -> Optional[Tuple[str, Optional[str]]]:
vids = [p for p in folder.rglob("*") if p.is_file() and _looks_like_video(p)]
if not vids:
return None
vids.sort(key=lambda p: p.stat().st_size if p.exists() else 0, reverse=True)
video = vids[0]
gt: Optional[Path] = None
for p in folder.rglob("*"):
if p.is_file() and p.name.lower() in GT_FILENAMES:
gt = p
break
if gt is None:
cands = [
p for p in folder.rglob("*")
if p.is_file() and p.suffix.lower() in GT_EXTS and "readme" not in p.name.lower()
]
if cands:
gt = cands[0]
return str(video), (str(gt) if gt else None)
subs = [d for d in root_dir.iterdir() if d.is_dir()]
if subs:
for sub in subs:
pair = pick_pair(sub)
if pair:
pairs.append(pair)
else:
pair = pick_pair(root_dir) # the root itself might be a single-subject folder
if pair:
pairs.append(pair)
# Deduplicate
seen = set()
uniq: List[Tuple[str, Optional[str]]] = []
for v, g in pairs:
key = (v, g or "")
if key not in seen:
seen.add(key)
uniq.append((v, g))
return uniq
def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"):
import inspect
if model_file:
model_path = (repo_root / model_file).resolve()
if model_path.exists():
try:
mod = _import_from_file(model_path)
if hasattr(mod, model_class):
return getattr(mod, model_class)
except Exception:
pass
search_dirs = [
repo_root / "neural_methods" / "model",
repo_root / "neural_methods",
repo_root
]
name_pattern = re.compile(r"mamba", re.IGNORECASE)
for base_dir in search_dirs:
if not base_dir.exists():
continue
for py_file in base_dir.rglob("*.py"):
if "__pycache__" in str(py_file) or "mamba_ssm" in str(py_file):
continue
try:
mod = _import_from_file(py_file)
for name, obj in inspect.getmembers(mod):
if callable(obj) and name_pattern.search(name) and "ssm" not in name.lower():
if inspect.isclass(obj) and issubclass(obj, nn.Module):
return obj
except Exception:
continue
raise ImportError(f"Could not find PhysMamba model class")
def load_physmamba_model(ckpt_path: Path, device: torch.device,
model_file: str = "", model_class: str = "PhysMamba"):
repo_root = Path(".").resolve()
Builder = find_physmamba_builder(repo_root, model_file, model_class)
import inspect
ctor_trials = [
{},
{"d_model": 96},
{"dim": 96},
{"d_model": 96, "frames": 128, "img_size": 128, "in_chans": 3},
{"frames": 128, "img_size": 128, "in_chans": 3},
{"frame_depth": 3},
]
model = None
for kwargs in ctor_trials:
try:
candidate = Builder(**kwargs) if inspect.isclass(Builder) else Builder(**kwargs)
if isinstance(candidate, nn.Module):
model = candidate
break
except Exception:
continue
if model is None:
raise RuntimeError("Could not construct PhysMamba model")
try:
checkpoint = torch.load(str(ckpt_path), map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
try:
model.load_state_dict(state_dict, strict=False)
except Exception:
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
except Exception:
pass
model.to(device).eval()
try:
with torch.no_grad():
_ = model(torch.zeros(1, 3, 8, 128, 128, device=device))
except Exception:
pass
# Disable attention visualization since model forward pass is incompatible
attention_viz = None
return model, attention_viz
def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray:
"""
Stable band-pass with edge-safety and parameter clipping.
"""
x = np.asarray(x, dtype=float)
n = int(fs * 2)
if x.size < max(n, 8): # need at least ~2s
return x
nyq = 0.5 * fs
lo = max(low / nyq, 1e-6)
hi = min(high / nyq, 0.999_999)
if not (0.0 < lo < hi < 1.0):
return x
try:
b, a = signal.butter(order, [lo, hi], btype="band")
# padlen must be < len(x); reduce when short
padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1))
return signal.filtfilt(b, a, x, padlen=padlen)
except Exception:
return x
def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float:
"""
HR (BPM) via Welch PSD peak in [lo, hi] Hz.
"""
x = np.asarray(x, dtype=float)
if x.size < int(fs * 4.0): # need ~4s for a usable PSD
return 0.0
try:
# nperseg tuned for short windows while avoiding tiny segments
nper = int(min(max(64, fs * 2), min(512, x.size)))
f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant")
mask = (f >= lo) & (f <= hi)
if not np.any(mask):
return 0.0
f_band = f[mask]
p_band = pxx[mask]
if p_band.size == 0 or np.all(~np.isfinite(p_band)):
return 0.0
fpk = float(f_band[np.argmax(p_band)])
bpm = fpk * 60.0
# clip to plausible range
return float(np.clip(bpm, 30.0, 220.0))
except Exception:
return 0.0
def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
"""
HRV RMSSD from peaks; robust to short/flat segments.
"""
x = np.asarray(x, dtype=float)
if x.size < int(fs * 5.0):
return 0.0
try:
# peak distance ~ 0.5s minimum (avoid double counting)
peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs)))
if len(peaks) < 3:
return 0.0
rr = np.diff(peaks) / fs * 1000.0 # ms
if rr.size < 2:
return 0.0
return float(np.sqrt(np.mean(np.diff(rr) ** 2)))
except Exception:
return 0.0
def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
"""
Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band.
Signature unchanged to avoid breaking callers.
"""
global _HR_SMOOTH
y = np.asarray(pred, dtype=float)
if y.size == 0:
return y, 0.0
# 1) band-limit
y_filt = bandpass_filter(y, fs=fs, low=0.7, high=3.5, order=4)
# 2) HR estimate
hr = hr_from_welch(y_filt, fs=fs, lo=0.7, hi=3.5)
# 3) gentle attraction to resting band (if way off)
if hr > 0:
lo, hi = REST_HR_RANGE
if hr < lo or hr > hi:
dist = abs(hr - REST_HR_TARGET)
# farther away -> stronger pull
alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65))
hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET
# 4) temporal smoothing to limit frame-to-frame jumps
if hr > 0:
if _HR_SMOOTH is None:
_HR_SMOOTH = hr
else:
step = float(np.clip(hr - _HR_SMOOTH, -MAX_JUMP_BPM, MAX_JUMP_BPM))
_HR_SMOOTH = _HR_SMOOTH + 0.6 * step
hr = float(_HR_SMOOTH)
return y_filt, float(hr)
def draw_face_and_roi(frame_bgr: np.ndarray,
face_bbox: Optional[Tuple[int, int, int, int]],
roi_bbox: Optional[Tuple[int, int, int, int]],
label: str = "ROI") -> np.ndarray:
"""
Draw face (green) and ROI (cyan) rectangles on a copy of the frame.
"""
vis = frame_bgr.copy()
if face_bbox is not None:
x, y, w, h = face_bbox
cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 230, 0), 2)
cv2.putText(vis, "FACE", (x, max(20, y - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 230, 0), 2)
if roi_bbox is not None:
rx, ry, rw, rh = roi_bbox
cv2.rectangle(vis, (rx, ry), (rx + rw, ry + rh), (255, 220, 0), 2)
cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2)
return vis
def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
roi_type: str,
frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
"""
Compute the ROI rectangle (x,y,w,h) from a face bbox and your ROI rule.
Matches your crop_roi geometry.
"""
x, y, w, h = face_bbox
H, W = frame_shape[:2]
if roi_type == "forehead":
rx = int(x + 0.25 * w); rw = int(0.5 * w)
ry = int(y + 0.10 * h); rh = int(0.20 * h)
elif roi_type == "cheeks":
rx = int(x + 0.15 * w); rw = int(0.70 * w)
ry = int(y + 0.55 * h); rh = int(0.30 * h)
else:
rx, ry, rw, rh = x, y, w, h
rx2 = min(W, rx + rw)
ry2 = min(H, ry + rh)
rx = max(0, rx); ry = max(0, ry)
if rx2 <= rx or ry2 <= ry:
return (0, 0, 0, 0)
return (rx, ry, rx2 - rx, ry2 - ry)
def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
"""
Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR.
Assumes chw = (3, H, W) with zero-mean, unit-var normalization per-frame.
"""
if chw is None or chw.ndim != 3 or chw.shape[0] != 3:
return np.zeros((128, 128, 3), dtype=np.uint8)
# Undo channel-first & normalization to a viewable image
img = chw.copy()
# Re-normalize to 0..1 by min-max of the tensor to "show" contrast
vmin, vmax = float(img.min()), float(img.max())
if vmax <= vmin + 1e-6:
img = np.zeros_like(img)
else:
img = (img - vmin) / (vmax - vmin)
img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR
return img
def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
if gt_len <= 1:
return None
if gt_fs and gt_fs > 0:
return np.arange(gt_len, dtype=float) / float(gt_fs)
return None # will fall back to length-matching overlay
def plot_signals_with_gt(time_axis: np.ndarray,
raw_signal: np.ndarray,
post_signal: np.ndarray,
fs: int,
out_path: str,
gt_time: Optional[np.ndarray] = None,
gt_bvp: Optional[np.ndarray] = None,
title: str = "rPPG Signals (Pred vs GT)") -> str:
"""
Save a 3-pane plot: (1) predicted raw, (2) predicted post, (3) overlay Pred vs GT (normalized).
If GT is provided, it is resampled to the prediction time grid and lag-aligned
(±5 s) using cross-correlation. The overlay includes Pearson r, lag, and HR stats.
"""
import numpy as _np
import matplotlib.pyplot as _plt
from matplotlib.gridspec import GridSpec as _GridSpec
def z(x):
x = _np.asarray(x, dtype=float)
if x.size == 0:
return x
m = float(_np.nanmean(x))
s = float(_np.nanstd(x)) + 1e-8
return (x - m) / s
def _bandpass(x, fs_local, lo=0.7, hi=3.5, order=4):
try:
return bandpass_filter(_np.asarray(x, float), fs=fs_local, low=lo, high=hi, order=order)
except Exception:
return _np.asarray(x, float)
def _welch_hr(x, fs_local):
try:
return float(hr_from_welch(_np.asarray(x, float), fs=fs_local, lo=0.7, hi=3.5))
except Exception:
return 0.0
def _safe_interp(x_t, y, t_new):
"""Monotonic-safe 1D interpolation with clipping to valid domain."""
x_t = _np.asarray(x_t, dtype=float).ravel()
y = _np.asarray(y, dtype=float).ravel()
t_new = _np.asarray(t_new, dtype=float).ravel()
if x_t.size < 2 or y.size != x_t.size:
# Fallback: length-based resize to t_new length
if y.size == 0 or t_new.size == 0:
return _np.zeros_like(t_new)
idx = _np.linspace(0, y.size - 1, num=t_new.size)
return _np.interp(_np.arange(t_new.size), idx, y)
# Enforce strictly increasing time (dedup if needed)
order = _np.argsort(x_t)
x_t = x_t[order]
y = y[order]
mask = _np.concatenate(([True], _np.diff(x_t) > 0))
x_t = x_t[mask]
y = y[mask]
# Clip t_new to the valid domain to avoid edge extrapolation artifacts
t_clip = _np.clip(t_new, x_t[0], x_t[-1])
return _np.interp(t_clip, x_t, y)
def _best_lag(pred, gt, fs_local, max_lag_s=5.0):
"""Return lag (sec) that best aligns GT to Pred using cross-correlation on z-scored signals."""
x = z(pred); y = z(gt)
if x.size < 8 or y.size < 8:
return 0.0
n = int(min(len(x), len(y)))
x = x[:n]; y = y[:n]
max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local))))
# valid lags: negative means GT should be shifted left (advance) relative to Pred
lags = _np.arange(-max_lag, max_lag + 1)
# compute correlation for each lag
best_corr = -_np.inf
best_lag = 0
for L in lags:
if L < 0:
xx = x[-L:n]
yy = y[0:n+L]
elif L > 0:
xx = x[0:n-L]
yy = y[L:n]
else:
xx = x
yy = y
if xx.size < 8 or yy.size < 8:
continue
c = _np.corrcoef(xx, yy)[0, 1]
if _np.isfinite(c) and c > best_corr:
best_corr = c
best_lag = L
return float(best_lag / float(fs_local))
def _apply_lag(y, lag_sec, fs_local):
"""Shift y by lag_sec (positive => delay GT) using sample roll; edges set to NaN."""
y = _np.asarray(y, float)
if y.size == 0 or fs_local <= 0:
return y
shift = int(round(lag_sec * fs_local))
if shift == 0:
return y
out = _np.empty_like(y)
out[:] = _np.nan
if shift > 0:
# delay: move content right
out[shift:] = y[:-shift]
else:
# advance: move content left
out[:shift] = y[-shift:]
return out
t = _np.asarray(time_axis, dtype=float)
raw = _np.asarray(raw_signal, dtype=float)
post = _np.asarray(post_signal, dtype=float)
# guard
if t.size == 0:
t = _np.arange(post.size, dtype=float) / max(fs, 1)
have_gt = gt_bvp is not None and _np.asarray(gt_bvp).size > 0
gt_on_pred = None
lag_sec = 0.0
pearson_r = _np.nan
hr_pred = _welch_hr(_bandpass(post, fs), fs)
hr_gt = 0.0
if have_gt:
gt = _np.asarray(gt_bvp, dtype=float).ravel()
if gt_time is not None and _np.asarray(gt_time).size == gt.size:
gt_t = _np.asarray(gt_time, dtype=float).ravel()
gt_on_pred = _safe_interp(gt_t, gt, t)
else:
# No time vector: try length-based mapping to pred time grid
gt_on_pred = _safe_interp(_np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size),
gt, t)
# Band-limit both before correlation/HR
pred_bp = _bandpass(post, fs)
gt_bp = _bandpass(gt_on_pred, fs)
# Estimate best lag (sec) of GT relative to Pred
lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0)
# Apply lag to GT for visualization and correlation
gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs)
# Compute Pearson r on overlapping valid samples
valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp)
if valid.sum() >= 16:
pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1])
else:
pearson_r = _np.nan
hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs)
_plt.figure(figsize=(13, 6), dpi=110)
gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35)
# (1) Raw Pred
ax1 = _plt.subplot(gs[0, 0])
ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5)
ax1.set_title(f"Predicted (Raw) — fs={fs} Hz")
ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude")
ax1.grid(True, alpha=0.3)
# (2) Post Pred
ax2 = _plt.subplot(gs[0, 1])
ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5)
ax2.set_title("Predicted (Post-processed)")
ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude")
ax2.grid(True, alpha=0.3)
# (3) Overlay Pred vs GT (z-scored) OR just post
ax3 = _plt.subplot(gs[1, :])
ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6)
if have_gt and gt_on_pred is not None:
gt_bp = _bandpass(gt_on_pred, fs)
gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs)
ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9)
# metrics box
txt = [
f"HR_pred: {hr_pred:.1f} BPM",
f"HR_gt: {hr_gt:.1f} BPM",
f"Pearson r: {pearson_r:.3f}" if _np.isfinite(pearson_r) else "Pearson r: --",
f"Lag: {lag_sec:+.2f} s"
]
ax3.text(0.01, 0.98, "\n".join(txt), transform=ax3.transAxes,
va="top", ha="left", fontsize=9,
bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="0.8", alpha=0.9))
ax3.set_title("Pred vs GT (z-score overlay, lag-aligned)")
else:
ax3.set_title("Pred vs GT (no GT provided)")
ax3.set_xlabel("Time (s)"); ax3.set_ylabel("z")
ax3.grid(True, alpha=0.3)
ax3.legend(loc="upper right")
_plt.suptitle(title, fontweight="bold")
_plt.tight_layout(rect=[0, 0.02, 1, 0.97])
_plt.savefig(out_path, bbox_inches="tight")
_plt.close()
return out_path
def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
"""
Robust single-face detector with a few practical guards:
- converts to gray safely
- equalizes histogram (helps underexposure)
- tries multiple scales / minNeighbors
- returns the largest face bbox (x,y,w,h) or None
"""
if frame is None or frame.size == 0:
return None
try:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
except Exception:
# If color conversion fails, assume already gray
gray = frame.copy() if frame.ndim == 2 else cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY)
# Light preproc to improve Haar performance
gray = cv2.equalizeHist(gray)
faces_all = []
# Try a couple of parameter combos to be more forgiving
params = [
dict(scaleFactor=1.05, minNeighbors=3),
dict(scaleFactor=1.10, minNeighbors=4),
dict(scaleFactor=1.20, minNeighbors=5),
]
for p in params:
try:
faces = FACE_CASCADE.detectMultiScale(gray, **p)
if faces is not None and len(faces) > 0:
faces_all.extend([tuple(map(int, f)) for f in faces])
except Exception:
continue
if not faces_all:
return None
# Return the largest (by area)
return max(faces_all, key=lambda f: f[2] * f[3])
def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]:
"""
Crop ROI from the frame based on a face bbox and the selected roi_type.
Returns the cropped BGR ROI or None if invalid.
"""
if face_bbox is None or frame is None or frame.size == 0:
return None
x, y, w, h = map(int, face_bbox)
H, W = frame.shape[:2]
if roi_type == "forehead":
rx = int(x + 0.25 * w); rw = int(0.50 * w)
ry = int(y + 0.10 * h); rh = int(0.20 * h)
elif roi_type == "cheeks":
rx = int(x + 0.15 * w); rw = int(0.70 * w)
ry = int(y + 0.55 * h); rh = int(0.30 * h)
else:
rx, ry, rw, rh = x, y, w, h
# clamp in-bounds
rx = max(0, rx); ry = max(0, ry)
rx2 = min(W, rx + rw); ry2 = min(H, ry + rh)
if rx2 <= rx or ry2 <= ry:
return None
roi = frame[ry:ry2, rx:rx2]
# Avoid empty or 1-pixel slivers
if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4:
return None
return roi
def crop_roi_with_bbox(face_bbox: Tuple[int, int, int, int],
roi_type: str,
frame: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]:
if face_bbox is None or frame is None or frame.size == 0:
return None, None
x, y, w, h = map(int, face_bbox)
H, W = frame.shape[:2]
if roi_type == "forehead":
rx = int(x + 0.25 * w); rw = int(0.50 * w)
ry = int(y + 0.10 * h); rh = int(0.20 * h)
elif roi_type == "cheeks":
rx = int(x + 0.15 * w); rw = int(0.70 * w)
ry = int(y + 0.55 * h); rh = int(0.30 * h)
else:
rx, ry, rw, rh = x, y, w, h
rx = max(0, rx); ry = max(0, ry)
rx2 = min(W, rx + rw); ry2 = min(H, ry + rh)
if rx2 <= rx or ry2 <= ry:
return None, None
roi = frame[ry:ry2, rx:rx2]
if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4:
return None, None
return roi, (rx, ry, rx2 - rx, ry2 - ry)
def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
"""
PhysMamba-compatible normalization with DiffNormalized support.
Returns (3, size, size).
"""
if face_bgr is None or face_bgr.size == 0:
return np.zeros((3, size, size), dtype=np.float32)
try:
face = cv2.resize(face_bgr, (size, size), interpolation=cv2.INTER_AREA)
except Exception:
face = cv2.resize(face_bgr, (size, size))
face = face.astype(np.float32) / 255.0
# Per-frame standardization
mean = face.mean(axis=(0, 1), keepdims=True)
std = face.std(axis=(0, 1), keepdims=True) + 1e-6
face = (face - mean) / std
# BGR -> RGB and HWC -> CHW
chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False)
return chw
def extract_attention_map(model, clip_tensor: torch.Tensor,
attention_viz) -> Optional[np.ndarray]:
"""Attention visualization disabled - model architecture incompatible."""
return None
def create_attention_overlay(frame: np.ndarray, attention: Optional[np.ndarray],
attention_viz: Optional[SimpleActivationAttention] = None) -> np.ndarray:
"""Create attention heatmap overlay."""
# Attention disabled - return original frame
return frame
def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
H, W = roi_bgr.shape[:2]
base_bvp = forward_bvp(model, torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE))
.unsqueeze(0).unsqueeze(2).to(DEVICE)) # fake T=1 path if needed
base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs)
heat = np.zeros((H, W), np.float32)
for y in range(0, H - patch + 1, stride):
for x in range(0, W - patch + 1, stride):
tmp = roi_bgr.copy()
tmp[y:y+patch, x:x+patch] = 127 # occlude
bvp = forward_bvp(model, torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE))
.unsqueeze(0).unsqueeze(2).to(DEVICE))
power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs)
drop = max(0.0, base_power - power)
heat[y:y+patch, x:x+patch] += drop
heat -= heat.min()
if heat.max() > 1e-8: heat /= heat.max()
return heat
def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
"""
Try common 5D layouts:
[B, C, T, H, W] then [B, T, C, H, W].
"""
last_err = None
try:
return model(clip_tensor)
except Exception as e:
last_err = e
try:
return model(clip_tensor.permute(0, 2, 1, 3, 4).contiguous())
except Exception as e:
last_err = e
raise last_err
def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
"""
Forward and extract a 1D time-like BVP vector with length T_clip.
Tolerant to dict/list/tuple heads and odd shapes.
"""
T_clip = int(clip_tensor.shape[2]) # intended time length for [B,C,T,H,W]
with torch.no_grad():
out = _call_model_try_orders(model, clip_tensor)
# unwrap common containers
if isinstance(out, dict):
for key in ("bvp", "ppg", "signal", "pred", "y", "out", "logits"):
if key in out and isinstance(out[key], torch.Tensor):
out = out[key]
break
if isinstance(out, (list, tuple)):
tensors = [t for t in out if isinstance(t, torch.Tensor)]
if not tensors:
return np.zeros(T_clip, dtype=np.float32)
def score(t: torch.Tensor):
has_T = 1 if T_clip in t.shape else 0
return (has_T, t.numel())
out = max(tensors, key=score)
if not isinstance(out, torch.Tensor):
return np.zeros(T_clip, dtype=np.float32)
out = out.detach().cpu().float()
# ---- 1D
if out.ndim == 1:
v = out
if v.shape[0] == T_clip:
return v.numpy()
if v.numel() == 1:
return np.full(T_clip, float(v.item()), dtype=np.float32)
return np.resize(v.numpy(), T_clip).astype(np.float32)
# ---- 2D
if out.ndim == 2:
B, K = out.shape
if B == 1:
v = out[0]
return (v.numpy() if v.shape[0] == T_clip else np.resize(v.numpy(), T_clip).astype(np.float32))
if B == T_clip:
return out[:, 0].numpy()
if K == T_clip:
return out[0, :].numpy()
return np.resize(out.flatten().numpy(), T_clip).astype(np.float32)
# ---- 3D
if out.ndim == 3:
B, D1, D2 = out.shape[0], out.shape[1], out.shape[2]
if D1 == T_clip: # [B, T, C]
return out[0, :, 0].numpy()
if D2 == T_clip: # [B, C, T]
return out[0, 0, :].numpy()
v = out.mean(dim=tuple(range(1, out.ndim))).squeeze(0)
return np.resize(v.numpy(), T_clip).astype(np.float32)
# ---- 4D
if out.ndim == 4:
B, A, H, W = out.shape
if A == T_clip: # [B, T, H, W]
return out[0].mean(dim=(1, 2)).numpy()
v = out[0].mean(dim=(1, 2))
return np.resize(v.numpy(), T_clip).astype(np.float32)
# ---- 5D+
if out.ndim >= 5:
shape = list(out.shape)
try:
t_idx = next(i for i, s in enumerate(shape) if (i != 0 and s == T_clip))
except StopIteration:
pooled = out[0].mean(dim=tuple(i for i in range(1, out.ndim) if i not in (-1,)))
v = pooled.flatten()
return np.resize(v.numpy(), T_clip).astype(np.float32)
axes = list(range(out.ndim))
perm = [0, t_idx] + [i for i in axes[1:] if i != t_idx]
o2 = out.permute(*perm) # [B, T, ...]
pooled = o2.mean(dim=tuple(range(2, o2.ndim))) # -> [B, T]
return pooled[0].numpy()
# fallback: constant vector with mean value
val = float(out.mean().item()) if out.numel() else 0.0
return np.full(T_clip, val, dtype=np.float32)
def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
"""
Classical rPPG from green-channel means when the model yields nothing.
Detrend -> bandpass -> z-normalize.
"""
if means is None:
return np.array([], dtype=np.float32)
x = np.asarray(means, dtype=np.float32)
if x.size == 0:
return np.array([], dtype=np.float32)
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
try:
x = signal.detrend(x, type="linear")
except Exception:
pass
y = bandpass_filter(x, fs=fs, low=0.7, high=3.5, order=4)
std = float(np.std(y)) + 1e-6
return (y / std).astype(np.float32)
def _to_floats(s: str) -> List[float]:
"""
Extract all real numbers from free-form text, including scientific notation.
Gracefully ignores 'nan', 'inf', units, and comments.
"""
if not isinstance(s, str) or not s:
return []
s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE)
s = s.replace(",", " ").replace(";", " ")
toks = re.findall(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?", s)
out: List[float] = []
for t in toks:
try:
v = float(t)
if np.isfinite(v):
out.append(v)
except Exception:
continue
return out
def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
"""
Parse ground-truth files from:
• UBFC/UBFC-rPPG style TXT: 3 lines => PPG, HR, timestep(s)
• Generic TXT: one column (or free-form) numeric sequence
• JSON: keys like 'ppg' / 'bvp' (optionally nested), 'hr', 'fs'
• CSV: columns named BVP/PPG/Signal + optional HR + optional Time
• MAT: array under keys ['ppg','bvp','signal','wave']; optionally 'fs'/'hr'
• NPY: np.load() 1-D array; optionally sidecar .json with fs/hr (same stem)
Returns:
bvp : np.ndarray (float) — may be empty
hr : float (mean HR if available or estimated)
fs_hint : float — sampling rate if derivable (0.0 if unknown)
"""
if not gt_path or not os.path.exists(gt_path):
return np.array([]), 0.0, 0.0
p = Path(gt_path)
ext = p.suffix.lower()
def _fs_from_time_vector(tv: np.ndarray) -> float:
tv = np.asarray(tv, dtype=float)
if tv.ndim != 1 or tv.size < 2:
return 0.0
diffs = np.diff(tv)
diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
return (1.0 / float(np.median(diffs))) if diffs.size else 0.0
def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float:
if bvp is None or bvp.size == 0:
return 0.0
fs_use = fs_hint if (fs_hint and fs_hint > 0) else 30.0
bp = bandpass_filter(bvp.astype(float), fs=fs_use)
return hr_from_welch(bp, fs=fs_use)
if p.name.lower() == "ground_truth.txt" or (ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2):
try:
lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines() if ln.strip()]
ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else []
hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else []
t_vals = _to_floats(lines[2]) if len(lines) >= 3 else []
bvp = np.asarray(ppg_vals, dtype=float) if ppg_vals else np.array([], dtype=float)
hr = float(np.nanmean(hr_vals)) if hr_vals else 0.0
fs_hint = 0.0
if t_vals:
if len(t_vals) == 1:
dt = float(t_vals[0])
if dt > 0:
fs_hint = 1.0 / dt
else:
fs_hint = _fs_from_time_vector(np.asarray(t_vals, dtype=float))
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
# Fall through to generic handlers
pass
if ext == ".txt":
try:
nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore"))
bvp = np.asarray(nums, dtype=float) if nums else np.array([], dtype=float)
hr = _hr_from_bvp(bvp, fs_hint=0.0) if bvp.size else 0.0
return bvp, hr, 0.0
except Exception:
return np.array([]), 0.0, 0.0
if ext == ".json":
try:
data = json.loads(p.read_text(encoding="utf-8", errors="ignore"))
# Try several paths for BVP array
bvp = None
def _seek(obj, keys):
for k in keys:
if isinstance(obj, dict) and k in obj:
return obj[k]
return None
# Direct top-level
bvp = _seek(data, ("ppg", "bvp", "signal", "wave"))
# Common nested containers
if bvp is None:
for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"):
if container_key in data:
cand = _seek(data[container_key], ("ppg", "bvp", "signal", "wave"))
if cand is not None:
bvp = cand
break
if bvp is not None:
bvp = np.asarray(bvp, dtype=float).ravel()
else:
bvp = np.array([], dtype=float)
# fs / hr (accept scalar or array)
fs_hint = 0.0
if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0:
fs_hint = float(data["fs"])
hr = 0.0
if "hr" in data:
v = data["hr"]
hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
return np.array([]), 0.0, 0.0
if ext == ".csv":
try:
df = pd.read_csv(p)
# Normalize column names
cols = {str(c).strip().lower(): c for c in df.columns}
def _first_match(names):
for nm in names:
if nm in cols:
return cols[nm]
return None
bvp_col = _first_match(["bvp", "ppg", "wave", "signal", "bvp_signal", "ppg_signal"])
hr_col = _first_match(["hr", "heart_rate", "hr_bpm", "bpm"])
t_col = _first_match(["time", "t", "timestamp", "sec", "seconds", "time_s"])
bvp = np.asarray(df[bvp_col].values, dtype=float) if bvp_col else np.array([], dtype=float)
fs_hint = 0.0
if t_col is not None and len(df[t_col].values) >= 2:
fs_hint = _fs_from_time_vector(np.asarray(df[t_col].values, dtype=float))
hr = float(np.nanmean(df[hr_col].values)) if (hr_col and df[hr_col].notna().any()) else 0.0
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
return np.array([]), 0.0, 0.0
if ext == ".mat":
try:
md = loadmat(str(p))
# look for most likely array
arr = None
for key in ("ppg", "bvp", "signal", "wave"):
if key in md and isinstance(md[key], np.ndarray):
arr = md[key]
break
if arr is None:
# fallback: first 1-D array
for v in md.values():
if isinstance(v, np.ndarray) and v.ndim == 1:
arr = v
break
bvp = np.asarray(arr, dtype=float).ravel() if arr is not None else np.array([], dtype=float)
fs_hint = 0.0
for k in ("fs", "Fs", "sampling_rate", "sr"):
if k in md:
try:
fs_hint = float(np.ravel(md[k])[0])
break
except Exception:
pass
hr = 0.0
if "hr" in md:
try:
hr = float(np.nanmean(np.ravel(md["hr"])))
except Exception:
hr = 0.0
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
return np.array([]), 0.0, 0.0
# ================= NPY =================
if ext == ".npy":
try:
bvp = np.asarray(np.load(str(p)), dtype=float).ravel()
fs_hint, hr = 0.0, 0.0
# optional sidecar JSON (same stem) with fs/hr
sidecar = p.with_suffix(".json")
if sidecar.exists():
try:
meta = json.loads(sidecar.read_text(encoding="utf-8", errors="ignore"))
if isinstance(meta.get("fs", None), (int, float)) and meta["fs"] > 0:
fs_hint = float(meta["fs"])
if "hr" in meta:
v = meta["hr"]
hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
except Exception:
pass
if hr == 0.0 and bvp.size:
hr = _hr_from_bvp(bvp, fs_hint)
return bvp, hr, fs_hint
except Exception:
return np.array([]), 0.0, 0.0
# Fallback (unsupported extension)
return np.array([]), 0.0, 0.0
def scan_models() -> List[str]:
if not MODEL_DIR.exists():
return []
models = []
for f in sorted(MODEL_DIR.iterdir()):
if f.suffix.lower() == '.pth':
models.append(f.name)
return models
_GLOBAL_CONTROLS: Dict[str, Dict] = {}
def ensure_controls(control_id: str) -> Tuple[str, Dict]:
# Use a stable default so Pause/Resume/Stop work for the current run
if not control_id:
control_id = "default-session"
if control_id not in _GLOBAL_CONTROLS:
_GLOBAL_CONTROLS[control_id] = {
'pause': threading.Event(),
'stop': threading.Event()
}
return control_id, _GLOBAL_CONTROLS[control_id]
def process_video_file(
video_path: str,
gt_file: Optional[str],
model_name: str,
fps_input: int,
max_seconds: int,
roi_type: str,
control_id: str
):
"""
Enhanced video processing with Grad-CAM attention visualization.
"""
global _HR_SMOOTH
_HR_SMOOTH = None
def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
if gt_len <= 1:
return None
if gt_fs and gt_fs > 0:
return np.arange(gt_len, dtype=float) / float(gt_fs)
return None
control_id, controls = ensure_controls(control_id)
controls['stop'].clear()
if not model_name:
yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
return
if isinstance(model_name, int):
model_name = str(model_name)
model_path = MODEL_DIR / model_name
if not model_path.exists():
yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None)
return
try:
model, attention_viz = load_physmamba_model(model_path, DEVICE)
except Exception as e:
yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None)
return
gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0)
if not video_path or not os.path.exists(video_path):
yield ("ERROR: Video not found", None, None, None, None, None, None, None, None, None)
return
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
yield ("ERROR: Cannot open video", None, None, None, None, None, None, None, None, None)
return
fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
max_frames = int(max_seconds * fps) if max_seconds and max_seconds > 0 else total_frames
MAX_SIGNAL_LENGTH = max_frames if max_frames > 0 else fps * 600
frames_chw = deque(maxlen=DEFAULT_T)
raw_g_means = deque(maxlen=DEFAULT_T)
bvp_stream: List[float] = []
frame_idx = 0
last_infer = -1
last_bpm = 0.0
last_rmssd = 0.0
last_attention = None
start_time = time.time()
next_display = time.time()
tmpdir = Path(tempfile.gettempdir())
frame_path = tmpdir / "frame.jpg"
attention_path = tmpdir / "attention.jpg"
signal_path = tmpdir / "signal.png"
raw_path = tmpdir / "raw_signal.png"
post_path = tmpdir / "post_signal.png"
yield ("Starting… reading video frames", None, f"{gt_hr:.1f}" if gt_hr > 0 else "--",
None, None, None, None, None, None, None)
while True:
if controls['stop'].is_set():
break
while controls['pause'].is_set():
time.sleep(0.2)
if controls['stop'].is_set():
break
ret, frame = cap.read()
if not ret or (max_frames > 0 and frame_idx >= max_frames):
break
frame_idx += 1
face = detect_face(frame)
vis_frame = frame.copy()
if face is not None:
x, y, w, h = face
cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
if roi_type == "auto":
roi, _ = pick_auto_roi(face, frame, attn=last_attention)
else:
roi = crop_roi(face, roi_type, frame)
if roi is not None and roi.size > 0:
try:
g_mean = float(roi[..., 1].astype(np.float32).mean())
raw_g_means.append(g_mean)
except Exception:
pass
face_norm = normalize_frame(roi, DEFAULT_SIZE)
frames_chw.append(face_norm)
if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE:
try:
clip = np.stack(list(frames_chw), axis=1).astype(np.float32)
except Exception as e:
print(f"[infer] clip stack failed: {e}")
clip = None
bvp_out = None
if clip is not None:
clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE)
try:
raw = forward_bvp(model, clip_t)
if isinstance(raw, np.ndarray):
raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
bvp_out = raw if raw.size > 0 else None
else:
bvp_out = None
except Exception as e:
print(f"[infer] forward_bvp error: {e}")
bvp_out = None
# UPDATED: Generate attention with Grad-CAM
try:
last_attention = extract_attention_map(model, clip_t, attention_viz)
except Exception as e:
print(f"⚠ Attention generation error: {e}")
last_attention = None
if bvp_out is None or bvp_out.size == 0:
gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0)
fb = _fallback_bvp_from_means(gbuf, fs=fps)
if isinstance(fb, np.ndarray) and fb.size > 0:
bvp_out = fb
else:
print("[infer] fallback produced empty output")
if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0:
tail = min(DEFAULT_STRIDE, bvp_out.size)
bvp_stream.extend(bvp_out[-tail:].tolist())
if len(bvp_stream) > MAX_SIGNAL_LENGTH:
bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
if len(bvp_stream) >= int(5 * fps):
seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32)
_, last_bpm = postprocess_bvp(seg, fs=fps)
last_rmssd = compute_rmssd(seg, fs=fps)
if frame_idx % (DEFAULT_STRIDE * 2) == 0:
print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}")
else:
print("[infer] no usable bvp_out after fallback")
last_infer = frame_idx
else:
cv2.putText(vis_frame, "No face detected", (20, 40),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2)
if last_bpm > 0:
color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255)
cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1)
cv2.putText(vis_frame, f"HR: {last_bpm:.1f} BPM", (20, 48),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
cv2.imwrite(str(frame_path), vis_frame)
cv2.imwrite(str(attention_path), vis_attention)
now = time.time()
if now >= next_display:
if len(bvp_stream) >= max(10, int(1 * fps)):
try:
signal_array = np.array(bvp_stream, dtype=float)
time_axis = np.arange(len(signal_array)) / fps
raw_signal = bandpass_filter(signal_array, fs=fps)
post_signal, _ = postprocess_bvp(signal_array, fs=fps)
raw_vis = raw_signal - np.mean(raw_signal)
post_vis = post_signal - np.mean(post_signal)
plt.figure(figsize=(10, 4), dpi=100)
plt.plot(time_axis, raw_vis, linewidth=1.6)
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM')
plt.grid(True, alpha=0.3)
plt.tight_layout(); plt.savefig(str(raw_path), dpi=100, bbox_inches='tight'); plt.close()
plt.figure(figsize=(10, 4), dpi=100)
plt.plot(time_axis, post_vis, linewidth=1.6)
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM')
plt.grid(True, alpha=0.3)
plt.tight_layout(); plt.savefig(str(post_path), dpi=100, bbox_inches='tight'); plt.close()
if gt_bvp is not None and gt_bvp.size > 0:
try:
gt_time = _gt_time_axis(len(gt_bvp), gt_fs)
plot_signals_with_gt(
time_axis=time_axis,
raw_signal=raw_signal,
post_signal=post_signal,
fs=fps,
out_path=str(signal_path),
gt_time=gt_time,
gt_bvp=gt_bvp,
title=f"Pred vs GT — HR: {last_bpm:.1f} BPM"
)
except Exception:
fig = plt.figure(figsize=(12, 5), dpi=100)
gs = GridSpec(1, 2, figure=fig, wspace=0.3)
ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6)
ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3)
ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6)
ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3)
plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold')
plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all')
else:
fig = plt.figure(figsize=(12, 5), dpi=100)
gs = GridSpec(1, 2, figure=fig, wspace=0.3)
ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6)
ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3)
ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6)
ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3)
plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold')
plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all')
except Exception:
pass
elapsed = now - start_time
status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
yield (
status,
f"{last_bpm:.1f}" if last_bpm > 0 else None,
f"{gt_hr:.1f}" if gt_hr > 0 else "--",
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
str(frame_path),
str(attention_path),
str(signal_path) if signal_path.exists() else None,
str(raw_path) if raw_path.exists() else None,
str(post_path) if post_path.exists() else None,
None
)
next_display = now + (1.0 / DISPLAY_FPS)
cap.release()
# Cleanup Grad-CAM
if attention_viz: attention_viz.cleanup()
csv_path = None
if bvp_stream:
csv_path = Path(tempfile.gettempdir()) / "bvp_output.csv"
time_array = np.arange(len(bvp_stream)) / fps
signal_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps)
try:
pd.DataFrame({'time_s': time_array, 'bvp': signal_final}).to_csv(csv_path, index=False)
except Exception:
csv_path = None
try:
final_overlay = Path(tempfile.gettempdir()) / "signal_final_overlay.png"
if gt_bvp is not None and gt_bvp.size > 0:
gt_time = _gt_time_axis(len(gt_bvp), gt_fs)
plot_signals_with_gt(
time_axis=time_array,
raw_signal=bandpass_filter(np.array(bvp_stream, dtype=float), fs=fps),
post_signal=signal_final,
fs=fps,
out_path=str(final_overlay),
gt_time=gt_time,
gt_bvp=gt_bvp,
title=f"Final Pred vs GT — HR: {last_bpm:.1f} BPM"
)
if final_overlay.exists():
signal_path = final_overlay
except Exception:
pass
elapsed = time.time() - start_time
final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
yield (
final_status,
f"{last_bpm:.1f}" if last_bpm > 0 else None,
f"{gt_hr:.1f}" if gt_hr > 0 else "--",
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
str(frame_path),
str(attention_path),
str(signal_path) if signal_path.exists() else None,
str(raw_path) if raw_path.exists() else None,
str(post_path) if post_path.exists() else None,
str(csv_path) if csv_path else None
)
def process_live_webcam(
model_name: str,
fps_input: int,
roi_type: str,
control_id: str
):
"""Stream live webcam with Grad-CAM attention visualization."""
global _HR_SMOOTH
_HR_SMOOTH = None
def _perf_heartbeat(frame_idx, t0, bvp_len, frames_chw_len, fps):
if frame_idx == 1:
print(f"[run] device={DEVICE} target_fps={fps}")
if frame_idx % 60 == 0:
elapsed = time.time() - t0
cur_fps = frame_idx / max(elapsed, 1e-6)
print(f"[perf] frames={frame_idx} ({cur_fps:.1f} FPS) "
f"clip={frames_chw_len}/{DEFAULT_T} bvp_len={bvp_len}")
control_id, controls = ensure_controls(control_id)
controls['stop'].clear()
if not model_name:
yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
return
model_path = MODEL_DIR / model_name
if not model_path.exists():
yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None)
return
try:
model, attention_viz = load_physmamba_model(model_path, DEVICE)
except Exception as e:
yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None)
return
cap = None
for camera_id in [0, 1]:
for backend in [cv2.CAP_AVFOUNDATION, cv2.CAP_ANY]:
try:
test_cap = cv2.VideoCapture(camera_id, backend)
if test_cap.isOpened():
test_cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
test_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
ret, frame = test_cap.read()
if ret and frame is not None:
cap = test_cap
break
test_cap.release()
except Exception:
pass
if cap is not None:
break
if cap is None:
yield ("ERROR: Cannot access webcam", None, None, None, None, None, None, None, None, None)
return
fps = int(fps_input) if fps_input else 30
frames_chw = deque(maxlen=DEFAULT_T)
raw_g_means = deque(maxlen=DEFAULT_T)
bvp_stream: List[float] = []
frame_idx = 0
last_infer = -1
last_bpm = 0.0
last_rmssd = 0.0
last_attention = None
t0 = time.time()
next_display = time.time()
DISPLAY_INTERVAL = 0.5
frame_path = Path(tempfile.gettempdir()) / "live_frame.jpg"
attention_path = Path(tempfile.gettempdir()) / "live_attention.jpg"
signal_path = Path(tempfile.gettempdir()) / "live_signal.png"
raw_path = Path(tempfile.gettempdir()) / "live_raw.png"
post_path = Path(tempfile.gettempdir()) / "live_post.png"
MAX_SIGNAL_LENGTH = fps * 60
yield ("Starting… waiting for frames", None, "--", None,
None, None, None, None, None, None)
while True:
if controls['stop'].is_set():
break
while controls['pause'].is_set():
time.sleep(0.2)
if controls['stop'].is_set():
break
ret, frame = cap.read()
if not ret:
time.sleep(0.05)
continue
frame_idx += 1
_perf_heartbeat(frame_idx, t0, len(bvp_stream), len(frames_chw), fps)
face = detect_face(frame)
vis_frame = frame.copy()
if face is not None:
x, y, w, h = face
cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 3)
cv2.putText(vis_frame, "FACE", (x, max(20, y - 10)),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
if roi_type == "auto":
roi, _ = pick_auto_roi(face, frame, attn=last_attention)
else:
roi = crop_roi(face, roi_type, frame)
if roi is not None and roi.size > 0:
try:
g_mean = float(roi[..., 1].astype(np.float32).mean())
raw_g_means.append(g_mean)
except Exception:
pass
chw = normalize_frame(roi, DEFAULT_SIZE)
frames_chw.append(chw)
if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE:
try:
clip = np.stack(list(frames_chw), axis=1).astype(np.float32)
except Exception as e:
print(f"[infer] clip stack failed: {e}")
clip = None
bvp_out = None
if clip is not None:
clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE)
try:
raw = forward_bvp(model, clip_t)
if isinstance(raw, np.ndarray):
raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
bvp_out = raw if raw.size > 0 else None
else:
bvp_out = None
except Exception as e:
print(f"[infer] forward_bvp error: {e}")
bvp_out = None
try:
last_attention = extract_attention_map(model, clip_t, attention_viz)
except Exception as e:
print(f" Attention generation error: {e}")
last_attention = None
if bvp_out is None or bvp_out.size == 0:
gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0)
fb = _fallback_bvp_from_means(gbuf, fs=fps)
if isinstance(fb, np.ndarray) and fb.size > 0:
bvp_out = fb
else:
print("[infer] fallback produced empty output")
if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0:
tail = min(DEFAULT_STRIDE, bvp_out.size)
bvp_stream.extend(bvp_out[-tail:].tolist())
if len(bvp_stream) > MAX_SIGNAL_LENGTH:
bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
if len(bvp_stream) >= int(5 * fps):
seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32)
_, last_bpm = postprocess_bvp(seg, fs=fps)
last_rmssd = compute_rmssd(seg, fs=fps)
if frame_idx % (DEFAULT_STRIDE * 2) == 0:
print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}")
else:
print("[infer] no usable bvp_out after fallback")
last_infer = frame_idx
else:
cv2.putText(vis_frame, "No face detected", (20, 40),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2)
cv2.putText(vis_frame, f"Fill: {len(frames_chw)}/{DEFAULT_T}", (20, 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
cv2.putText(vis_frame, f"BVP: {len(bvp_stream)}", (20, 45),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
if last_bpm > 0:
color = (50, 255, 50) if 55 <= last_bpm <= 100 else (50, 50, 255) if last_bpm > 100 else (255, 200, 50)
overlay = vis_frame.copy()
cv2.rectangle(overlay, (10, 10), (450, 100), (0, 0, 0), -1)
vis_frame = cv2.addWeighted(vis_frame, 0.6, overlay, 0.4, 0)
cv2.putText(vis_frame, f"HR: {last_bpm:.0f} BPM", (20, 80),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 3)
cv2.circle(vis_frame, (30, 30), 10, (0, 255, 0), -1)
cv2.putText(vis_frame, "LIVE", (50, 38),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
else:
cv2.putText(vis_frame, "Collecting…", (20, 80),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
cv2.imwrite(str(frame_path), vis_frame)
cv2.imwrite(str(attention_path), vis_attention)
now = time.time()
if now >= next_display:
if len(bvp_stream) >= max(10, int(1 * fps)):
try:
sig = np.array(bvp_stream, dtype=float)
t_axis = np.arange(len(sig)) / fps
raw_sig = bandpass_filter(sig, fs=fps)
post_sig, _ = postprocess_bvp(sig, fs=fps)
rv = raw_sig - np.mean(raw_sig)
pv = post_sig - np.mean(post_sig)
plt.figure(figsize=(10, 4), dpi=100)
plt.plot(t_axis, rv, linewidth=2)
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM')
plt.grid(True, alpha=0.3)
plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
plt.tight_layout(); plt.savefig(str(raw_path)); plt.close()
plt.figure(figsize=(10, 4), dpi=100)
plt.plot(t_axis, pv, linewidth=2)
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM')
plt.grid(True, alpha=0.3)
plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
plt.tight_layout(); plt.savefig(str(post_path)); plt.close()
fig = plt.figure(figsize=(12, 5), dpi=100)
from matplotlib.gridspec import GridSpec
gs = GridSpec(1, 2, figure=fig, wspace=0.3)
ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(t_axis, rv, linewidth=2)
ax1.set_title('Raw Signal'); ax1.grid(True, alpha=0.3)
ax1.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(t_axis, pv, linewidth=2)
ax2.set_title('Post-processed'); ax2.grid(True, alpha=0.3)
ax2.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
plt.suptitle(f'LIVE rPPG - HR: {last_bpm:.1f} BPM')
plt.savefig(str(signal_path)); plt.close('all')
except Exception:
pass
elapsed = now - t0
status = (f"Frame {frame_idx} | Fill {len(frames_chw)}/{DEFAULT_T} | "
f"BVP {len(bvp_stream)} | HR {last_bpm:.1f} BPM | "
f"Time {int(elapsed)}s")
yield (
status,
f"{last_bpm:.1f}" if last_bpm > 0 else None,
"--",
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
str(frame_path),
str(attention_path),
str(signal_path) if signal_path.exists() else None,
str(raw_path) if raw_path.exists() else None,
str(post_path) if post_path.exists() else None,
None
)
next_display = now + DISPLAY_INTERVAL
time.sleep(0.01)
cap.release()
# Cleanup Grad-CAM
if attention_viz: attention_viz.cleanup()
csv_path = None
if bvp_stream:
csv_path = Path(tempfile.gettempdir()) / "live_bvp.csv"
t = np.arange(len(bvp_stream)) / fps
sig_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps)
pd.DataFrame({"time_s": t, "bvp": sig_final}).to_csv(csv_path, index=False)
elapsed = time.time() - t0
final_status = f"Session ended | Frames {frame_idx} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
yield (
final_status,
f"{last_bpm:.1f}" if last_bpm > 0 else None,
"--",
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
str(frame_path),
str(attention_path),
str(signal_path) if signal_path.exists() else None,
str(raw_path) if raw_path.exists() else None,
str(post_path) if post_path.exists() else None,
str(csv_path) if csv_path else None
)
def process_stream(
input_source: str,
video_path: Optional[str],
gt_file: Optional[str],
model_name: str,
fps_input: int,
max_seconds: int,
roi_type: str,
control_id: str
):
if input_source == "Live Webcam":
yield from process_live_webcam(model_name, fps_input, roi_type, control_id)
else:
yield from process_video_file(video_path, gt_file, model_name, fps_input,
max_seconds, roi_type, control_id)
def pause_processing(control_id: str) -> str:
_, controls = ensure_controls(control_id)
controls['pause'].set()
return "Paused"
def resume_processing(control_id: str) -> str:
_, controls = ensure_controls(control_id)
controls['pause'].clear()
return "Resumed"
def stop_processing(control_id: str) -> str:
_, controls = ensure_controls(control_id)
controls['stop'].set()
controls['pause'].clear()
return "Stopped"
def reset_ui():
return ("Ready", None, None, None, None, None, None, None, None, None)
def handle_folder_upload(files):
if not files:
return None, None, "No files uploaded"
if not isinstance(files, list):
files = [files]
video_path = None
for file_obj in files:
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
if file_path.suffix.lower() in VIDEO_EXTENSIONS:
video_path = str(file_path)
break
gt_path = None
gt_patterns = ['gtdump.txt', 'ground_truth.txt', 'gt.txt']
for file_obj in files:
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
if file_path.name.lower() in [p.lower() for p in gt_patterns]:
gt_path = str(file_path)
break
if not gt_path:
for file_obj in files:
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
if file_path.suffix.lower() in ['.txt', '.json', '.csv']:
if 'readme' not in file_path.name.lower():
gt_path = str(file_path)
break
status = []
if video_path:
status.append(f"Video: {Path(video_path).name}")
else:
status.append("No video found")
if gt_path:
status.append(f"GT: {Path(gt_path).name}")
else:
status.append("No ground truth")
return video_path, gt_path, " | ".join(status)
with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as demo:
gr.Markdown("# rPPG Analysis Tool with Attention Visualization")
with gr.Row():
input_source = gr.Radio(
choices=["Video File", "Subject Folder", "Live Webcam"],
value="Live Webcam",
label="Input Source"
)
with gr.Row():
target_layer_input = gr.Textbox(
label="Grad-CAM Target Layer (optional - leave empty for auto)",
placeholder="e.g., backbone.conv3, encoder.layer2",
value=""
)
with gr.Row(visible=False) as video_inputs:
with gr.Column():
video_upload = gr.Video(label="Upload Video", sources=["upload"])
with gr.Column():
gt_upload = gr.File(label="Ground Truth (optional)",
file_types=[".txt", ".csv", ".json"])
with gr.Row(visible=False) as folder_inputs:
with gr.Column():
folder_upload = gr.File(
label="Upload Subject Folder",
file_count="directory",
file_types=None,
type="filepath"
)
folder_status = gr.Textbox(label="Folder Status", interactive=False)
with gr.Row(visible=True) as webcam_inputs:
gr.Markdown("### Live Webcam - Click Run to start")
def toggle_input_source(source):
return [
gr.update(visible=(source == "Video File")),
gr.update(visible=(source == "Subject Folder")),
gr.update(visible=(source == "Live Webcam"))
]
input_source.change(
toggle_input_source,
inputs=[input_source],
outputs=[video_inputs, folder_inputs, webcam_inputs]
)
folder_video = gr.State(None)
folder_gt = gr.State(None)
folder_upload.upload(
handle_folder_upload,
inputs=[folder_upload],
outputs=[folder_video, folder_gt, folder_status]
)
with gr.Row():
with gr.Column(scale=3):
model_dropdown = gr.Dropdown(
choices=scan_models(),
value=scan_models()[0] if scan_models() else None,
label="PhysMamba Model",
interactive=True
)
with gr.Column(scale=1):
refresh_models_btn = gr.Button("Refresh", variant="secondary")
refresh_models_btn.click(
lambda: gr.update(choices=scan_models()),
inputs=None,
outputs=[model_dropdown]
)
with gr.Row():
fps_slider = gr.Slider(
minimum=10, maximum=120, value=30, step=5,
label="FPS"
)
max_seconds_slider = gr.Slider(
minimum=10, maximum=600, value=180, step=10,
label="Max Duration (s)"
)
with gr.Row():
roi_dropdown = gr.Dropdown(choices=["auto","forehead","cheeks","face"], value="auto", label="ROI")
control_state = gr.State(value="")
placeholder_state = gr.State(value=None)
with gr.Row():
run_btn = gr.Button("Run", variant="primary")
pause_btn = gr.Button("Pause", variant="secondary")
resume_btn = gr.Button("Resume", variant="secondary")
stop_btn = gr.Button("Stop", variant="stop")
status_text = gr.Textbox(label="Status", lines=2, value="Ready")
with gr.Row():
hr_output = gr.Textbox(label="HR (BPM)", interactive=False)
gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False)
rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False)
with gr.Row():
with gr.Column():
frame_output = gr.Image(label="Video Feed", type="filepath")
with gr.Column():
attention_output = gr.Image(label="Attention Map", type="filepath")
with gr.Row():
signal_output = gr.Image(label="Signal Comparison", type="filepath")
with gr.Row():
with gr.Column():
raw_signal_output = gr.Image(label="Raw Signal", type="filepath")
with gr.Column():
post_signal_output = gr.Image(label="Post-processed Signal", type="filepath")
with gr.Row():
csv_output = gr.File(label="Download CSV")
pause_btn.click(
pause_processing,
inputs=[control_state],
outputs=[status_text]
)
resume_btn.click(
resume_processing,
inputs=[control_state],
outputs=[status_text]
)
stop_btn.click(
stop_processing,
inputs=[control_state],
outputs=[status_text]
).then(
reset_ui,
inputs=None,
outputs=[status_text, hr_output, gt_hr_output, rmssd_output,
frame_output, attention_output, signal_output,
raw_signal_output, post_signal_output, csv_output]
)
def run_processing(input_source, video_upload, gt_upload, folder_video, folder_gt,
model_name, fps, max_sec, roi, ctrl_id):
"""Fixed version that handles model_name type conversion."""
if isinstance(model_name, int):
model_name = str(model_name)
if not model_name:
yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
return
if input_source == "Video File":
video_path = _as_path(video_upload)
gt_file = _as_path(gt_upload)
elif input_source == "Subject Folder":
video_path = _as_path(folder_video)
gt_file = _as_path(folder_gt)
else: # Live Webcam
video_path, gt_file = None, None
yield from process_stream(
input_source, video_path, gt_file,
model_name, fps, max_sec, roi, ctrl_id
)
run_btn.click(
fn=run_processing,
inputs=[
input_source,
video_upload,
gt_upload,
folder_video,
folder_gt,
model_dropdown,
fps_slider,
max_seconds_slider,
roi_dropdown,
control_state
],
outputs=[
status_text,
hr_output,
gt_hr_output,
rmssd_output,
frame_output,
attention_output,
signal_output,
raw_signal_output,
post_signal_output,
csv_output
]
)
run_btn.click(
fn=run_processing,
inputs=[
input_source,
video_upload,
folder_video,
folder_gt,
model_dropdown,
fps_slider,
max_seconds_slider,
roi_dropdown,
control_state
],
outputs=[
status_text,
hr_output,
gt_hr_output,
rmssd_output,
frame_output,
attention_output,
signal_output,
raw_signal_output,
post_signal_output,
csv_output
]
)
if __name__ == "__main__":
import os
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=False,
)