Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,7 +24,6 @@ import torch
|
|
| 24 |
import torch.nn as nn
|
| 25 |
import torch.nn.functional as F
|
| 26 |
|
| 27 |
-
|
| 28 |
from scipy import signal
|
| 29 |
from scipy.signal import find_peaks, welch, get_window
|
| 30 |
|
|
@@ -43,7 +42,8 @@ from matplotlib.gridspec import GridSpec
|
|
| 43 |
|
| 44 |
import gradio as gr
|
| 45 |
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
class PhysMambaattention_viz:
|
| 49 |
"""Simplified Grad-CAM for PhysMamba."""
|
|
@@ -228,55 +228,86 @@ class PhysMambaattention_viz:
|
|
| 228 |
|
| 229 |
|
| 230 |
def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray:
|
| 231 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
if len(frames) < 2:
|
| 233 |
-
|
| 234 |
-
|
|
|
|
| 235 |
diff_frames = []
|
| 236 |
-
|
| 237 |
for i in range(len(frames)):
|
| 238 |
if i == 0:
|
| 239 |
diff_frames.append(np.zeros_like(frames[0], dtype=np.float32))
|
| 240 |
else:
|
| 241 |
curr = frames[i].astype(np.float32)
|
| 242 |
-
prev = frames[i-1].astype(np.float32)
|
| 243 |
-
|
| 244 |
-
diff = (curr - prev) /
|
| 245 |
diff_frames.append(diff)
|
| 246 |
-
|
| 247 |
-
diff_array = np.stack(diff_frames)
|
| 248 |
-
std = diff_array.std()
|
| 249 |
-
|
| 250 |
-
diff_array = diff_array / std
|
| 251 |
-
|
| 252 |
return diff_array
|
| 253 |
|
| 254 |
|
| 255 |
-
def preprocess_for_physmamba(
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
if len(frames) < target_frames:
|
| 260 |
frames = frames + [frames[-1]] * (target_frames - len(frames))
|
| 261 |
elif len(frames) > target_frames:
|
| 262 |
-
|
| 263 |
-
frames = [frames[i] for i in
|
| 264 |
-
|
| 265 |
-
|
|
|
|
| 266 |
frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb]
|
| 267 |
-
frames_diff = apply_diff_normalized(frames_resized)
|
| 268 |
-
frames_transposed = np.transpose(frames_diff, (3, 0, 1, 2))
|
| 269 |
-
frames_batched = np.expand_dims(frames_transposed, axis=0)
|
| 270 |
-
tensor = torch.from_numpy(frames_batched.astype(np.float32))
|
| 271 |
-
|
| 272 |
-
return tensor
|
| 273 |
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
MODEL_DIR = HERE / "final_model_release"
|
| 276 |
LOG_DIR = HERE / "logs"
|
| 277 |
ANALYSIS_DIR = HERE / "analysis"
|
| 278 |
-
for d in
|
| 279 |
-
d.mkdir(exist_ok=True)
|
| 280 |
|
| 281 |
DEVICE = (
|
| 282 |
torch.device("cuda") if torch.cuda.is_available()
|
|
@@ -284,7 +315,9 @@ DEVICE = (
|
|
| 284 |
else torch.device("cpu")
|
| 285 |
)
|
| 286 |
|
| 287 |
-
FACE_CASCADE = cv2.CascadeClassifier(
|
|
|
|
|
|
|
| 288 |
|
| 289 |
DEFAULT_SIZE = 128 # input H=W to model
|
| 290 |
DEFAULT_T = 128 # clip length
|
|
@@ -302,22 +335,52 @@ MAX_JUMP_BPM = 8.0
|
|
| 302 |
GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"}
|
| 303 |
GT_EXTS = {".txt", ".csv", ".json"}
|
| 304 |
|
|
|
|
| 305 |
def _as_path(maybe) -> Optional[str]:
|
| 306 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
if maybe is None:
|
| 308 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
if isinstance(maybe, str):
|
| 310 |
return maybe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
if isinstance(maybe, dict):
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
return name
|
|
|
|
| 316 |
try:
|
| 317 |
return str(maybe)
|
| 318 |
except Exception:
|
| 319 |
return None
|
| 320 |
|
|
|
|
| 321 |
def _import_from_file(py_path: Path):
|
| 322 |
spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path))
|
| 323 |
if not spec or not spec.loader:
|
|
@@ -326,27 +389,35 @@ def _import_from_file(py_path: Path):
|
|
| 326 |
spec.loader.exec_module(mod)
|
| 327 |
return mod
|
| 328 |
|
|
|
|
| 329 |
def _looks_like_video(p: Path) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
if p.suffix.lower() == ".mat":
|
| 331 |
return True
|
| 332 |
return p.suffix.lower() in VIDEO_EXTENSIONS
|
| 333 |
|
|
|
|
| 334 |
class SimpleActivationAttention:
|
| 335 |
-
"""Lightweight attention visualization
|
| 336 |
|
| 337 |
def __init__(self, model: nn.Module, device: torch.device):
|
| 338 |
self.model = model
|
| 339 |
self.device = device
|
| 340 |
-
self.activations = None
|
| 341 |
-
self.hook_handle = None
|
| 342 |
|
| 343 |
def _activation_hook(self, module, input, output):
|
| 344 |
"""Capture activations during forward pass."""
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
def register_hook(self):
|
| 348 |
-
"""Register hook on a suitable layer."""
|
| 349 |
-
# Find the last convolutional layer before Mamba
|
| 350 |
target = None
|
| 351 |
target_name = None
|
| 352 |
|
|
@@ -354,7 +425,7 @@ class SimpleActivationAttention:
|
|
| 354 |
if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower():
|
| 355 |
target = module
|
| 356 |
target_name = name
|
| 357 |
-
|
| 358 |
if target is None:
|
| 359 |
print("⚠ [attention_viz] No suitable conv layer found, attention disabled")
|
| 360 |
return
|
|
@@ -363,30 +434,36 @@ class SimpleActivationAttention:
|
|
| 363 |
print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})")
|
| 364 |
|
| 365 |
def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]:
|
| 366 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
try:
|
| 368 |
if self.activations is None:
|
| 369 |
return None
|
| 370 |
|
| 371 |
-
# Process activations to create spatial attention
|
| 372 |
act = self.activations
|
| 373 |
-
|
| 374 |
# Handle different tensor shapes
|
| 375 |
if act.dim() == 5: # [B, C, T, H, W]
|
| 376 |
-
# Average over
|
| 377 |
-
attention = act.mean(dim=[1, 2])
|
| 378 |
elif act.dim() == 4: # [B, C, H, W]
|
| 379 |
attention = act.mean(dim=1) # -> [B, H, W]
|
| 380 |
else:
|
| 381 |
print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}")
|
| 382 |
return None
|
| 383 |
|
| 384 |
-
#
|
| 385 |
-
attention = attention.
|
| 386 |
-
|
| 387 |
# Normalize to [0, 1]
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
return attention
|
| 392 |
|
|
@@ -396,16 +473,22 @@ class SimpleActivationAttention:
|
|
| 396 |
|
| 397 |
def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
|
| 398 |
"""Overlay heatmap on frame."""
|
|
|
|
|
|
|
|
|
|
| 399 |
h, w = frame.shape[:2]
|
| 400 |
heatmap_resized = cv2.resize(heatmap, (w, h))
|
| 401 |
heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
|
| 402 |
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 403 |
-
overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
|
| 404 |
return overlay
|
| 405 |
|
| 406 |
def cleanup(self):
|
| 407 |
if self.hook_handle is not None:
|
| 408 |
self.hook_handle.remove()
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
class VideoReader:
|
| 411 |
"""
|
|
@@ -421,6 +504,7 @@ class VideoReader:
|
|
| 421 |
self._idx = 0
|
| 422 |
self._len = 0
|
| 423 |
self._shape = None
|
|
|
|
| 424 |
|
| 425 |
if self.path.lower().endswith(".mat") and MAT_SUPPORT:
|
| 426 |
self._open_mat(self.path)
|
|
@@ -433,6 +517,7 @@ class VideoReader:
|
|
| 433 |
raise RuntimeError("Cannot open video")
|
| 434 |
self._cap = cap
|
| 435 |
self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
|
|
|
| 436 |
|
| 437 |
def _open_mat(self, path: str):
|
| 438 |
try:
|
|
@@ -444,6 +529,7 @@ class VideoReader:
|
|
| 444 |
break
|
| 445 |
else:
|
| 446 |
arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None)
|
|
|
|
| 447 |
if arr is None:
|
| 448 |
raise RuntimeError("No ndarray found in .mat")
|
| 449 |
|
|
@@ -451,16 +537,18 @@ class VideoReader:
|
|
| 451 |
# Normalize to (T,H,W,3)
|
| 452 |
if a.ndim == 4:
|
| 453 |
if a.shape[-1] == 3:
|
| 454 |
-
if
|
|
|
|
| 455 |
v = a
|
| 456 |
-
else:
|
| 457 |
v = np.transpose(a, (2, 0, 1, 3))
|
| 458 |
else:
|
| 459 |
-
v = a[..., :1]
|
| 460 |
elif a.ndim == 3:
|
| 461 |
-
|
|
|
|
| 462 |
v = a
|
| 463 |
-
else:
|
| 464 |
v = np.transpose(a, (2, 0, 1))
|
| 465 |
v = v[..., None]
|
| 466 |
else:
|
|
@@ -473,6 +561,7 @@ class VideoReader:
|
|
| 473 |
self._mat = v
|
| 474 |
self._len = v.shape[0]
|
| 475 |
self._shape = v.shape[1:3]
|
|
|
|
| 476 |
except Exception as e:
|
| 477 |
raise RuntimeError(f"Failed to open .mat video: {e}")
|
| 478 |
|
|
@@ -489,9 +578,11 @@ class VideoReader:
|
|
| 489 |
|
| 490 |
def fps(self, fallback: int = 30) -> int:
|
| 491 |
if self._mat is not None:
|
| 492 |
-
return fallback
|
|
|
|
|
|
|
| 493 |
f = self._cap.get(cv2.CAP_PROP_FPS)
|
| 494 |
-
return int(f) if f and f > 0 else fallback
|
| 495 |
|
| 496 |
def length(self) -> int:
|
| 497 |
return self._len
|
|
@@ -500,6 +591,7 @@ class VideoReader:
|
|
| 500 |
if self._cap is not None:
|
| 501 |
self._cap.release()
|
| 502 |
|
|
|
|
| 503 |
def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]:
|
| 504 |
x, y, w, h = face
|
| 505 |
# forehead
|
|
@@ -510,23 +602,25 @@ def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[s
|
|
| 510 |
ff = frame[y:y + h, x:x + w]
|
| 511 |
return {"forehead": fh, "cheeks": ck, "face": ff}
|
| 512 |
|
|
|
|
| 513 |
def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float:
|
| 514 |
if patch is None or patch.size == 0:
|
| 515 |
return -1e9
|
| 516 |
g = patch[..., 1].astype(np.float32) / 255.0 # green channel
|
| 517 |
g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling
|
| 518 |
g = g - g.mean()
|
| 519 |
-
b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band")
|
| 520 |
try:
|
|
|
|
| 521 |
y = signal.filtfilt(b, a, g, method="gust")
|
| 522 |
except Exception:
|
| 523 |
y = g
|
| 524 |
return float((y ** 2).mean())
|
| 525 |
|
|
|
|
| 526 |
def pick_auto_roi(face: Tuple[int, int, int, int],
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
"""Simple ROI selection."""
|
| 530 |
cands = roi_candidates(face, frame)
|
| 531 |
scores = {k: roi_quality_score(v) for k, v in cands.items()}
|
| 532 |
|
|
@@ -539,14 +633,15 @@ def pick_auto_roi(face: Tuple[int, int, int, int],
|
|
| 539 |
ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0
|
| 540 |
ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0
|
| 541 |
scores['forehead'] += fh_attn * 0.2
|
| 542 |
-
scores['cheeks']
|
| 543 |
-
scores['face']
|
| 544 |
except Exception:
|
| 545 |
pass
|
| 546 |
|
| 547 |
best = max(scores, key=scores.get)
|
| 548 |
return cands[best], best
|
| 549 |
|
|
|
|
| 550 |
def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
|
| 551 |
"""
|
| 552 |
Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None).
|
|
@@ -601,6 +696,7 @@ def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
|
|
| 601 |
uniq.append((v, g))
|
| 602 |
return uniq
|
| 603 |
|
|
|
|
| 604 |
def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"):
|
| 605 |
import inspect
|
| 606 |
|
|
@@ -639,7 +735,8 @@ def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: s
|
|
| 639 |
except Exception:
|
| 640 |
continue
|
| 641 |
|
| 642 |
-
raise ImportError(
|
|
|
|
| 643 |
|
| 644 |
def load_physmamba_model(ckpt_path: Path, device: torch.device,
|
| 645 |
model_file: str = "", model_class: str = "PhysMamba"):
|
|
@@ -680,6 +777,7 @@ def load_physmamba_model(ckpt_path: Path, device: torch.device,
|
|
| 680 |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 681 |
model.load_state_dict(state_dict, strict=False)
|
| 682 |
except Exception:
|
|
|
|
| 683 |
pass
|
| 684 |
|
| 685 |
model.to(device).eval()
|
|
@@ -688,13 +786,15 @@ def load_physmamba_model(ckpt_path: Path, device: torch.device,
|
|
| 688 |
with torch.no_grad():
|
| 689 |
_ = model(torch.zeros(1, 3, 8, 128, 128, device=device))
|
| 690 |
except Exception:
|
|
|
|
| 691 |
pass
|
| 692 |
|
| 693 |
-
#
|
| 694 |
attention_viz = None
|
| 695 |
|
| 696 |
return model, attention_viz
|
| 697 |
|
|
|
|
| 698 |
def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray:
|
| 699 |
"""
|
| 700 |
Stable band-pass with edge-safety and parameter clipping.
|
|
@@ -712,12 +812,12 @@ def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float =
|
|
| 712 |
|
| 713 |
try:
|
| 714 |
b, a = signal.butter(order, [lo, hi], btype="band")
|
| 715 |
-
# padlen must be < len(x); reduce when short
|
| 716 |
padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1))
|
| 717 |
return signal.filtfilt(b, a, x, padlen=padlen)
|
| 718 |
except Exception:
|
| 719 |
return x
|
| 720 |
|
|
|
|
| 721 |
def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float:
|
| 722 |
"""
|
| 723 |
HR (BPM) via Welch PSD peak in [lo, hi] Hz.
|
|
@@ -726,7 +826,6 @@ def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5)
|
|
| 726 |
if x.size < int(fs * 4.0): # need ~4s for a usable PSD
|
| 727 |
return 0.0
|
| 728 |
try:
|
| 729 |
-
# nperseg tuned for short windows while avoiding tiny segments
|
| 730 |
nper = int(min(max(64, fs * 2), min(512, x.size)))
|
| 731 |
f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant")
|
| 732 |
|
|
@@ -741,11 +840,11 @@ def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5)
|
|
| 741 |
|
| 742 |
fpk = float(f_band[np.argmax(p_band)])
|
| 743 |
bpm = fpk * 60.0
|
| 744 |
-
# clip to plausible range
|
| 745 |
return float(np.clip(bpm, 30.0, 220.0))
|
| 746 |
except Exception:
|
| 747 |
return 0.0
|
| 748 |
|
|
|
|
| 749 |
def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
|
| 750 |
"""
|
| 751 |
HRV RMSSD from peaks; robust to short/flat segments.
|
|
@@ -754,7 +853,6 @@ def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
|
|
| 754 |
if x.size < int(fs * 5.0):
|
| 755 |
return 0.0
|
| 756 |
try:
|
| 757 |
-
# peak distance ~ 0.5s minimum (avoid double counting)
|
| 758 |
peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs)))
|
| 759 |
if len(peaks) < 3:
|
| 760 |
return 0.0
|
|
@@ -765,6 +863,7 @@ def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
|
|
| 765 |
except Exception:
|
| 766 |
return 0.0
|
| 767 |
|
|
|
|
| 768 |
def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
|
| 769 |
"""
|
| 770 |
Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band.
|
|
@@ -787,7 +886,6 @@ def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
|
|
| 787 |
lo, hi = REST_HR_RANGE
|
| 788 |
if hr < lo or hr > hi:
|
| 789 |
dist = abs(hr - REST_HR_TARGET)
|
| 790 |
-
# farther away -> stronger pull
|
| 791 |
alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65))
|
| 792 |
hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET
|
| 793 |
|
|
@@ -802,6 +900,7 @@ def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
|
|
| 802 |
|
| 803 |
return y_filt, float(hr)
|
| 804 |
|
|
|
|
| 805 |
def draw_face_and_roi(frame_bgr: np.ndarray,
|
| 806 |
face_bbox: Optional[Tuple[int, int, int, int]],
|
| 807 |
roi_bbox: Optional[Tuple[int, int, int, int]],
|
|
@@ -820,6 +919,7 @@ def draw_face_and_roi(frame_bgr: np.ndarray,
|
|
| 820 |
cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2)
|
| 821 |
return vis
|
| 822 |
|
|
|
|
| 823 |
def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
|
| 824 |
roi_type: str,
|
| 825 |
frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
|
|
@@ -845,6 +945,7 @@ def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
|
|
| 845 |
return (0, 0, 0, 0)
|
| 846 |
return (rx, ry, rx2 - rx, ry2 - ry)
|
| 847 |
|
|
|
|
| 848 |
def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
|
| 849 |
"""
|
| 850 |
Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR.
|
|
@@ -853,9 +954,7 @@ def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
|
|
| 853 |
if chw is None or chw.ndim != 3 or chw.shape[0] != 3:
|
| 854 |
return np.zeros((128, 128, 3), dtype=np.uint8)
|
| 855 |
|
| 856 |
-
# Undo channel-first & normalization to a viewable image
|
| 857 |
img = chw.copy()
|
| 858 |
-
# Re-normalize to 0..1 by min-max of the tensor to "show" contrast
|
| 859 |
vmin, vmax = float(img.min()), float(img.max())
|
| 860 |
if vmax <= vmin + 1e-6:
|
| 861 |
img = np.zeros_like(img)
|
|
@@ -865,6 +964,7 @@ def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
|
|
| 865 |
img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR
|
| 866 |
return img
|
| 867 |
|
|
|
|
| 868 |
def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
|
| 869 |
if gt_len <= 1:
|
| 870 |
return None
|
|
@@ -872,6 +972,7 @@ def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
|
|
| 872 |
return np.arange(gt_len, dtype=float) / float(gt_fs)
|
| 873 |
return None # will fall back to length-matching overlay
|
| 874 |
|
|
|
|
| 875 |
def plot_signals_with_gt(time_axis: np.ndarray,
|
| 876 |
raw_signal: np.ndarray,
|
| 877 |
post_signal: np.ndarray,
|
|
@@ -916,20 +1017,17 @@ def plot_signals_with_gt(time_axis: np.ndarray,
|
|
| 916 |
t_new = _np.asarray(t_new, dtype=float).ravel()
|
| 917 |
|
| 918 |
if x_t.size < 2 or y.size != x_t.size:
|
| 919 |
-
# Fallback: length-based resize to t_new length
|
| 920 |
if y.size == 0 or t_new.size == 0:
|
| 921 |
return _np.zeros_like(t_new)
|
| 922 |
idx = _np.linspace(0, y.size - 1, num=t_new.size)
|
| 923 |
return _np.interp(_np.arange(t_new.size), idx, y)
|
| 924 |
|
| 925 |
-
# Enforce strictly increasing time (dedup if needed)
|
| 926 |
order = _np.argsort(x_t)
|
| 927 |
x_t = x_t[order]
|
| 928 |
y = y[order]
|
| 929 |
mask = _np.concatenate(([True], _np.diff(x_t) > 0))
|
| 930 |
x_t = x_t[mask]
|
| 931 |
y = y[mask]
|
| 932 |
-
# Clip t_new to the valid domain to avoid edge extrapolation artifacts
|
| 933 |
t_clip = _np.clip(t_new, x_t[0], x_t[-1])
|
| 934 |
return _np.interp(t_clip, x_t, y)
|
| 935 |
|
|
@@ -941,9 +1039,7 @@ def plot_signals_with_gt(time_axis: np.ndarray,
|
|
| 941 |
n = int(min(len(x), len(y)))
|
| 942 |
x = x[:n]; y = y[:n]
|
| 943 |
max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local))))
|
| 944 |
-
# valid lags: negative means GT should be shifted left (advance) relative to Pred
|
| 945 |
lags = _np.arange(-max_lag, max_lag + 1)
|
| 946 |
-
# compute correlation for each lag
|
| 947 |
best_corr = -_np.inf
|
| 948 |
best_lag = 0
|
| 949 |
for L in lags:
|
|
@@ -975,10 +1071,8 @@ def plot_signals_with_gt(time_axis: np.ndarray,
|
|
| 975 |
out = _np.empty_like(y)
|
| 976 |
out[:] = _np.nan
|
| 977 |
if shift > 0:
|
| 978 |
-
# delay: move content right
|
| 979 |
out[shift:] = y[:-shift]
|
| 980 |
else:
|
| 981 |
-
# advance: move content left
|
| 982 |
out[:shift] = y[-shift:]
|
| 983 |
return out
|
| 984 |
|
|
@@ -986,7 +1080,6 @@ def plot_signals_with_gt(time_axis: np.ndarray,
|
|
| 986 |
raw = _np.asarray(raw_signal, dtype=float)
|
| 987 |
post = _np.asarray(post_signal, dtype=float)
|
| 988 |
|
| 989 |
-
# guard
|
| 990 |
if t.size == 0:
|
| 991 |
t = _np.arange(post.size, dtype=float) / max(fs, 1)
|
| 992 |
|
|
@@ -1003,21 +1096,17 @@ def plot_signals_with_gt(time_axis: np.ndarray,
|
|
| 1003 |
gt_t = _np.asarray(gt_time, dtype=float).ravel()
|
| 1004 |
gt_on_pred = _safe_interp(gt_t, gt, t)
|
| 1005 |
else:
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
|
|
|
| 1009 |
|
| 1010 |
-
# Band-limit both before correlation/HR
|
| 1011 |
pred_bp = _bandpass(post, fs)
|
| 1012 |
gt_bp = _bandpass(gt_on_pred, fs)
|
| 1013 |
|
| 1014 |
-
# Estimate best lag (sec) of GT relative to Pred
|
| 1015 |
lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0)
|
| 1016 |
-
|
| 1017 |
-
# Apply lag to GT for visualization and correlation
|
| 1018 |
gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs)
|
| 1019 |
|
| 1020 |
-
# Compute Pearson r on overlapping valid samples
|
| 1021 |
valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp)
|
| 1022 |
if valid.sum() >= 16:
|
| 1023 |
pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1])
|
|
@@ -1026,25 +1115,21 @@ def plot_signals_with_gt(time_axis: np.ndarray,
|
|
| 1026 |
|
| 1027 |
hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs)
|
| 1028 |
|
| 1029 |
-
|
| 1030 |
_plt.figure(figsize=(13, 6), dpi=110)
|
| 1031 |
gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35)
|
| 1032 |
|
| 1033 |
-
# (1) Raw Pred
|
| 1034 |
ax1 = _plt.subplot(gs[0, 0])
|
| 1035 |
ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5)
|
| 1036 |
ax1.set_title(f"Predicted (Raw) — fs={fs} Hz")
|
| 1037 |
ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude")
|
| 1038 |
ax1.grid(True, alpha=0.3)
|
| 1039 |
|
| 1040 |
-
# (2) Post Pred
|
| 1041 |
ax2 = _plt.subplot(gs[0, 1])
|
| 1042 |
ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5)
|
| 1043 |
ax2.set_title("Predicted (Post-processed)")
|
| 1044 |
ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude")
|
| 1045 |
ax2.grid(True, alpha=0.3)
|
| 1046 |
|
| 1047 |
-
# (3) Overlay Pred vs GT (z-scored) OR just post
|
| 1048 |
ax3 = _plt.subplot(gs[1, :])
|
| 1049 |
ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6)
|
| 1050 |
|
|
@@ -1053,7 +1138,6 @@ def plot_signals_with_gt(time_axis: np.ndarray,
|
|
| 1053 |
gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs)
|
| 1054 |
ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9)
|
| 1055 |
|
| 1056 |
-
# metrics box
|
| 1057 |
txt = [
|
| 1058 |
f"HR_pred: {hr_pred:.1f} BPM",
|
| 1059 |
f"HR_gt: {hr_gt:.1f} BPM",
|
|
@@ -1088,16 +1172,23 @@ def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
|
|
| 1088 |
if frame is None or frame.size == 0:
|
| 1089 |
return None
|
| 1090 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1091 |
try:
|
| 1092 |
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 1093 |
except Exception:
|
| 1094 |
# If color conversion fails, assume already gray
|
| 1095 |
-
|
|
|
|
|
|
|
|
|
|
| 1096 |
|
| 1097 |
# Light preproc to improve Haar performance
|
| 1098 |
gray = cv2.equalizeHist(gray)
|
| 1099 |
|
| 1100 |
-
faces_all = []
|
| 1101 |
# Try a couple of parameter combos to be more forgiving
|
| 1102 |
params = [
|
| 1103 |
dict(scaleFactor=1.05, minNeighbors=3),
|
|
@@ -1118,6 +1209,7 @@ def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
|
|
| 1118 |
# Return the largest (by area)
|
| 1119 |
return max(faces_all, key=lambda f: f[2] * f[3])
|
| 1120 |
|
|
|
|
| 1121 |
def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]:
|
| 1122 |
"""
|
| 1123 |
Crop ROI from the frame based on a face bbox and the selected roi_type.
|
|
@@ -1151,9 +1243,15 @@ def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndar
|
|
| 1151 |
return None
|
| 1152 |
return roi
|
| 1153 |
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1157 |
if face_bbox is None or frame is None or frame.size == 0:
|
| 1158 |
return None, None
|
| 1159 |
|
|
@@ -1180,6 +1278,7 @@ def crop_roi_with_bbox(face_bbox: Tuple[int, int, int, int],
|
|
| 1180 |
|
| 1181 |
return roi, (rx, ry, rx2 - rx, ry2 - ry)
|
| 1182 |
|
|
|
|
| 1183 |
def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
|
| 1184 |
"""
|
| 1185 |
PhysMamba-compatible normalization with DiffNormalized support.
|
|
@@ -1204,21 +1303,28 @@ def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
|
|
| 1204 |
chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False)
|
| 1205 |
return chw
|
| 1206 |
|
| 1207 |
-
|
| 1208 |
-
|
| 1209 |
"""Attention visualization disabled - model architecture incompatible."""
|
| 1210 |
return None
|
| 1211 |
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
|
| 1215 |
-
|
|
|
|
|
|
|
|
|
|
| 1216 |
return frame
|
| 1217 |
|
|
|
|
| 1218 |
def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
|
| 1219 |
H, W = roi_bgr.shape[:2]
|
| 1220 |
-
base_bvp = forward_bvp(
|
| 1221 |
-
|
|
|
|
|
|
|
|
|
|
| 1222 |
base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs)
|
| 1223 |
|
| 1224 |
heat = np.zeros((H, W), np.float32)
|
|
@@ -1226,15 +1332,20 @@ def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
|
|
| 1226 |
for x in range(0, W - patch + 1, stride):
|
| 1227 |
tmp = roi_bgr.copy()
|
| 1228 |
tmp[y:y+patch, x:x+patch] = 127 # occlude
|
| 1229 |
-
bvp = forward_bvp(
|
| 1230 |
-
|
|
|
|
|
|
|
|
|
|
| 1231 |
power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs)
|
| 1232 |
drop = max(0.0, base_power - power)
|
| 1233 |
heat[y:y+patch, x:x+patch] += drop
|
| 1234 |
heat -= heat.min()
|
| 1235 |
-
if heat.max() > 1e-8:
|
|
|
|
| 1236 |
return heat
|
| 1237 |
|
|
|
|
| 1238 |
def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
|
| 1239 |
"""
|
| 1240 |
Try common 5D layouts:
|
|
@@ -1251,6 +1362,7 @@ def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
|
|
| 1251 |
last_err = e
|
| 1252 |
raise last_err
|
| 1253 |
|
|
|
|
| 1254 |
def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
|
| 1255 |
"""
|
| 1256 |
Forward and extract a 1D time-like BVP vector with length T_clip.
|
|
@@ -1297,7 +1409,8 @@ def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
|
|
| 1297 |
B, K = out.shape
|
| 1298 |
if B == 1:
|
| 1299 |
v = out[0]
|
| 1300 |
-
return (v.numpy() if v.shape[0] == T_clip
|
|
|
|
| 1301 |
if B == T_clip:
|
| 1302 |
return out[:, 0].numpy()
|
| 1303 |
if K == T_clip:
|
|
@@ -1342,6 +1455,7 @@ def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
|
|
| 1342 |
val = float(out.mean().item()) if out.numel() else 0.0
|
| 1343 |
return np.full(T_clip, val, dtype=np.float32)
|
| 1344 |
|
|
|
|
| 1345 |
def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
|
| 1346 |
"""
|
| 1347 |
Classical rPPG from green-channel means when the model yields nothing.
|
|
@@ -1366,19 +1480,25 @@ def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
|
|
| 1366 |
std = float(np.std(y)) + 1e-6
|
| 1367 |
return (y / std).astype(np.float32)
|
| 1368 |
|
|
|
|
| 1369 |
def _to_floats(s: str) -> List[float]:
|
| 1370 |
"""
|
| 1371 |
Extract all real numbers from free-form text, including scientific notation.
|
| 1372 |
-
Gracefully ignores
|
| 1373 |
"""
|
| 1374 |
if not isinstance(s, str) or not s:
|
| 1375 |
return []
|
| 1376 |
|
|
|
|
| 1377 |
s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE)
|
| 1378 |
|
|
|
|
| 1379 |
s = s.replace(",", " ").replace(";", " ")
|
| 1380 |
|
| 1381 |
-
toks = re.findall(
|
|
|
|
|
|
|
|
|
|
| 1382 |
out: List[float] = []
|
| 1383 |
for t in toks:
|
| 1384 |
try:
|
|
@@ -1418,7 +1538,6 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1418 |
diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
|
| 1419 |
return (1.0 / float(np.median(diffs))) if diffs.size else 0.0
|
| 1420 |
|
| 1421 |
-
|
| 1422 |
def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float:
|
| 1423 |
if bvp is None or bvp.size == 0:
|
| 1424 |
return 0.0
|
|
@@ -1426,9 +1545,16 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1426 |
bp = bandpass_filter(bvp.astype(float), fs=fs_use)
|
| 1427 |
return hr_from_welch(bp, fs=fs_use)
|
| 1428 |
|
| 1429 |
-
|
|
|
|
|
|
|
|
|
|
| 1430 |
try:
|
| 1431 |
-
lines = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1432 |
ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else []
|
| 1433 |
hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else []
|
| 1434 |
t_vals = _to_floats(lines[2]) if len(lines) >= 3 else []
|
|
@@ -1452,7 +1578,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1452 |
# Fall through to generic handlers
|
| 1453 |
pass
|
| 1454 |
|
| 1455 |
-
|
| 1456 |
if ext == ".txt":
|
| 1457 |
try:
|
| 1458 |
nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore"))
|
|
@@ -1462,12 +1588,10 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1462 |
except Exception:
|
| 1463 |
return np.array([]), 0.0, 0.0
|
| 1464 |
|
| 1465 |
-
|
| 1466 |
if ext == ".json":
|
| 1467 |
try:
|
| 1468 |
data = json.loads(p.read_text(encoding="utf-8", errors="ignore"))
|
| 1469 |
-
# Try several paths for BVP array
|
| 1470 |
-
bvp = None
|
| 1471 |
|
| 1472 |
def _seek(obj, keys):
|
| 1473 |
for k in keys:
|
|
@@ -1475,9 +1599,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1475 |
return obj[k]
|
| 1476 |
return None
|
| 1477 |
|
| 1478 |
-
# Direct top-level
|
| 1479 |
bvp = _seek(data, ("ppg", "bvp", "signal", "wave"))
|
| 1480 |
-
# Common nested containers
|
| 1481 |
if bvp is None:
|
| 1482 |
for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"):
|
| 1483 |
if container_key in data:
|
|
@@ -1491,7 +1613,6 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1491 |
else:
|
| 1492 |
bvp = np.array([], dtype=float)
|
| 1493 |
|
| 1494 |
-
# fs / hr (accept scalar or array)
|
| 1495 |
fs_hint = 0.0
|
| 1496 |
if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0:
|
| 1497 |
fs_hint = float(data["fs"])
|
|
@@ -1507,11 +1628,10 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1507 |
except Exception:
|
| 1508 |
return np.array([]), 0.0, 0.0
|
| 1509 |
|
| 1510 |
-
|
| 1511 |
if ext == ".csv":
|
| 1512 |
try:
|
| 1513 |
df = pd.read_csv(p)
|
| 1514 |
-
# Normalize column names
|
| 1515 |
cols = {str(c).strip().lower(): c for c in df.columns}
|
| 1516 |
|
| 1517 |
def _first_match(names):
|
|
@@ -1537,18 +1657,16 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1537 |
except Exception:
|
| 1538 |
return np.array([]), 0.0, 0.0
|
| 1539 |
|
| 1540 |
-
|
| 1541 |
if ext == ".mat":
|
| 1542 |
try:
|
| 1543 |
md = loadmat(str(p))
|
| 1544 |
-
# look for most likely array
|
| 1545 |
arr = None
|
| 1546 |
for key in ("ppg", "bvp", "signal", "wave"):
|
| 1547 |
if key in md and isinstance(md[key], np.ndarray):
|
| 1548 |
arr = md[key]
|
| 1549 |
break
|
| 1550 |
if arr is None:
|
| 1551 |
-
# fallback: first 1-D array
|
| 1552 |
for v in md.values():
|
| 1553 |
if isinstance(v, np.ndarray) and v.ndim == 1:
|
| 1554 |
arr = v
|
|
@@ -1582,7 +1700,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1582 |
try:
|
| 1583 |
bvp = np.asarray(np.load(str(p)), dtype=float).ravel()
|
| 1584 |
fs_hint, hr = 0.0, 0.0
|
| 1585 |
-
|
| 1586 |
sidecar = p.with_suffix(".json")
|
| 1587 |
if sidecar.exists():
|
| 1588 |
try:
|
|
@@ -1594,6 +1712,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1594 |
hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
|
| 1595 |
except Exception:
|
| 1596 |
pass
|
|
|
|
| 1597 |
if hr == 0.0 and bvp.size:
|
| 1598 |
hr = _hr_from_bvp(bvp, fs_hint)
|
| 1599 |
return bvp, hr, fs_hint
|
|
@@ -1603,6 +1722,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
|
| 1603 |
# Fallback (unsupported extension)
|
| 1604 |
return np.array([]), 0.0, 0.0
|
| 1605 |
|
|
|
|
| 1606 |
def scan_models() -> List[str]:
|
| 1607 |
if not MODEL_DIR.exists():
|
| 1608 |
return []
|
|
@@ -1614,8 +1734,10 @@ def scan_models() -> List[str]:
|
|
| 1614 |
|
| 1615 |
return models
|
| 1616 |
|
|
|
|
| 1617 |
_GLOBAL_CONTROLS: Dict[str, Dict] = {}
|
| 1618 |
|
|
|
|
| 1619 |
def ensure_controls(control_id: str) -> Tuple[str, Dict]:
|
| 1620 |
# Use a stable default so Pause/Resume/Stop work for the current run
|
| 1621 |
if not control_id:
|
|
@@ -1627,6 +1749,7 @@ def ensure_controls(control_id: str) -> Tuple[str, Dict]:
|
|
| 1627 |
}
|
| 1628 |
return control_id, _GLOBAL_CONTROLS[control_id]
|
| 1629 |
|
|
|
|
| 1630 |
def process_video_file(
|
| 1631 |
video_path: str,
|
| 1632 |
gt_file: Optional[str],
|
|
@@ -1639,9 +1762,13 @@ def process_video_file(
|
|
| 1639 |
"""
|
| 1640 |
Enhanced video processing with Grad-CAM attention visualization,
|
| 1641 |
plus per-frame illumination and motion logging.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1642 |
"""
|
| 1643 |
global _HR_SMOOTH
|
| 1644 |
-
global FRAME_METRICS #
|
| 1645 |
_HR_SMOOTH = None
|
| 1646 |
FRAME_METRICS = [] # reset per run
|
| 1647 |
|
|
@@ -1655,8 +1782,19 @@ def process_video_file(
|
|
| 1655 |
control_id, controls = ensure_controls(control_id)
|
| 1656 |
controls['stop'].clear()
|
| 1657 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1658 |
if not model_name:
|
| 1659 |
-
yield ("ERROR: No model selected"
|
| 1660 |
return
|
| 1661 |
|
| 1662 |
if isinstance(model_name, int):
|
|
@@ -1664,24 +1802,24 @@ def process_video_file(
|
|
| 1664 |
|
| 1665 |
model_path = MODEL_DIR / model_name
|
| 1666 |
if not model_path.exists():
|
| 1667 |
-
yield ("ERROR: Model not found"
|
| 1668 |
return
|
| 1669 |
|
| 1670 |
try:
|
| 1671 |
model, attention_viz = load_physmamba_model(model_path, DEVICE)
|
| 1672 |
except Exception as e:
|
| 1673 |
-
yield (f"ERROR loading model: {str(e)}"
|
| 1674 |
return
|
| 1675 |
|
| 1676 |
gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0)
|
| 1677 |
|
| 1678 |
if not video_path or not os.path.exists(video_path):
|
| 1679 |
-
yield ("ERROR: Video not found"
|
| 1680 |
return
|
| 1681 |
|
| 1682 |
cap = cv2.VideoCapture(video_path)
|
| 1683 |
if not cap.isOpened():
|
| 1684 |
-
yield ("ERROR: Cannot open video"
|
| 1685 |
return
|
| 1686 |
|
| 1687 |
fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30)
|
|
@@ -1698,7 +1836,7 @@ def process_video_file(
|
|
| 1698 |
last_rmssd = 0.0
|
| 1699 |
last_attention = None
|
| 1700 |
|
| 1701 |
-
#
|
| 1702 |
prev_gray = None
|
| 1703 |
prev_roi_gray = None
|
| 1704 |
|
|
@@ -1712,8 +1850,15 @@ def process_video_file(
|
|
| 1712 |
raw_path = tmpdir / "raw_signal.png"
|
| 1713 |
post_path = tmpdir / "post_signal.png"
|
| 1714 |
|
| 1715 |
-
|
| 1716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1717 |
|
| 1718 |
while True:
|
| 1719 |
if controls['stop'].is_set():
|
|
@@ -1730,13 +1875,13 @@ def process_video_file(
|
|
| 1730 |
|
| 1731 |
frame_idx += 1
|
| 1732 |
|
| 1733 |
-
#
|
| 1734 |
global_brightness = None
|
| 1735 |
global_motion = None
|
| 1736 |
roi_brightness = None
|
| 1737 |
roi_motion = None
|
| 1738 |
|
| 1739 |
-
#
|
| 1740 |
try:
|
| 1741 |
frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 1742 |
except Exception:
|
|
@@ -1762,7 +1907,7 @@ def process_video_file(
|
|
| 1762 |
roi = crop_roi(face, roi_type, frame)
|
| 1763 |
|
| 1764 |
if roi is not None and roi.size > 0:
|
| 1765 |
-
#
|
| 1766 |
try:
|
| 1767 |
roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
|
| 1768 |
except Exception:
|
|
@@ -1799,7 +1944,9 @@ def process_video_file(
|
|
| 1799 |
try:
|
| 1800 |
raw = forward_bvp(model, clip_t)
|
| 1801 |
if isinstance(raw, np.ndarray):
|
| 1802 |
-
raw = np.nan_to_num(
|
|
|
|
|
|
|
| 1803 |
bvp_out = raw if raw.size > 0 else None
|
| 1804 |
else:
|
| 1805 |
bvp_out = None
|
|
@@ -1807,7 +1954,7 @@ def process_video_file(
|
|
| 1807 |
print(f"[infer] forward_bvp error: {e}")
|
| 1808 |
bvp_out = None
|
| 1809 |
|
| 1810 |
-
#
|
| 1811 |
try:
|
| 1812 |
last_attention = extract_attention_map(model, clip_t, attention_viz)
|
| 1813 |
except Exception as e:
|
|
@@ -1815,7 +1962,9 @@ def process_video_file(
|
|
| 1815 |
last_attention = None
|
| 1816 |
|
| 1817 |
if bvp_out is None or bvp_out.size == 0:
|
| 1818 |
-
gbuf = np.nan_to_num(
|
|
|
|
|
|
|
| 1819 |
fb = _fallback_bvp_from_means(gbuf, fs=fps)
|
| 1820 |
if isinstance(fb, np.ndarray) and fb.size > 0:
|
| 1821 |
bvp_out = fb
|
|
@@ -1829,7 +1978,9 @@ def process_video_file(
|
|
| 1829 |
bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
|
| 1830 |
|
| 1831 |
if len(bvp_stream) >= int(5 * fps):
|
| 1832 |
-
seg = np.asarray(
|
|
|
|
|
|
|
| 1833 |
_, last_bpm = postprocess_bvp(seg, fs=fps)
|
| 1834 |
last_rmssd = compute_rmssd(seg, fs=fps)
|
| 1835 |
|
|
@@ -1841,10 +1992,17 @@ def process_video_file(
|
|
| 1841 |
last_infer = frame_idx
|
| 1842 |
|
| 1843 |
else:
|
| 1844 |
-
cv2.putText(
|
| 1845 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1846 |
|
| 1847 |
-
#
|
| 1848 |
try:
|
| 1849 |
time_s = frame_idx / float(fps) if fps > 0 else float(frame_idx)
|
| 1850 |
except Exception:
|
|
@@ -1859,11 +2017,24 @@ def process_video_file(
|
|
| 1859 |
"roi_motion": float(roi_motion) if roi_motion is not None else None,
|
| 1860 |
})
|
| 1861 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1862 |
if last_bpm > 0:
|
| 1863 |
color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255)
|
| 1864 |
cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1)
|
| 1865 |
-
cv2.putText(
|
| 1866 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1867 |
|
| 1868 |
vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
|
| 1869 |
|
|
@@ -1934,6 +2105,7 @@ def process_video_file(
|
|
| 1934 |
elapsed = now - start_time
|
| 1935 |
status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
|
| 1936 |
|
|
|
|
| 1937 |
yield (
|
| 1938 |
status,
|
| 1939 |
f"{last_bpm:.1f}" if last_bpm > 0 else None,
|
|
@@ -1944,7 +2116,11 @@ def process_video_file(
|
|
| 1944 |
str(signal_path) if signal_path.exists() else None,
|
| 1945 |
str(raw_path) if raw_path.exists() else None,
|
| 1946 |
str(post_path) if post_path.exists() else None,
|
| 1947 |
-
None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1948 |
)
|
| 1949 |
|
| 1950 |
next_display = now + (1.0 / DISPLAY_FPS)
|
|
@@ -1984,7 +2160,7 @@ def process_video_file(
|
|
| 1984 |
except Exception:
|
| 1985 |
pass
|
| 1986 |
|
| 1987 |
-
#
|
| 1988 |
frame_metrics_path = None
|
| 1989 |
if FRAME_METRICS:
|
| 1990 |
try:
|
|
@@ -1995,9 +2171,13 @@ def process_video_file(
|
|
| 1995 |
print(f"[metrics] Failed to save frame metrics CSV: {e}")
|
| 1996 |
frame_metrics_path = None
|
| 1997 |
|
|
|
|
|
|
|
|
|
|
| 1998 |
elapsed = time.time() - start_time
|
| 1999 |
final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
|
| 2000 |
|
|
|
|
| 2001 |
yield (
|
| 2002 |
final_status,
|
| 2003 |
f"{last_bpm:.1f}" if last_bpm > 0 else None,
|
|
@@ -2008,9 +2188,14 @@ def process_video_file(
|
|
| 2008 |
str(signal_path) if signal_path.exists() else None,
|
| 2009 |
str(raw_path) if raw_path.exists() else None,
|
| 2010 |
str(post_path) if post_path.exists() else None,
|
| 2011 |
-
str(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2012 |
)
|
| 2013 |
|
|
|
|
| 2014 |
def process_stream(
|
| 2015 |
input_source: str,
|
| 2016 |
video_path: Optional[str],
|
|
@@ -2024,8 +2209,10 @@ def process_stream(
|
|
| 2024 |
if input_source == "Live Webcam":
|
| 2025 |
yield from process_live_webcam(model_name, fps_input, roi_type, control_id)
|
| 2026 |
else:
|
| 2027 |
-
yield from process_video_file(
|
| 2028 |
-
|
|
|
|
|
|
|
| 2029 |
|
| 2030 |
def pause_processing(control_id: str) -> str:
|
| 2031 |
_, controls = ensure_controls(control_id)
|
|
@@ -2044,7 +2231,23 @@ def stop_processing(control_id: str) -> str:
|
|
| 2044 |
return "Stopped"
|
| 2045 |
|
| 2046 |
def reset_ui():
|
| 2047 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2048 |
|
| 2049 |
def handle_folder_upload(files):
|
| 2050 |
if not files:
|
|
@@ -2111,8 +2314,10 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
|
|
| 2111 |
with gr.Column():
|
| 2112 |
video_upload = gr.Video(label="Upload Video", sources=["upload"])
|
| 2113 |
with gr.Column():
|
| 2114 |
-
gt_upload = gr.File(
|
| 2115 |
-
|
|
|
|
|
|
|
| 2116 |
|
| 2117 |
with gr.Row(visible=False) as folder_inputs:
|
| 2118 |
with gr.Column():
|
|
@@ -2177,7 +2382,11 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
|
|
| 2177 |
)
|
| 2178 |
|
| 2179 |
with gr.Row():
|
| 2180 |
-
roi_dropdown = gr.Dropdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2181 |
|
| 2182 |
control_state = gr.State(value="")
|
| 2183 |
placeholder_state = gr.State(value=None)
|
|
@@ -2194,6 +2403,21 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
|
|
| 2194 |
hr_output = gr.Textbox(label="HR (BPM)", interactive=False)
|
| 2195 |
gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False)
|
| 2196 |
rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2197 |
|
| 2198 |
with gr.Row():
|
| 2199 |
with gr.Column():
|
|
@@ -2232,20 +2456,51 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
|
|
| 2232 |
).then(
|
| 2233 |
reset_ui,
|
| 2234 |
inputs=None,
|
| 2235 |
-
outputs=[
|
| 2236 |
-
|
| 2237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2238 |
)
|
| 2239 |
|
| 2240 |
-
def run_processing(
|
| 2241 |
-
|
| 2242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2243 |
|
| 2244 |
if isinstance(model_name, int):
|
| 2245 |
model_name = str(model_name)
|
| 2246 |
|
| 2247 |
if not model_name:
|
| 2248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2249 |
return
|
| 2250 |
|
| 2251 |
if input_source == "Video File":
|
|
@@ -2257,12 +2512,12 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
|
|
| 2257 |
else: # Live Webcam
|
| 2258 |
video_path, gt_file = None, None
|
| 2259 |
|
|
|
|
| 2260 |
yield from process_stream(
|
| 2261 |
input_source, video_path, gt_file,
|
| 2262 |
model_name, fps, max_sec, roi, ctrl_id
|
| 2263 |
)
|
| 2264 |
|
| 2265 |
-
|
| 2266 |
run_btn.click(
|
| 2267 |
fn=run_processing,
|
| 2268 |
inputs=[
|
|
@@ -2287,35 +2542,11 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
|
|
| 2287 |
signal_output,
|
| 2288 |
raw_signal_output,
|
| 2289 |
post_signal_output,
|
| 2290 |
-
csv_output
|
| 2291 |
-
|
| 2292 |
-
|
| 2293 |
-
|
| 2294 |
-
|
| 2295 |
-
run_btn.click(
|
| 2296 |
-
fn=run_processing,
|
| 2297 |
-
inputs=[
|
| 2298 |
-
input_source,
|
| 2299 |
-
video_upload,
|
| 2300 |
-
folder_video,
|
| 2301 |
-
folder_gt,
|
| 2302 |
-
model_dropdown,
|
| 2303 |
-
fps_slider,
|
| 2304 |
-
max_seconds_slider,
|
| 2305 |
-
roi_dropdown,
|
| 2306 |
-
control_state
|
| 2307 |
-
],
|
| 2308 |
-
outputs=[
|
| 2309 |
-
status_text,
|
| 2310 |
-
hr_output,
|
| 2311 |
-
gt_hr_output,
|
| 2312 |
-
rmssd_output,
|
| 2313 |
-
frame_output,
|
| 2314 |
-
attention_output,
|
| 2315 |
-
signal_output,
|
| 2316 |
-
raw_signal_output,
|
| 2317 |
-
post_signal_output,
|
| 2318 |
-
csv_output
|
| 2319 |
]
|
| 2320 |
)
|
| 2321 |
|
|
|
|
| 24 |
import torch.nn as nn
|
| 25 |
import torch.nn.functional as F
|
| 26 |
|
|
|
|
| 27 |
from scipy import signal
|
| 28 |
from scipy.signal import find_peaks, welch, get_window
|
| 29 |
|
|
|
|
| 42 |
|
| 43 |
import gradio as gr
|
| 44 |
|
| 45 |
+
# Global buffer for per-frame illumination/motion metrics
|
| 46 |
+
FRAME_METRICS: List[Dict] = []
|
| 47 |
|
| 48 |
class PhysMambaattention_viz:
|
| 49 |
"""Simplified Grad-CAM for PhysMamba."""
|
|
|
|
| 228 |
|
| 229 |
|
| 230 |
def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray:
|
| 231 |
+
"""
|
| 232 |
+
Apply DiffNormalized preprocessing from the PhysMamba paper:
|
| 233 |
+
|
| 234 |
+
diff_t = (I_t - I_{t-1}) / (I_t + I_{t-1} + eps)
|
| 235 |
+
then global std-normalize.
|
| 236 |
+
|
| 237 |
+
frames: list of HxWx3 uint8 or float32 arrays (RGB or BGR, consistent).
|
| 238 |
+
Returns: (T, H, W, C) float32.
|
| 239 |
+
"""
|
| 240 |
+
if not frames:
|
| 241 |
+
return np.zeros((0,), dtype=np.float32)
|
| 242 |
+
|
| 243 |
if len(frames) < 2:
|
| 244 |
+
f0 = frames[0].astype(np.float32)
|
| 245 |
+
return np.stack([np.zeros_like(f0, dtype=np.float32)], axis=0)
|
| 246 |
+
|
| 247 |
diff_frames = []
|
|
|
|
| 248 |
for i in range(len(frames)):
|
| 249 |
if i == 0:
|
| 250 |
diff_frames.append(np.zeros_like(frames[0], dtype=np.float32))
|
| 251 |
else:
|
| 252 |
curr = frames[i].astype(np.float32)
|
| 253 |
+
prev = frames[i - 1].astype(np.float32)
|
| 254 |
+
denom = curr + prev + 1e-8
|
| 255 |
+
diff = (curr - prev) / denom
|
| 256 |
diff_frames.append(diff)
|
| 257 |
+
|
| 258 |
+
diff_array = np.stack(diff_frames).astype(np.float32)
|
| 259 |
+
std = float(diff_array.std()) + 1e-8
|
| 260 |
+
diff_array /= std
|
|
|
|
|
|
|
| 261 |
return diff_array
|
| 262 |
|
| 263 |
|
| 264 |
+
def preprocess_for_physmamba(
|
| 265 |
+
frames: List[np.ndarray],
|
| 266 |
+
target_frames: int = 128,
|
| 267 |
+
target_size: int = 128
|
| 268 |
+
) -> torch.Tensor:
|
| 269 |
+
"""
|
| 270 |
+
Complete DiffNormalized preprocessing pipeline to produce
|
| 271 |
+
a PhysMamba-compatible clip tensor of shape [1, 3, T, H, W].
|
| 272 |
+
|
| 273 |
+
NOTE: This path is *not* used in the current live demo, which instead
|
| 274 |
+
uses normalize_frame() + forward_bvp(). Keep for future experiments.
|
| 275 |
+
"""
|
| 276 |
+
if not frames:
|
| 277 |
+
# Dummy tensor; caller should guard length > 0
|
| 278 |
+
return torch.zeros(1, 3, target_frames, target_size, target_size, dtype=torch.float32)
|
| 279 |
+
|
| 280 |
+
# Temporal sampling / padding to target_frames
|
| 281 |
if len(frames) < target_frames:
|
| 282 |
frames = frames + [frames[-1]] * (target_frames - len(frames))
|
| 283 |
elif len(frames) > target_frames:
|
| 284 |
+
idx = np.linspace(0, len(frames) - 1, target_frames).astype(int)
|
| 285 |
+
frames = [frames[i] for i in idx]
|
| 286 |
+
|
| 287 |
+
# Convert to RGB and resize
|
| 288 |
+
frames_rgb = [f[..., ::-1].copy() for f in frames] # BGR->RGB
|
| 289 |
frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
+
# DiffNormalized
|
| 292 |
+
diff_array = apply_diff_normalized(frames_resized) # (T, H, W, C)
|
| 293 |
+
|
| 294 |
+
# To [B, C, T, H, W]
|
| 295 |
+
diff_array = np.transpose(diff_array, (3, 0, 1, 2)) # (C, T, H, W)
|
| 296 |
+
diff_array = np.expand_dims(diff_array, axis=0) # (1, C, T, H, W)
|
| 297 |
+
|
| 298 |
+
return torch.from_numpy(diff_array.astype(np.float32))
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ---------------------------------------------------------------------------
|
| 302 |
+
# Paths, device, constants
|
| 303 |
+
# ---------------------------------------------------------------------------
|
| 304 |
+
|
| 305 |
+
HERE = Path(__file__).resolve().parent
|
| 306 |
MODEL_DIR = HERE / "final_model_release"
|
| 307 |
LOG_DIR = HERE / "logs"
|
| 308 |
ANALYSIS_DIR = HERE / "analysis"
|
| 309 |
+
for d in (MODEL_DIR, LOG_DIR, ANALYSIS_DIR):
|
| 310 |
+
d.mkdir(exist_ok=True, parents=True)
|
| 311 |
|
| 312 |
DEVICE = (
|
| 313 |
torch.device("cuda") if torch.cuda.is_available()
|
|
|
|
| 315 |
else torch.device("cpu")
|
| 316 |
)
|
| 317 |
|
| 318 |
+
FACE_CASCADE = cv2.CascadeClassifier(
|
| 319 |
+
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
|
| 320 |
+
)
|
| 321 |
|
| 322 |
DEFAULT_SIZE = 128 # input H=W to model
|
| 323 |
DEFAULT_T = 128 # clip length
|
|
|
|
| 335 |
GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"}
|
| 336 |
GT_EXTS = {".txt", ".csv", ".json"}
|
| 337 |
|
| 338 |
+
|
| 339 |
def _as_path(maybe) -> Optional[str]:
|
| 340 |
+
"""
|
| 341 |
+
Return a filesystem path from Gradio values (str, dict, Path, tempfile objects, lists).
|
| 342 |
+
|
| 343 |
+
Handles:
|
| 344 |
+
- plain strings
|
| 345 |
+
- pathlib.Path
|
| 346 |
+
- Gradio dicts with keys: 'name', 'path', 'file'
|
| 347 |
+
- file-like objects with .name
|
| 348 |
+
- lists (takes first element)
|
| 349 |
+
"""
|
| 350 |
if maybe is None:
|
| 351 |
return None
|
| 352 |
+
|
| 353 |
+
# Gradio can pass a list (e.g., multiple files / directory upload)
|
| 354 |
+
if isinstance(maybe, list):
|
| 355 |
+
if not maybe:
|
| 356 |
+
return None
|
| 357 |
+
return _as_path(maybe[0])
|
| 358 |
+
|
| 359 |
if isinstance(maybe, str):
|
| 360 |
return maybe
|
| 361 |
+
|
| 362 |
+
if isinstance(maybe, Path):
|
| 363 |
+
return str(maybe)
|
| 364 |
+
|
| 365 |
+
# Gradio v4 File/Video components often pass a dict
|
| 366 |
if isinstance(maybe, dict):
|
| 367 |
+
for key in ("name", "path", "file"):
|
| 368 |
+
v = maybe.get(key)
|
| 369 |
+
if isinstance(v, str) and v:
|
| 370 |
+
return v
|
| 371 |
+
return None
|
| 372 |
+
|
| 373 |
+
# tempfile-like / UploadedFile objects
|
| 374 |
+
name = getattr(maybe, "name", None)
|
| 375 |
+
if isinstance(name, str) and name:
|
| 376 |
return name
|
| 377 |
+
|
| 378 |
try:
|
| 379 |
return str(maybe)
|
| 380 |
except Exception:
|
| 381 |
return None
|
| 382 |
|
| 383 |
+
|
| 384 |
def _import_from_file(py_path: Path):
|
| 385 |
spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path))
|
| 386 |
if not spec or not spec.loader:
|
|
|
|
| 389 |
spec.loader.exec_module(mod)
|
| 390 |
return mod
|
| 391 |
|
| 392 |
+
|
| 393 |
def _looks_like_video(p: Path) -> bool:
|
| 394 |
+
"""
|
| 395 |
+
Heuristic for 'video-like' files used in subject-folder discovery.
|
| 396 |
+
Treat .mat as video, plus common video extensions.
|
| 397 |
+
"""
|
| 398 |
if p.suffix.lower() == ".mat":
|
| 399 |
return True
|
| 400 |
return p.suffix.lower() in VIDEO_EXTENSIONS
|
| 401 |
|
| 402 |
+
|
| 403 |
class SimpleActivationAttention:
|
| 404 |
+
"""Lightweight attention visualization using forward activations (no gradients)."""
|
| 405 |
|
| 406 |
def __init__(self, model: nn.Module, device: torch.device):
|
| 407 |
self.model = model
|
| 408 |
self.device = device
|
| 409 |
+
self.activations: Optional[torch.Tensor] = None
|
| 410 |
+
self.hook_handle: Optional[Any] = None
|
| 411 |
|
| 412 |
def _activation_hook(self, module, input, output):
|
| 413 |
"""Capture activations during forward pass."""
|
| 414 |
+
try:
|
| 415 |
+
self.activations = output.detach()
|
| 416 |
+
except Exception:
|
| 417 |
+
self.activations = None
|
| 418 |
|
| 419 |
def register_hook(self):
|
| 420 |
+
"""Register hook on a suitable conv layer (last conv before Mamba if possible)."""
|
|
|
|
| 421 |
target = None
|
| 422 |
target_name = None
|
| 423 |
|
|
|
|
| 425 |
if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower():
|
| 426 |
target = module
|
| 427 |
target_name = name
|
| 428 |
+
|
| 429 |
if target is None:
|
| 430 |
print("⚠ [attention_viz] No suitable conv layer found, attention disabled")
|
| 431 |
return
|
|
|
|
| 434 |
print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})")
|
| 435 |
|
| 436 |
def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]:
|
| 437 |
+
"""
|
| 438 |
+
Generate attention map from stored activations (call AFTER the forward pass).
|
| 439 |
+
|
| 440 |
+
Returns a 2D numpy array in [0,1] or None if unavailable.
|
| 441 |
+
"""
|
| 442 |
try:
|
| 443 |
if self.activations is None:
|
| 444 |
return None
|
| 445 |
|
|
|
|
| 446 |
act = self.activations
|
| 447 |
+
|
| 448 |
# Handle different tensor shapes
|
| 449 |
if act.dim() == 5: # [B, C, T, H, W]
|
| 450 |
+
# Average over channels and time: -> [B, H, W]
|
| 451 |
+
attention = act.mean(dim=[1, 2])
|
| 452 |
elif act.dim() == 4: # [B, C, H, W]
|
| 453 |
attention = act.mean(dim=1) # -> [B, H, W]
|
| 454 |
else:
|
| 455 |
print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}")
|
| 456 |
return None
|
| 457 |
|
| 458 |
+
# Take first batch
|
| 459 |
+
attention = attention[0].detach().cpu().numpy()
|
| 460 |
+
|
| 461 |
# Normalize to [0, 1]
|
| 462 |
+
a_min, a_max = attention.min(), attention.max()
|
| 463 |
+
if a_max > a_min:
|
| 464 |
+
attention = (attention - a_min) / (a_max - a_min)
|
| 465 |
+
else:
|
| 466 |
+
attention = np.zeros_like(attention, dtype=np.float32)
|
| 467 |
|
| 468 |
return attention
|
| 469 |
|
|
|
|
| 473 |
|
| 474 |
def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
|
| 475 |
"""Overlay heatmap on frame."""
|
| 476 |
+
if heatmap is None or frame is None or frame.size == 0:
|
| 477 |
+
return frame
|
| 478 |
+
|
| 479 |
h, w = frame.shape[:2]
|
| 480 |
heatmap_resized = cv2.resize(heatmap, (w, h))
|
| 481 |
heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
|
| 482 |
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 483 |
+
overlay = cv2.addWeighted(frame, 1 - alpha, heatmap_colored, alpha, 0)
|
| 484 |
return overlay
|
| 485 |
|
| 486 |
def cleanup(self):
|
| 487 |
if self.hook_handle is not None:
|
| 488 |
self.hook_handle.remove()
|
| 489 |
+
self.hook_handle = None
|
| 490 |
+
self.activations = None
|
| 491 |
+
|
| 492 |
|
| 493 |
class VideoReader:
|
| 494 |
"""
|
|
|
|
| 504 |
self._idx = 0
|
| 505 |
self._len = 0
|
| 506 |
self._shape = None
|
| 507 |
+
self._fps = 0
|
| 508 |
|
| 509 |
if self.path.lower().endswith(".mat") and MAT_SUPPORT:
|
| 510 |
self._open_mat(self.path)
|
|
|
|
| 517 |
raise RuntimeError("Cannot open video")
|
| 518 |
self._cap = cap
|
| 519 |
self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
| 520 |
+
self._fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
|
| 521 |
|
| 522 |
def _open_mat(self, path: str):
|
| 523 |
try:
|
|
|
|
| 529 |
break
|
| 530 |
else:
|
| 531 |
arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None)
|
| 532 |
+
|
| 533 |
if arr is None:
|
| 534 |
raise RuntimeError("No ndarray found in .mat")
|
| 535 |
|
|
|
|
| 537 |
# Normalize to (T,H,W,3)
|
| 538 |
if a.ndim == 4:
|
| 539 |
if a.shape[-1] == 3:
|
| 540 |
+
# Heuristic: if first dim is much smaller than spatial dims -> assume T
|
| 541 |
+
if a.shape[0] <= a.shape[1] and a.shape[0] <= a.shape[2]: # (T,H,W,3)
|
| 542 |
v = a
|
| 543 |
+
else: # (H,W,T,3) -> (T,H,W,3)
|
| 544 |
v = np.transpose(a, (2, 0, 1, 3))
|
| 545 |
else:
|
| 546 |
+
v = a[..., :1] # take first channel
|
| 547 |
elif a.ndim == 3:
|
| 548 |
+
# (T,H,W) or (H,W,T)
|
| 549 |
+
if a.shape[0] <= a.shape[1] and a.shape[0] <= a.shape[2]: # (T,H,W)
|
| 550 |
v = a
|
| 551 |
+
else: # (H,W,T) -> (T,H,W)
|
| 552 |
v = np.transpose(a, (2, 0, 1))
|
| 553 |
v = v[..., None]
|
| 554 |
else:
|
|
|
|
| 561 |
self._mat = v
|
| 562 |
self._len = v.shape[0]
|
| 563 |
self._shape = v.shape[1:3]
|
| 564 |
+
self._fps = 0.0 # unknown; caller can override
|
| 565 |
except Exception as e:
|
| 566 |
raise RuntimeError(f"Failed to open .mat video: {e}")
|
| 567 |
|
|
|
|
| 578 |
|
| 579 |
def fps(self, fallback: int = 30) -> int:
|
| 580 |
if self._mat is not None:
|
| 581 |
+
return int(fallback)
|
| 582 |
+
if self._fps and self._fps > 0:
|
| 583 |
+
return int(self._fps)
|
| 584 |
f = self._cap.get(cv2.CAP_PROP_FPS)
|
| 585 |
+
return int(f) if f and f > 0 else int(fallback)
|
| 586 |
|
| 587 |
def length(self) -> int:
|
| 588 |
return self._len
|
|
|
|
| 591 |
if self._cap is not None:
|
| 592 |
self._cap.release()
|
| 593 |
|
| 594 |
+
|
| 595 |
def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]:
|
| 596 |
x, y, w, h = face
|
| 597 |
# forehead
|
|
|
|
| 602 |
ff = frame[y:y + h, x:x + w]
|
| 603 |
return {"forehead": fh, "cheeks": ck, "face": ff}
|
| 604 |
|
| 605 |
+
|
| 606 |
def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float:
|
| 607 |
if patch is None or patch.size == 0:
|
| 608 |
return -1e9
|
| 609 |
g = patch[..., 1].astype(np.float32) / 255.0 # green channel
|
| 610 |
g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling
|
| 611 |
g = g - g.mean()
|
|
|
|
| 612 |
try:
|
| 613 |
+
b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band")
|
| 614 |
y = signal.filtfilt(b, a, g, method="gust")
|
| 615 |
except Exception:
|
| 616 |
y = g
|
| 617 |
return float((y ** 2).mean())
|
| 618 |
|
| 619 |
+
|
| 620 |
def pick_auto_roi(face: Tuple[int, int, int, int],
|
| 621 |
+
frame: np.ndarray,
|
| 622 |
+
attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]:
|
| 623 |
+
"""Simple ROI selection using signal quality + optional attention weighting."""
|
| 624 |
cands = roi_candidates(face, frame)
|
| 625 |
scores = {k: roi_quality_score(v) for k, v in cands.items()}
|
| 626 |
|
|
|
|
| 633 |
ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0
|
| 634 |
ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0
|
| 635 |
scores['forehead'] += fh_attn * 0.2
|
| 636 |
+
scores['cheeks'] += ck_attn * 0.2
|
| 637 |
+
scores['face'] += ff_attn * 0.2
|
| 638 |
except Exception:
|
| 639 |
pass
|
| 640 |
|
| 641 |
best = max(scores, key=scores.get)
|
| 642 |
return cands[best], best
|
| 643 |
|
| 644 |
+
|
| 645 |
def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
|
| 646 |
"""
|
| 647 |
Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None).
|
|
|
|
| 696 |
uniq.append((v, g))
|
| 697 |
return uniq
|
| 698 |
|
| 699 |
+
|
| 700 |
def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"):
|
| 701 |
import inspect
|
| 702 |
|
|
|
|
| 735 |
except Exception:
|
| 736 |
continue
|
| 737 |
|
| 738 |
+
raise ImportError("Could not find PhysMamba model class")
|
| 739 |
+
|
| 740 |
|
| 741 |
def load_physmamba_model(ckpt_path: Path, device: torch.device,
|
| 742 |
model_file: str = "", model_class: str = "PhysMamba"):
|
|
|
|
| 777 |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 778 |
model.load_state_dict(state_dict, strict=False)
|
| 779 |
except Exception:
|
| 780 |
+
# If loading fails, you still get an uninitialized model (for debugging)
|
| 781 |
pass
|
| 782 |
|
| 783 |
model.to(device).eval()
|
|
|
|
| 786 |
with torch.no_grad():
|
| 787 |
_ = model(torch.zeros(1, 3, 8, 128, 128, device=device))
|
| 788 |
except Exception:
|
| 789 |
+
# Shape sanity check failed, but we keep the model usable.
|
| 790 |
pass
|
| 791 |
|
| 792 |
+
# For now: attention visualization disabled (extract_attention_map returns None)
|
| 793 |
attention_viz = None
|
| 794 |
|
| 795 |
return model, attention_viz
|
| 796 |
|
| 797 |
+
|
| 798 |
def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray:
|
| 799 |
"""
|
| 800 |
Stable band-pass with edge-safety and parameter clipping.
|
|
|
|
| 812 |
|
| 813 |
try:
|
| 814 |
b, a = signal.butter(order, [lo, hi], btype="band")
|
|
|
|
| 815 |
padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1))
|
| 816 |
return signal.filtfilt(b, a, x, padlen=padlen)
|
| 817 |
except Exception:
|
| 818 |
return x
|
| 819 |
|
| 820 |
+
|
| 821 |
def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float:
|
| 822 |
"""
|
| 823 |
HR (BPM) via Welch PSD peak in [lo, hi] Hz.
|
|
|
|
| 826 |
if x.size < int(fs * 4.0): # need ~4s for a usable PSD
|
| 827 |
return 0.0
|
| 828 |
try:
|
|
|
|
| 829 |
nper = int(min(max(64, fs * 2), min(512, x.size)))
|
| 830 |
f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant")
|
| 831 |
|
|
|
|
| 840 |
|
| 841 |
fpk = float(f_band[np.argmax(p_band)])
|
| 842 |
bpm = fpk * 60.0
|
|
|
|
| 843 |
return float(np.clip(bpm, 30.0, 220.0))
|
| 844 |
except Exception:
|
| 845 |
return 0.0
|
| 846 |
|
| 847 |
+
|
| 848 |
def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
|
| 849 |
"""
|
| 850 |
HRV RMSSD from peaks; robust to short/flat segments.
|
|
|
|
| 853 |
if x.size < int(fs * 5.0):
|
| 854 |
return 0.0
|
| 855 |
try:
|
|
|
|
| 856 |
peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs)))
|
| 857 |
if len(peaks) < 3:
|
| 858 |
return 0.0
|
|
|
|
| 863 |
except Exception:
|
| 864 |
return 0.0
|
| 865 |
|
| 866 |
+
|
| 867 |
def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
|
| 868 |
"""
|
| 869 |
Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band.
|
|
|
|
| 886 |
lo, hi = REST_HR_RANGE
|
| 887 |
if hr < lo or hr > hi:
|
| 888 |
dist = abs(hr - REST_HR_TARGET)
|
|
|
|
| 889 |
alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65))
|
| 890 |
hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET
|
| 891 |
|
|
|
|
| 900 |
|
| 901 |
return y_filt, float(hr)
|
| 902 |
|
| 903 |
+
|
| 904 |
def draw_face_and_roi(frame_bgr: np.ndarray,
|
| 905 |
face_bbox: Optional[Tuple[int, int, int, int]],
|
| 906 |
roi_bbox: Optional[Tuple[int, int, int, int]],
|
|
|
|
| 919 |
cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2)
|
| 920 |
return vis
|
| 921 |
|
| 922 |
+
|
| 923 |
def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
|
| 924 |
roi_type: str,
|
| 925 |
frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
|
|
|
|
| 945 |
return (0, 0, 0, 0)
|
| 946 |
return (rx, ry, rx2 - rx, ry2 - ry)
|
| 947 |
|
| 948 |
+
|
| 949 |
def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
|
| 950 |
"""
|
| 951 |
Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR.
|
|
|
|
| 954 |
if chw is None or chw.ndim != 3 or chw.shape[0] != 3:
|
| 955 |
return np.zeros((128, 128, 3), dtype=np.uint8)
|
| 956 |
|
|
|
|
| 957 |
img = chw.copy()
|
|
|
|
| 958 |
vmin, vmax = float(img.min()), float(img.max())
|
| 959 |
if vmax <= vmin + 1e-6:
|
| 960 |
img = np.zeros_like(img)
|
|
|
|
| 964 |
img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR
|
| 965 |
return img
|
| 966 |
|
| 967 |
+
|
| 968 |
def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
|
| 969 |
if gt_len <= 1:
|
| 970 |
return None
|
|
|
|
| 972 |
return np.arange(gt_len, dtype=float) / float(gt_fs)
|
| 973 |
return None # will fall back to length-matching overlay
|
| 974 |
|
| 975 |
+
|
| 976 |
def plot_signals_with_gt(time_axis: np.ndarray,
|
| 977 |
raw_signal: np.ndarray,
|
| 978 |
post_signal: np.ndarray,
|
|
|
|
| 1017 |
t_new = _np.asarray(t_new, dtype=float).ravel()
|
| 1018 |
|
| 1019 |
if x_t.size < 2 or y.size != x_t.size:
|
|
|
|
| 1020 |
if y.size == 0 or t_new.size == 0:
|
| 1021 |
return _np.zeros_like(t_new)
|
| 1022 |
idx = _np.linspace(0, y.size - 1, num=t_new.size)
|
| 1023 |
return _np.interp(_np.arange(t_new.size), idx, y)
|
| 1024 |
|
|
|
|
| 1025 |
order = _np.argsort(x_t)
|
| 1026 |
x_t = x_t[order]
|
| 1027 |
y = y[order]
|
| 1028 |
mask = _np.concatenate(([True], _np.diff(x_t) > 0))
|
| 1029 |
x_t = x_t[mask]
|
| 1030 |
y = y[mask]
|
|
|
|
| 1031 |
t_clip = _np.clip(t_new, x_t[0], x_t[-1])
|
| 1032 |
return _np.interp(t_clip, x_t, y)
|
| 1033 |
|
|
|
|
| 1039 |
n = int(min(len(x), len(y)))
|
| 1040 |
x = x[:n]; y = y[:n]
|
| 1041 |
max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local))))
|
|
|
|
| 1042 |
lags = _np.arange(-max_lag, max_lag + 1)
|
|
|
|
| 1043 |
best_corr = -_np.inf
|
| 1044 |
best_lag = 0
|
| 1045 |
for L in lags:
|
|
|
|
| 1071 |
out = _np.empty_like(y)
|
| 1072 |
out[:] = _np.nan
|
| 1073 |
if shift > 0:
|
|
|
|
| 1074 |
out[shift:] = y[:-shift]
|
| 1075 |
else:
|
|
|
|
| 1076 |
out[:shift] = y[-shift:]
|
| 1077 |
return out
|
| 1078 |
|
|
|
|
| 1080 |
raw = _np.asarray(raw_signal, dtype=float)
|
| 1081 |
post = _np.asarray(post_signal, dtype=float)
|
| 1082 |
|
|
|
|
| 1083 |
if t.size == 0:
|
| 1084 |
t = _np.arange(post.size, dtype=float) / max(fs, 1)
|
| 1085 |
|
|
|
|
| 1096 |
gt_t = _np.asarray(gt_time, dtype=float).ravel()
|
| 1097 |
gt_on_pred = _safe_interp(gt_t, gt, t)
|
| 1098 |
else:
|
| 1099 |
+
gt_on_pred = _safe_interp(
|
| 1100 |
+
_np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size),
|
| 1101 |
+
gt, t
|
| 1102 |
+
)
|
| 1103 |
|
|
|
|
| 1104 |
pred_bp = _bandpass(post, fs)
|
| 1105 |
gt_bp = _bandpass(gt_on_pred, fs)
|
| 1106 |
|
|
|
|
| 1107 |
lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0)
|
|
|
|
|
|
|
| 1108 |
gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs)
|
| 1109 |
|
|
|
|
| 1110 |
valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp)
|
| 1111 |
if valid.sum() >= 16:
|
| 1112 |
pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1])
|
|
|
|
| 1115 |
|
| 1116 |
hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs)
|
| 1117 |
|
|
|
|
| 1118 |
_plt.figure(figsize=(13, 6), dpi=110)
|
| 1119 |
gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35)
|
| 1120 |
|
|
|
|
| 1121 |
ax1 = _plt.subplot(gs[0, 0])
|
| 1122 |
ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5)
|
| 1123 |
ax1.set_title(f"Predicted (Raw) — fs={fs} Hz")
|
| 1124 |
ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude")
|
| 1125 |
ax1.grid(True, alpha=0.3)
|
| 1126 |
|
|
|
|
| 1127 |
ax2 = _plt.subplot(gs[0, 1])
|
| 1128 |
ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5)
|
| 1129 |
ax2.set_title("Predicted (Post-processed)")
|
| 1130 |
ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude")
|
| 1131 |
ax2.grid(True, alpha=0.3)
|
| 1132 |
|
|
|
|
| 1133 |
ax3 = _plt.subplot(gs[1, :])
|
| 1134 |
ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6)
|
| 1135 |
|
|
|
|
| 1138 |
gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs)
|
| 1139 |
ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9)
|
| 1140 |
|
|
|
|
| 1141 |
txt = [
|
| 1142 |
f"HR_pred: {hr_pred:.1f} BPM",
|
| 1143 |
f"HR_gt: {hr_gt:.1f} BPM",
|
|
|
|
| 1172 |
if frame is None or frame.size == 0:
|
| 1173 |
return None
|
| 1174 |
|
| 1175 |
+
# If cascade is missing, fail fast (prevents cryptic OpenCV errors)
|
| 1176 |
+
if FACE_CASCADE is None:
|
| 1177 |
+
return None
|
| 1178 |
+
|
| 1179 |
try:
|
| 1180 |
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 1181 |
except Exception:
|
| 1182 |
# If color conversion fails, assume already gray
|
| 1183 |
+
if frame.ndim == 2:
|
| 1184 |
+
gray = frame.copy()
|
| 1185 |
+
else:
|
| 1186 |
+
gray = cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY)
|
| 1187 |
|
| 1188 |
# Light preproc to improve Haar performance
|
| 1189 |
gray = cv2.equalizeHist(gray)
|
| 1190 |
|
| 1191 |
+
faces_all: List[Tuple[int, int, int, int]] = []
|
| 1192 |
# Try a couple of parameter combos to be more forgiving
|
| 1193 |
params = [
|
| 1194 |
dict(scaleFactor=1.05, minNeighbors=3),
|
|
|
|
| 1209 |
# Return the largest (by area)
|
| 1210 |
return max(faces_all, key=lambda f: f[2] * f[3])
|
| 1211 |
|
| 1212 |
+
|
| 1213 |
def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]:
|
| 1214 |
"""
|
| 1215 |
Crop ROI from the frame based on a face bbox and the selected roi_type.
|
|
|
|
| 1243 |
return None
|
| 1244 |
return roi
|
| 1245 |
|
| 1246 |
+
|
| 1247 |
+
def crop_roi_with_bbox(
|
| 1248 |
+
face_bbox: Tuple[int, int, int, int],
|
| 1249 |
+
roi_type: str,
|
| 1250 |
+
frame: np.ndarray
|
| 1251 |
+
) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]:
|
| 1252 |
+
"""
|
| 1253 |
+
Same as crop_roi, but also returns the ROI bbox (x, y, w, h) in frame coords.
|
| 1254 |
+
"""
|
| 1255 |
if face_bbox is None or frame is None or frame.size == 0:
|
| 1256 |
return None, None
|
| 1257 |
|
|
|
|
| 1278 |
|
| 1279 |
return roi, (rx, ry, rx2 - rx, ry2 - ry)
|
| 1280 |
|
| 1281 |
+
|
| 1282 |
def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
|
| 1283 |
"""
|
| 1284 |
PhysMamba-compatible normalization with DiffNormalized support.
|
|
|
|
| 1303 |
chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False)
|
| 1304 |
return chw
|
| 1305 |
|
| 1306 |
+
|
| 1307 |
+
def extract_attention_map(model, clip_tensor: torch.Tensor, attention_viz) -> Optional[np.ndarray]:
|
| 1308 |
"""Attention visualization disabled - model architecture incompatible."""
|
| 1309 |
return None
|
| 1310 |
|
| 1311 |
+
|
| 1312 |
+
def create_attention_overlay(
|
| 1313 |
+
frame: np.ndarray,
|
| 1314 |
+
attention: Optional[np.ndarray],
|
| 1315 |
+
attention_viz: Optional[SimpleActivationAttention] = None
|
| 1316 |
+
) -> np.ndarray:
|
| 1317 |
+
"""Create attention heatmap overlay (currently passthrough)."""
|
| 1318 |
return frame
|
| 1319 |
|
| 1320 |
+
|
| 1321 |
def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
|
| 1322 |
H, W = roi_bgr.shape[:2]
|
| 1323 |
+
base_bvp = forward_bvp(
|
| 1324 |
+
model,
|
| 1325 |
+
torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE))
|
| 1326 |
+
.unsqueeze(0).unsqueeze(2).to(DEVICE) # fake T=1 path if needed
|
| 1327 |
+
)
|
| 1328 |
base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs)
|
| 1329 |
|
| 1330 |
heat = np.zeros((H, W), np.float32)
|
|
|
|
| 1332 |
for x in range(0, W - patch + 1, stride):
|
| 1333 |
tmp = roi_bgr.copy()
|
| 1334 |
tmp[y:y+patch, x:x+patch] = 127 # occlude
|
| 1335 |
+
bvp = forward_bvp(
|
| 1336 |
+
model,
|
| 1337 |
+
torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE))
|
| 1338 |
+
.unsqueeze(0).unsqueeze(2).to(DEVICE)
|
| 1339 |
+
)
|
| 1340 |
power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs)
|
| 1341 |
drop = max(0.0, base_power - power)
|
| 1342 |
heat[y:y+patch, x:x+patch] += drop
|
| 1343 |
heat -= heat.min()
|
| 1344 |
+
if heat.max() > 1e-8:
|
| 1345 |
+
heat /= heat.max()
|
| 1346 |
return heat
|
| 1347 |
|
| 1348 |
+
|
| 1349 |
def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
|
| 1350 |
"""
|
| 1351 |
Try common 5D layouts:
|
|
|
|
| 1362 |
last_err = e
|
| 1363 |
raise last_err
|
| 1364 |
|
| 1365 |
+
|
| 1366 |
def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
|
| 1367 |
"""
|
| 1368 |
Forward and extract a 1D time-like BVP vector with length T_clip.
|
|
|
|
| 1409 |
B, K = out.shape
|
| 1410 |
if B == 1:
|
| 1411 |
v = out[0]
|
| 1412 |
+
return (v.numpy() if v.shape[0] == T_clip
|
| 1413 |
+
else np.resize(v.numpy(), T_clip).astype(np.float32))
|
| 1414 |
if B == T_clip:
|
| 1415 |
return out[:, 0].numpy()
|
| 1416 |
if K == T_clip:
|
|
|
|
| 1455 |
val = float(out.mean().item()) if out.numel() else 0.0
|
| 1456 |
return np.full(T_clip, val, dtype=np.float32)
|
| 1457 |
|
| 1458 |
+
|
| 1459 |
def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
|
| 1460 |
"""
|
| 1461 |
Classical rPPG from green-channel means when the model yields nothing.
|
|
|
|
| 1480 |
std = float(np.std(y)) + 1e-6
|
| 1481 |
return (y / std).astype(np.float32)
|
| 1482 |
|
| 1483 |
+
|
| 1484 |
def _to_floats(s: str) -> List[float]:
|
| 1485 |
"""
|
| 1486 |
Extract all real numbers from free-form text, including scientific notation.
|
| 1487 |
+
Gracefully ignores comments, units, and non-numeric junk.
|
| 1488 |
"""
|
| 1489 |
if not isinstance(s, str) or not s:
|
| 1490 |
return []
|
| 1491 |
|
| 1492 |
+
# Strip comments starting with #, //, or ;
|
| 1493 |
s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE)
|
| 1494 |
|
| 1495 |
+
# Normalize common delimiters
|
| 1496 |
s = s.replace(",", " ").replace(";", " ")
|
| 1497 |
|
| 1498 |
+
toks = re.findall(
|
| 1499 |
+
r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?",
|
| 1500 |
+
s
|
| 1501 |
+
)
|
| 1502 |
out: List[float] = []
|
| 1503 |
for t in toks:
|
| 1504 |
try:
|
|
|
|
| 1538 |
diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
|
| 1539 |
return (1.0 / float(np.median(diffs))) if diffs.size else 0.0
|
| 1540 |
|
|
|
|
| 1541 |
def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float:
|
| 1542 |
if bvp is None or bvp.size == 0:
|
| 1543 |
return 0.0
|
|
|
|
| 1545 |
bp = bandpass_filter(bvp.astype(float), fs=fs_use)
|
| 1546 |
return hr_from_welch(bp, fs=fs_use)
|
| 1547 |
|
| 1548 |
+
# ================= UBFC-style TXT (3 lines) =================
|
| 1549 |
+
if p.name.lower() == "ground_truth.txt" or (
|
| 1550 |
+
ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2
|
| 1551 |
+
):
|
| 1552 |
try:
|
| 1553 |
+
lines = [
|
| 1554 |
+
ln.strip()
|
| 1555 |
+
for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()
|
| 1556 |
+
if ln.strip()
|
| 1557 |
+
]
|
| 1558 |
ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else []
|
| 1559 |
hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else []
|
| 1560 |
t_vals = _to_floats(lines[2]) if len(lines) >= 3 else []
|
|
|
|
| 1578 |
# Fall through to generic handlers
|
| 1579 |
pass
|
| 1580 |
|
| 1581 |
+
# ================= Generic TXT =================
|
| 1582 |
if ext == ".txt":
|
| 1583 |
try:
|
| 1584 |
nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore"))
|
|
|
|
| 1588 |
except Exception:
|
| 1589 |
return np.array([]), 0.0, 0.0
|
| 1590 |
|
| 1591 |
+
# ================= JSON =================
|
| 1592 |
if ext == ".json":
|
| 1593 |
try:
|
| 1594 |
data = json.loads(p.read_text(encoding="utf-8", errors="ignore"))
|
|
|
|
|
|
|
| 1595 |
|
| 1596 |
def _seek(obj, keys):
|
| 1597 |
for k in keys:
|
|
|
|
| 1599 |
return obj[k]
|
| 1600 |
return None
|
| 1601 |
|
|
|
|
| 1602 |
bvp = _seek(data, ("ppg", "bvp", "signal", "wave"))
|
|
|
|
| 1603 |
if bvp is None:
|
| 1604 |
for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"):
|
| 1605 |
if container_key in data:
|
|
|
|
| 1613 |
else:
|
| 1614 |
bvp = np.array([], dtype=float)
|
| 1615 |
|
|
|
|
| 1616 |
fs_hint = 0.0
|
| 1617 |
if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0:
|
| 1618 |
fs_hint = float(data["fs"])
|
|
|
|
| 1628 |
except Exception:
|
| 1629 |
return np.array([]), 0.0, 0.0
|
| 1630 |
|
| 1631 |
+
# ================= CSV =================
|
| 1632 |
if ext == ".csv":
|
| 1633 |
try:
|
| 1634 |
df = pd.read_csv(p)
|
|
|
|
| 1635 |
cols = {str(c).strip().lower(): c for c in df.columns}
|
| 1636 |
|
| 1637 |
def _first_match(names):
|
|
|
|
| 1657 |
except Exception:
|
| 1658 |
return np.array([]), 0.0, 0.0
|
| 1659 |
|
| 1660 |
+
# ================= MAT =================
|
| 1661 |
if ext == ".mat":
|
| 1662 |
try:
|
| 1663 |
md = loadmat(str(p))
|
|
|
|
| 1664 |
arr = None
|
| 1665 |
for key in ("ppg", "bvp", "signal", "wave"):
|
| 1666 |
if key in md and isinstance(md[key], np.ndarray):
|
| 1667 |
arr = md[key]
|
| 1668 |
break
|
| 1669 |
if arr is None:
|
|
|
|
| 1670 |
for v in md.values():
|
| 1671 |
if isinstance(v, np.ndarray) and v.ndim == 1:
|
| 1672 |
arr = v
|
|
|
|
| 1700 |
try:
|
| 1701 |
bvp = np.asarray(np.load(str(p)), dtype=float).ravel()
|
| 1702 |
fs_hint, hr = 0.0, 0.0
|
| 1703 |
+
|
| 1704 |
sidecar = p.with_suffix(".json")
|
| 1705 |
if sidecar.exists():
|
| 1706 |
try:
|
|
|
|
| 1712 |
hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
|
| 1713 |
except Exception:
|
| 1714 |
pass
|
| 1715 |
+
|
| 1716 |
if hr == 0.0 and bvp.size:
|
| 1717 |
hr = _hr_from_bvp(bvp, fs_hint)
|
| 1718 |
return bvp, hr, fs_hint
|
|
|
|
| 1722 |
# Fallback (unsupported extension)
|
| 1723 |
return np.array([]), 0.0, 0.0
|
| 1724 |
|
| 1725 |
+
|
| 1726 |
def scan_models() -> List[str]:
|
| 1727 |
if not MODEL_DIR.exists():
|
| 1728 |
return []
|
|
|
|
| 1734 |
|
| 1735 |
return models
|
| 1736 |
|
| 1737 |
+
|
| 1738 |
_GLOBAL_CONTROLS: Dict[str, Dict] = {}
|
| 1739 |
|
| 1740 |
+
|
| 1741 |
def ensure_controls(control_id: str) -> Tuple[str, Dict]:
|
| 1742 |
# Use a stable default so Pause/Resume/Stop work for the current run
|
| 1743 |
if not control_id:
|
|
|
|
| 1749 |
}
|
| 1750 |
return control_id, _GLOBAL_CONTROLS[control_id]
|
| 1751 |
|
| 1752 |
+
|
| 1753 |
def process_video_file(
|
| 1754 |
video_path: str,
|
| 1755 |
gt_file: Optional[str],
|
|
|
|
| 1762 |
"""
|
| 1763 |
Enhanced video processing with Grad-CAM attention visualization,
|
| 1764 |
plus per-frame illumination and motion logging.
|
| 1765 |
+
Returns 14 outputs per yield, matching Gradio UI:
|
| 1766 |
+
status, pred_hr, gt_hr, rmssd,
|
| 1767 |
+
frame_path, attention_path, signal_path, raw_path, post_path, csv_path,
|
| 1768 |
+
global_brightness, global_motion, roi_brightness, roi_motion
|
| 1769 |
"""
|
| 1770 |
global _HR_SMOOTH
|
| 1771 |
+
global FRAME_METRICS # global frame metrics buffer
|
| 1772 |
_HR_SMOOTH = None
|
| 1773 |
FRAME_METRICS = [] # reset per run
|
| 1774 |
|
|
|
|
| 1782 |
control_id, controls = ensure_controls(control_id)
|
| 1783 |
controls['stop'].clear()
|
| 1784 |
|
| 1785 |
+
# Helper for consistent error yields (14 outputs)
|
| 1786 |
+
def _error_status(msg: str):
|
| 1787 |
+
return (
|
| 1788 |
+
msg,
|
| 1789 |
+
None, None, None, # HR, GT HR, RMSSD
|
| 1790 |
+
None, None, None, # frame, attention, signal
|
| 1791 |
+
None, None, # raw, post
|
| 1792 |
+
None, # csv
|
| 1793 |
+
None, None, None, None # brightness & motion
|
| 1794 |
+
)
|
| 1795 |
+
|
| 1796 |
if not model_name:
|
| 1797 |
+
yield _error_status("ERROR: No model selected")
|
| 1798 |
return
|
| 1799 |
|
| 1800 |
if isinstance(model_name, int):
|
|
|
|
| 1802 |
|
| 1803 |
model_path = MODEL_DIR / model_name
|
| 1804 |
if not model_path.exists():
|
| 1805 |
+
yield _error_status("ERROR: Model not found")
|
| 1806 |
return
|
| 1807 |
|
| 1808 |
try:
|
| 1809 |
model, attention_viz = load_physmamba_model(model_path, DEVICE)
|
| 1810 |
except Exception as e:
|
| 1811 |
+
yield _error_status(f"ERROR loading model: {str(e)}")
|
| 1812 |
return
|
| 1813 |
|
| 1814 |
gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0)
|
| 1815 |
|
| 1816 |
if not video_path or not os.path.exists(video_path):
|
| 1817 |
+
yield _error_status("ERROR: Video not found")
|
| 1818 |
return
|
| 1819 |
|
| 1820 |
cap = cv2.VideoCapture(video_path)
|
| 1821 |
if not cap.isOpened():
|
| 1822 |
+
yield _error_status("ERROR: Cannot open video")
|
| 1823 |
return
|
| 1824 |
|
| 1825 |
fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30)
|
|
|
|
| 1836 |
last_rmssd = 0.0
|
| 1837 |
last_attention = None
|
| 1838 |
|
| 1839 |
+
# previous grayscale frames for motion
|
| 1840 |
prev_gray = None
|
| 1841 |
prev_roi_gray = None
|
| 1842 |
|
|
|
|
| 1850 |
raw_path = tmpdir / "raw_signal.png"
|
| 1851 |
post_path = tmpdir / "post_signal.png"
|
| 1852 |
|
| 1853 |
+
# Initial status yield (14 outputs)
|
| 1854 |
+
yield (
|
| 1855 |
+
"Starting… reading video frames",
|
| 1856 |
+
None,
|
| 1857 |
+
f"{gt_hr:.1f}" if gt_hr > 0 else "--",
|
| 1858 |
+
None,
|
| 1859 |
+
None, None, None, None, None, None,
|
| 1860 |
+
None, None, None, None
|
| 1861 |
+
)
|
| 1862 |
|
| 1863 |
while True:
|
| 1864 |
if controls['stop'].is_set():
|
|
|
|
| 1875 |
|
| 1876 |
frame_idx += 1
|
| 1877 |
|
| 1878 |
+
# default per-frame metrics for this iteration
|
| 1879 |
global_brightness = None
|
| 1880 |
global_motion = None
|
| 1881 |
roi_brightness = None
|
| 1882 |
roi_motion = None
|
| 1883 |
|
| 1884 |
+
# global illumination & motion (full frame)
|
| 1885 |
try:
|
| 1886 |
frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 1887 |
except Exception:
|
|
|
|
| 1907 |
roi = crop_roi(face, roi_type, frame)
|
| 1908 |
|
| 1909 |
if roi is not None and roi.size > 0:
|
| 1910 |
+
# ROI-level brightness & motion
|
| 1911 |
try:
|
| 1912 |
roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
|
| 1913 |
except Exception:
|
|
|
|
| 1944 |
try:
|
| 1945 |
raw = forward_bvp(model, clip_t)
|
| 1946 |
if isinstance(raw, np.ndarray):
|
| 1947 |
+
raw = np.nan_to_num(
|
| 1948 |
+
raw, nan=0.0, posinf=0.0, neginf=0.0
|
| 1949 |
+
).astype(np.float32, copy=False)
|
| 1950 |
bvp_out = raw if raw.size > 0 else None
|
| 1951 |
else:
|
| 1952 |
bvp_out = None
|
|
|
|
| 1954 |
print(f"[infer] forward_bvp error: {e}")
|
| 1955 |
bvp_out = None
|
| 1956 |
|
| 1957 |
+
# Generate attention with Grad-CAM
|
| 1958 |
try:
|
| 1959 |
last_attention = extract_attention_map(model, clip_t, attention_viz)
|
| 1960 |
except Exception as e:
|
|
|
|
| 1962 |
last_attention = None
|
| 1963 |
|
| 1964 |
if bvp_out is None or bvp_out.size == 0:
|
| 1965 |
+
gbuf = np.nan_to_num(
|
| 1966 |
+
np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0
|
| 1967 |
+
)
|
| 1968 |
fb = _fallback_bvp_from_means(gbuf, fs=fps)
|
| 1969 |
if isinstance(fb, np.ndarray) and fb.size > 0:
|
| 1970 |
bvp_out = fb
|
|
|
|
| 1978 |
bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
|
| 1979 |
|
| 1980 |
if len(bvp_stream) >= int(5 * fps):
|
| 1981 |
+
seg = np.asarray(
|
| 1982 |
+
bvp_stream[-int(10 * fps):], dtype=np.float32
|
| 1983 |
+
)
|
| 1984 |
_, last_bpm = postprocess_bvp(seg, fs=fps)
|
| 1985 |
last_rmssd = compute_rmssd(seg, fs=fps)
|
| 1986 |
|
|
|
|
| 1992 |
last_infer = frame_idx
|
| 1993 |
|
| 1994 |
else:
|
| 1995 |
+
cv2.putText(
|
| 1996 |
+
vis_frame,
|
| 1997 |
+
"No face detected",
|
| 1998 |
+
(20, 40),
|
| 1999 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 2000 |
+
0.7,
|
| 2001 |
+
(30, 200, 255),
|
| 2002 |
+
2,
|
| 2003 |
+
)
|
| 2004 |
|
| 2005 |
+
# Log per-frame metrics into global FRAME_METRICS
|
| 2006 |
try:
|
| 2007 |
time_s = frame_idx / float(fps) if fps > 0 else float(frame_idx)
|
| 2008 |
except Exception:
|
|
|
|
| 2017 |
"roi_motion": float(roi_motion) if roi_motion is not None else None,
|
| 2018 |
})
|
| 2019 |
|
| 2020 |
+
# Pretty strings for UI
|
| 2021 |
+
gb_str = f"{global_brightness:.2f}" if global_brightness is not None else None
|
| 2022 |
+
gm_str = f"{global_motion:.2f}" if global_motion is not None else None
|
| 2023 |
+
rb_str = f"{roi_brightness:.2f}" if roi_brightness is not None else None
|
| 2024 |
+
rm_str = f"{roi_motion:.2f}" if roi_motion is not None else None
|
| 2025 |
+
|
| 2026 |
if last_bpm > 0:
|
| 2027 |
color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255)
|
| 2028 |
cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1)
|
| 2029 |
+
cv2.putText(
|
| 2030 |
+
vis_frame,
|
| 2031 |
+
f"HR: {last_bpm:.1f} BPM",
|
| 2032 |
+
(20, 48),
|
| 2033 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 2034 |
+
0.9,
|
| 2035 |
+
color,
|
| 2036 |
+
2,
|
| 2037 |
+
)
|
| 2038 |
|
| 2039 |
vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
|
| 2040 |
|
|
|
|
| 2105 |
elapsed = now - start_time
|
| 2106 |
status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
|
| 2107 |
|
| 2108 |
+
# Periodic UI update: 14 outputs
|
| 2109 |
yield (
|
| 2110 |
status,
|
| 2111 |
f"{last_bpm:.1f}" if last_bpm > 0 else None,
|
|
|
|
| 2116 |
str(signal_path) if signal_path.exists() else None,
|
| 2117 |
str(raw_path) if raw_path.exists() else None,
|
| 2118 |
str(post_path) if post_path.exists() else None,
|
| 2119 |
+
None, # CSV placeholder (filled at end)
|
| 2120 |
+
gb_str, # global brightness
|
| 2121 |
+
gm_str, # global motion
|
| 2122 |
+
rb_str, # ROI brightness
|
| 2123 |
+
rm_str # ROI motion
|
| 2124 |
)
|
| 2125 |
|
| 2126 |
next_display = now + (1.0 / DISPLAY_FPS)
|
|
|
|
| 2160 |
except Exception:
|
| 2161 |
pass
|
| 2162 |
|
| 2163 |
+
# Save per-frame illumination & motion metrics
|
| 2164 |
frame_metrics_path = None
|
| 2165 |
if FRAME_METRICS:
|
| 2166 |
try:
|
|
|
|
| 2171 |
print(f"[metrics] Failed to save frame metrics CSV: {e}")
|
| 2172 |
frame_metrics_path = None
|
| 2173 |
|
| 2174 |
+
# Decide which CSV to expose: metrics preferred, else BVP
|
| 2175 |
+
final_csv = frame_metrics_path or csv_path
|
| 2176 |
+
|
| 2177 |
elapsed = time.time() - start_time
|
| 2178 |
final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
|
| 2179 |
|
| 2180 |
+
# Final yield: 14 outputs
|
| 2181 |
yield (
|
| 2182 |
final_status,
|
| 2183 |
f"{last_bpm:.1f}" if last_bpm > 0 else None,
|
|
|
|
| 2188 |
str(signal_path) if signal_path.exists() else None,
|
| 2189 |
str(raw_path) if raw_path.exists() else None,
|
| 2190 |
str(post_path) if post_path.exists() else None,
|
| 2191 |
+
str(final_csv) if final_csv else None,
|
| 2192 |
+
None, # final global brightness (leave None or reuse gb_str)
|
| 2193 |
+
None, # final global motion
|
| 2194 |
+
None, # final ROI brightness
|
| 2195 |
+
None # final ROI motion
|
| 2196 |
)
|
| 2197 |
|
| 2198 |
+
|
| 2199 |
def process_stream(
|
| 2200 |
input_source: str,
|
| 2201 |
video_path: Optional[str],
|
|
|
|
| 2209 |
if input_source == "Live Webcam":
|
| 2210 |
yield from process_live_webcam(model_name, fps_input, roi_type, control_id)
|
| 2211 |
else:
|
| 2212 |
+
yield from process_video_file(
|
| 2213 |
+
video_path, gt_file, model_name, fps_input,
|
| 2214 |
+
max_seconds, roi_type, control_id
|
| 2215 |
+
)
|
| 2216 |
|
| 2217 |
def pause_processing(control_id: str) -> str:
|
| 2218 |
_, controls = ensure_controls(control_id)
|
|
|
|
| 2231 |
return "Stopped"
|
| 2232 |
|
| 2233 |
def reset_ui():
|
| 2234 |
+
# 14 values in total, matching all outputs from run_processing / process_stream
|
| 2235 |
+
return (
|
| 2236 |
+
"Ready", # status_text
|
| 2237 |
+
None, # hr_output
|
| 2238 |
+
None, # gt_hr_output
|
| 2239 |
+
None, # rmssd_output
|
| 2240 |
+
None, # frame_output
|
| 2241 |
+
None, # attention_output
|
| 2242 |
+
None, # signal_output
|
| 2243 |
+
None, # raw_signal_output
|
| 2244 |
+
None, # post_signal_output
|
| 2245 |
+
None, # csv_output
|
| 2246 |
+
None, # global_brightness_output
|
| 2247 |
+
None, # global_motion_output
|
| 2248 |
+
None, # roi_brightness_output
|
| 2249 |
+
None # roi_motion_output
|
| 2250 |
+
)
|
| 2251 |
|
| 2252 |
def handle_folder_upload(files):
|
| 2253 |
if not files:
|
|
|
|
| 2314 |
with gr.Column():
|
| 2315 |
video_upload = gr.Video(label="Upload Video", sources=["upload"])
|
| 2316 |
with gr.Column():
|
| 2317 |
+
gt_upload = gr.File(
|
| 2318 |
+
label="Ground Truth (optional)",
|
| 2319 |
+
file_types=[".txt", ".csv", ".json"]
|
| 2320 |
+
)
|
| 2321 |
|
| 2322 |
with gr.Row(visible=False) as folder_inputs:
|
| 2323 |
with gr.Column():
|
|
|
|
| 2382 |
)
|
| 2383 |
|
| 2384 |
with gr.Row():
|
| 2385 |
+
roi_dropdown = gr.Dropdown(
|
| 2386 |
+
choices=["auto", "forehead", "cheeks", "face"],
|
| 2387 |
+
value="auto",
|
| 2388 |
+
label="ROI"
|
| 2389 |
+
)
|
| 2390 |
|
| 2391 |
control_state = gr.State(value="")
|
| 2392 |
placeholder_state = gr.State(value=None)
|
|
|
|
| 2403 |
hr_output = gr.Textbox(label="HR (BPM)", interactive=False)
|
| 2404 |
gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False)
|
| 2405 |
rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False)
|
| 2406 |
+
|
| 2407 |
+
# NEW: illumination & motion front-end outputs
|
| 2408 |
+
with gr.Row():
|
| 2409 |
+
global_brightness_output = gr.Textbox(
|
| 2410 |
+
label="Global Brightness", interactive=False
|
| 2411 |
+
)
|
| 2412 |
+
global_motion_output = gr.Textbox(
|
| 2413 |
+
label="Global Motion", interactive=False
|
| 2414 |
+
)
|
| 2415 |
+
roi_brightness_output = gr.Textbox(
|
| 2416 |
+
label="ROI Brightness", interactive=False
|
| 2417 |
+
)
|
| 2418 |
+
roi_motion_output = gr.Textbox(
|
| 2419 |
+
label="ROI Motion", interactive=False
|
| 2420 |
+
)
|
| 2421 |
|
| 2422 |
with gr.Row():
|
| 2423 |
with gr.Column():
|
|
|
|
| 2456 |
).then(
|
| 2457 |
reset_ui,
|
| 2458 |
inputs=None,
|
| 2459 |
+
outputs=[
|
| 2460 |
+
status_text,
|
| 2461 |
+
hr_output,
|
| 2462 |
+
gt_hr_output,
|
| 2463 |
+
rmssd_output,
|
| 2464 |
+
frame_output,
|
| 2465 |
+
attention_output,
|
| 2466 |
+
signal_output,
|
| 2467 |
+
raw_signal_output,
|
| 2468 |
+
post_signal_output,
|
| 2469 |
+
csv_output,
|
| 2470 |
+
global_brightness_output,
|
| 2471 |
+
global_motion_output,
|
| 2472 |
+
roi_brightness_output,
|
| 2473 |
+
roi_motion_output,
|
| 2474 |
+
]
|
| 2475 |
)
|
| 2476 |
|
| 2477 |
+
def run_processing(
|
| 2478 |
+
input_source,
|
| 2479 |
+
video_upload,
|
| 2480 |
+
gt_upload,
|
| 2481 |
+
folder_video,
|
| 2482 |
+
folder_gt,
|
| 2483 |
+
model_name,
|
| 2484 |
+
fps,
|
| 2485 |
+
max_sec,
|
| 2486 |
+
roi,
|
| 2487 |
+
ctrl_id
|
| 2488 |
+
):
|
| 2489 |
+
"""Wrapper that resolves paths and streams from process_stream."""
|
| 2490 |
|
| 2491 |
if isinstance(model_name, int):
|
| 2492 |
model_name = str(model_name)
|
| 2493 |
|
| 2494 |
if not model_name:
|
| 2495 |
+
# must return 14 outputs to match UI wiring
|
| 2496 |
+
yield (
|
| 2497 |
+
"ERROR: No model selected",
|
| 2498 |
+
None, None, None, # HR, GT HR, RMSSD
|
| 2499 |
+
None, None, None, # frame, attention, signal
|
| 2500 |
+
None, None, # raw, post
|
| 2501 |
+
None, # CSV
|
| 2502 |
+
None, None, None, None # brightness & motion fields
|
| 2503 |
+
)
|
| 2504 |
return
|
| 2505 |
|
| 2506 |
if input_source == "Video File":
|
|
|
|
| 2512 |
else: # Live Webcam
|
| 2513 |
video_path, gt_file = None, None
|
| 2514 |
|
| 2515 |
+
# This yields 14-tuples from process_video_file / process_live_webcam
|
| 2516 |
yield from process_stream(
|
| 2517 |
input_source, video_path, gt_file,
|
| 2518 |
model_name, fps, max_sec, roi, ctrl_id
|
| 2519 |
)
|
| 2520 |
|
|
|
|
| 2521 |
run_btn.click(
|
| 2522 |
fn=run_processing,
|
| 2523 |
inputs=[
|
|
|
|
| 2542 |
signal_output,
|
| 2543 |
raw_signal_output,
|
| 2544 |
post_signal_output,
|
| 2545 |
+
csv_output,
|
| 2546 |
+
global_brightness_output,
|
| 2547 |
+
global_motion_output,
|
| 2548 |
+
roi_brightness_output,
|
| 2549 |
+
roi_motion_output,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2550 |
]
|
| 2551 |
)
|
| 2552 |
|