swetchareddytukkani commited on
Commit
e8fb09f
·
verified ·
1 Parent(s): 109d0cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +431 -200
app.py CHANGED
@@ -24,7 +24,6 @@ import torch
24
  import torch.nn as nn
25
  import torch.nn.functional as F
26
 
27
-
28
  from scipy import signal
29
  from scipy.signal import find_peaks, welch, get_window
30
 
@@ -43,7 +42,8 @@ from matplotlib.gridspec import GridSpec
43
 
44
  import gradio as gr
45
 
46
- FRAME_METRICS = []
 
47
 
48
  class PhysMambaattention_viz:
49
  """Simplified Grad-CAM for PhysMamba."""
@@ -228,55 +228,86 @@ class PhysMambaattention_viz:
228
 
229
 
230
  def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray:
231
- """Apply DiffNormalized preprocessing from PhysMamba paper."""
 
 
 
 
 
 
 
 
 
 
 
232
  if len(frames) < 2:
233
- return np.zeros((len(frames), *frames[0].shape), dtype=np.float32)
234
-
 
235
  diff_frames = []
236
-
237
  for i in range(len(frames)):
238
  if i == 0:
239
  diff_frames.append(np.zeros_like(frames[0], dtype=np.float32))
240
  else:
241
  curr = frames[i].astype(np.float32)
242
- prev = frames[i-1].astype(np.float32)
243
- denominator = curr + prev + 1e-8
244
- diff = (curr - prev) / denominator
245
  diff_frames.append(diff)
246
-
247
- diff_array = np.stack(diff_frames)
248
- std = diff_array.std()
249
- if std > 0:
250
- diff_array = diff_array / std
251
-
252
  return diff_array
253
 
254
 
255
- def preprocess_for_physmamba(frames: List[np.ndarray],
256
- target_frames: int = 128,
257
- target_size: int = 128) -> torch.Tensor:
258
- """Complete preprocessing pipeline for PhysMamba model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  if len(frames) < target_frames:
260
  frames = frames + [frames[-1]] * (target_frames - len(frames))
261
  elif len(frames) > target_frames:
262
- indices = np.linspace(0, len(frames)-1, target_frames).astype(int)
263
- frames = [frames[i] for i in indices]
264
-
265
- frames_rgb = [f[..., ::-1].copy() for f in frames]
 
266
  frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb]
267
- frames_diff = apply_diff_normalized(frames_resized)
268
- frames_transposed = np.transpose(frames_diff, (3, 0, 1, 2))
269
- frames_batched = np.expand_dims(frames_transposed, axis=0)
270
- tensor = torch.from_numpy(frames_batched.astype(np.float32))
271
-
272
- return tensor
273
 
274
- HERE = Path(__file__).parent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  MODEL_DIR = HERE / "final_model_release"
276
  LOG_DIR = HERE / "logs"
277
  ANALYSIS_DIR = HERE / "analysis"
278
- for d in [MODEL_DIR, LOG_DIR, ANALYSIS_DIR]:
279
- d.mkdir(exist_ok=True)
280
 
281
  DEVICE = (
282
  torch.device("cuda") if torch.cuda.is_available()
@@ -284,7 +315,9 @@ DEVICE = (
284
  else torch.device("cpu")
285
  )
286
 
287
- FACE_CASCADE = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
 
 
288
 
289
  DEFAULT_SIZE = 128 # input H=W to model
290
  DEFAULT_T = 128 # clip length
@@ -302,22 +335,52 @@ MAX_JUMP_BPM = 8.0
302
  GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"}
303
  GT_EXTS = {".txt", ".csv", ".json"}
304
 
 
305
  def _as_path(maybe) -> Optional[str]:
306
- """Return a filesystem path from Gradio values (str, dict, tempfile objects)."""
 
 
 
 
 
 
 
 
 
307
  if maybe is None:
308
  return None
 
 
 
 
 
 
 
309
  if isinstance(maybe, str):
310
  return maybe
 
 
 
 
 
311
  if isinstance(maybe, dict):
312
- return maybe.get("name") or maybe.get("path")
313
- name = getattr(maybe, "name", None) # tempfile-like object
314
- if isinstance(name, str):
 
 
 
 
 
 
315
  return name
 
316
  try:
317
  return str(maybe)
318
  except Exception:
319
  return None
320
 
 
321
  def _import_from_file(py_path: Path):
322
  spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path))
323
  if not spec or not spec.loader:
@@ -326,27 +389,35 @@ def _import_from_file(py_path: Path):
326
  spec.loader.exec_module(mod)
327
  return mod
328
 
 
329
  def _looks_like_video(p: Path) -> bool:
 
 
 
 
330
  if p.suffix.lower() == ".mat":
331
  return True
332
  return p.suffix.lower() in VIDEO_EXTENSIONS
333
 
 
334
  class SimpleActivationAttention:
335
- """Lightweight attention visualization without gradients."""
336
 
337
  def __init__(self, model: nn.Module, device: torch.device):
338
  self.model = model
339
  self.device = device
340
- self.activations = None
341
- self.hook_handle = None
342
 
343
  def _activation_hook(self, module, input, output):
344
  """Capture activations during forward pass."""
345
- self.activations = output.detach()
 
 
 
346
 
347
  def register_hook(self):
348
- """Register hook on a suitable layer."""
349
- # Find the last convolutional layer before Mamba
350
  target = None
351
  target_name = None
352
 
@@ -354,7 +425,7 @@ class SimpleActivationAttention:
354
  if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower():
355
  target = module
356
  target_name = name
357
-
358
  if target is None:
359
  print("⚠ [attention_viz] No suitable conv layer found, attention disabled")
360
  return
@@ -363,30 +434,36 @@ class SimpleActivationAttention:
363
  print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})")
364
 
365
  def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]:
366
- """Generate attention map from activations (call after forward pass)."""
 
 
 
 
367
  try:
368
  if self.activations is None:
369
  return None
370
 
371
- # Process activations to create spatial attention
372
  act = self.activations
373
-
374
  # Handle different tensor shapes
375
  if act.dim() == 5: # [B, C, T, H, W]
376
- # Average over time and channels
377
- attention = act.mean(dim=[1, 2]) # -> [B, H, W]
378
  elif act.dim() == 4: # [B, C, H, W]
379
  attention = act.mean(dim=1) # -> [B, H, W]
380
  else:
381
  print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}")
382
  return None
383
 
384
- # Convert to numpy
385
- attention = attention.squeeze().cpu().numpy()
386
-
387
  # Normalize to [0, 1]
388
- if attention.max() > attention.min():
389
- attention = (attention - attention.min()) / (attention.max() - attention.min())
 
 
 
390
 
391
  return attention
392
 
@@ -396,16 +473,22 @@ class SimpleActivationAttention:
396
 
397
  def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
398
  """Overlay heatmap on frame."""
 
 
 
399
  h, w = frame.shape[:2]
400
  heatmap_resized = cv2.resize(heatmap, (w, h))
401
  heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
402
  heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
403
- overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
404
  return overlay
405
 
406
  def cleanup(self):
407
  if self.hook_handle is not None:
408
  self.hook_handle.remove()
 
 
 
409
 
410
  class VideoReader:
411
  """
@@ -421,6 +504,7 @@ class VideoReader:
421
  self._idx = 0
422
  self._len = 0
423
  self._shape = None
 
424
 
425
  if self.path.lower().endswith(".mat") and MAT_SUPPORT:
426
  self._open_mat(self.path)
@@ -433,6 +517,7 @@ class VideoReader:
433
  raise RuntimeError("Cannot open video")
434
  self._cap = cap
435
  self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
 
436
 
437
  def _open_mat(self, path: str):
438
  try:
@@ -444,6 +529,7 @@ class VideoReader:
444
  break
445
  else:
446
  arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None)
 
447
  if arr is None:
448
  raise RuntimeError("No ndarray found in .mat")
449
 
@@ -451,16 +537,18 @@ class VideoReader:
451
  # Normalize to (T,H,W,3)
452
  if a.ndim == 4:
453
  if a.shape[-1] == 3:
454
- if a.shape[0] < a.shape[2]: # (T,H,W,3) heuristic
 
455
  v = a
456
- else: # (H,W,T,3) -> (T,H,W,3)
457
  v = np.transpose(a, (2, 0, 1, 3))
458
  else:
459
- v = a[..., :1] # take first channel
460
  elif a.ndim == 3:
461
- if a.shape[0] < a.shape[2]: # (T,H,W)
 
462
  v = a
463
- else: # (H,W,T) -> (T,H,W)
464
  v = np.transpose(a, (2, 0, 1))
465
  v = v[..., None]
466
  else:
@@ -473,6 +561,7 @@ class VideoReader:
473
  self._mat = v
474
  self._len = v.shape[0]
475
  self._shape = v.shape[1:3]
 
476
  except Exception as e:
477
  raise RuntimeError(f"Failed to open .mat video: {e}")
478
 
@@ -489,9 +578,11 @@ class VideoReader:
489
 
490
  def fps(self, fallback: int = 30) -> int:
491
  if self._mat is not None:
492
- return fallback # .mat typically lacks FPS; caller can override
 
 
493
  f = self._cap.get(cv2.CAP_PROP_FPS)
494
- return int(f) if f and f > 0 else fallback
495
 
496
  def length(self) -> int:
497
  return self._len
@@ -500,6 +591,7 @@ class VideoReader:
500
  if self._cap is not None:
501
  self._cap.release()
502
 
 
503
  def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]:
504
  x, y, w, h = face
505
  # forehead
@@ -510,23 +602,25 @@ def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[s
510
  ff = frame[y:y + h, x:x + w]
511
  return {"forehead": fh, "cheeks": ck, "face": ff}
512
 
 
513
  def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float:
514
  if patch is None or patch.size == 0:
515
  return -1e9
516
  g = patch[..., 1].astype(np.float32) / 255.0 # green channel
517
  g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling
518
  g = g - g.mean()
519
- b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band")
520
  try:
 
521
  y = signal.filtfilt(b, a, g, method="gust")
522
  except Exception:
523
  y = g
524
  return float((y ** 2).mean())
525
 
 
526
  def pick_auto_roi(face: Tuple[int, int, int, int],
527
- frame: np.ndarray,
528
- attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]:
529
- """Simple ROI selection."""
530
  cands = roi_candidates(face, frame)
531
  scores = {k: roi_quality_score(v) for k, v in cands.items()}
532
 
@@ -539,14 +633,15 @@ def pick_auto_roi(face: Tuple[int, int, int, int],
539
  ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0
540
  ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0
541
  scores['forehead'] += fh_attn * 0.2
542
- scores['cheeks'] += ck_attn * 0.2
543
- scores['face'] += ff_attn * 0.2
544
  except Exception:
545
  pass
546
 
547
  best = max(scores, key=scores.get)
548
  return cands[best], best
549
 
 
550
  def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
551
  """
552
  Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None).
@@ -601,6 +696,7 @@ def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
601
  uniq.append((v, g))
602
  return uniq
603
 
 
604
  def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"):
605
  import inspect
606
 
@@ -639,7 +735,8 @@ def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: s
639
  except Exception:
640
  continue
641
 
642
- raise ImportError(f"Could not find PhysMamba model class")
 
643
 
644
  def load_physmamba_model(ckpt_path: Path, device: torch.device,
645
  model_file: str = "", model_class: str = "PhysMamba"):
@@ -680,6 +777,7 @@ def load_physmamba_model(ckpt_path: Path, device: torch.device,
680
  state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
681
  model.load_state_dict(state_dict, strict=False)
682
  except Exception:
 
683
  pass
684
 
685
  model.to(device).eval()
@@ -688,13 +786,15 @@ def load_physmamba_model(ckpt_path: Path, device: torch.device,
688
  with torch.no_grad():
689
  _ = model(torch.zeros(1, 3, 8, 128, 128, device=device))
690
  except Exception:
 
691
  pass
692
 
693
- # Disable attention visualization since model forward pass is incompatible
694
  attention_viz = None
695
 
696
  return model, attention_viz
697
 
 
698
  def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray:
699
  """
700
  Stable band-pass with edge-safety and parameter clipping.
@@ -712,12 +812,12 @@ def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float =
712
 
713
  try:
714
  b, a = signal.butter(order, [lo, hi], btype="band")
715
- # padlen must be < len(x); reduce when short
716
  padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1))
717
  return signal.filtfilt(b, a, x, padlen=padlen)
718
  except Exception:
719
  return x
720
 
 
721
  def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float:
722
  """
723
  HR (BPM) via Welch PSD peak in [lo, hi] Hz.
@@ -726,7 +826,6 @@ def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5)
726
  if x.size < int(fs * 4.0): # need ~4s for a usable PSD
727
  return 0.0
728
  try:
729
- # nperseg tuned for short windows while avoiding tiny segments
730
  nper = int(min(max(64, fs * 2), min(512, x.size)))
731
  f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant")
732
 
@@ -741,11 +840,11 @@ def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5)
741
 
742
  fpk = float(f_band[np.argmax(p_band)])
743
  bpm = fpk * 60.0
744
- # clip to plausible range
745
  return float(np.clip(bpm, 30.0, 220.0))
746
  except Exception:
747
  return 0.0
748
 
 
749
  def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
750
  """
751
  HRV RMSSD from peaks; robust to short/flat segments.
@@ -754,7 +853,6 @@ def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
754
  if x.size < int(fs * 5.0):
755
  return 0.0
756
  try:
757
- # peak distance ~ 0.5s minimum (avoid double counting)
758
  peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs)))
759
  if len(peaks) < 3:
760
  return 0.0
@@ -765,6 +863,7 @@ def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
765
  except Exception:
766
  return 0.0
767
 
 
768
  def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
769
  """
770
  Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band.
@@ -787,7 +886,6 @@ def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
787
  lo, hi = REST_HR_RANGE
788
  if hr < lo or hr > hi:
789
  dist = abs(hr - REST_HR_TARGET)
790
- # farther away -> stronger pull
791
  alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65))
792
  hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET
793
 
@@ -802,6 +900,7 @@ def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
802
 
803
  return y_filt, float(hr)
804
 
 
805
  def draw_face_and_roi(frame_bgr: np.ndarray,
806
  face_bbox: Optional[Tuple[int, int, int, int]],
807
  roi_bbox: Optional[Tuple[int, int, int, int]],
@@ -820,6 +919,7 @@ def draw_face_and_roi(frame_bgr: np.ndarray,
820
  cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2)
821
  return vis
822
 
 
823
  def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
824
  roi_type: str,
825
  frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
@@ -845,6 +945,7 @@ def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
845
  return (0, 0, 0, 0)
846
  return (rx, ry, rx2 - rx, ry2 - ry)
847
 
 
848
  def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
849
  """
850
  Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR.
@@ -853,9 +954,7 @@ def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
853
  if chw is None or chw.ndim != 3 or chw.shape[0] != 3:
854
  return np.zeros((128, 128, 3), dtype=np.uint8)
855
 
856
- # Undo channel-first & normalization to a viewable image
857
  img = chw.copy()
858
- # Re-normalize to 0..1 by min-max of the tensor to "show" contrast
859
  vmin, vmax = float(img.min()), float(img.max())
860
  if vmax <= vmin + 1e-6:
861
  img = np.zeros_like(img)
@@ -865,6 +964,7 @@ def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
865
  img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR
866
  return img
867
 
 
868
  def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
869
  if gt_len <= 1:
870
  return None
@@ -872,6 +972,7 @@ def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
872
  return np.arange(gt_len, dtype=float) / float(gt_fs)
873
  return None # will fall back to length-matching overlay
874
 
 
875
  def plot_signals_with_gt(time_axis: np.ndarray,
876
  raw_signal: np.ndarray,
877
  post_signal: np.ndarray,
@@ -916,20 +1017,17 @@ def plot_signals_with_gt(time_axis: np.ndarray,
916
  t_new = _np.asarray(t_new, dtype=float).ravel()
917
 
918
  if x_t.size < 2 or y.size != x_t.size:
919
- # Fallback: length-based resize to t_new length
920
  if y.size == 0 or t_new.size == 0:
921
  return _np.zeros_like(t_new)
922
  idx = _np.linspace(0, y.size - 1, num=t_new.size)
923
  return _np.interp(_np.arange(t_new.size), idx, y)
924
 
925
- # Enforce strictly increasing time (dedup if needed)
926
  order = _np.argsort(x_t)
927
  x_t = x_t[order]
928
  y = y[order]
929
  mask = _np.concatenate(([True], _np.diff(x_t) > 0))
930
  x_t = x_t[mask]
931
  y = y[mask]
932
- # Clip t_new to the valid domain to avoid edge extrapolation artifacts
933
  t_clip = _np.clip(t_new, x_t[0], x_t[-1])
934
  return _np.interp(t_clip, x_t, y)
935
 
@@ -941,9 +1039,7 @@ def plot_signals_with_gt(time_axis: np.ndarray,
941
  n = int(min(len(x), len(y)))
942
  x = x[:n]; y = y[:n]
943
  max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local))))
944
- # valid lags: negative means GT should be shifted left (advance) relative to Pred
945
  lags = _np.arange(-max_lag, max_lag + 1)
946
- # compute correlation for each lag
947
  best_corr = -_np.inf
948
  best_lag = 0
949
  for L in lags:
@@ -975,10 +1071,8 @@ def plot_signals_with_gt(time_axis: np.ndarray,
975
  out = _np.empty_like(y)
976
  out[:] = _np.nan
977
  if shift > 0:
978
- # delay: move content right
979
  out[shift:] = y[:-shift]
980
  else:
981
- # advance: move content left
982
  out[:shift] = y[-shift:]
983
  return out
984
 
@@ -986,7 +1080,6 @@ def plot_signals_with_gt(time_axis: np.ndarray,
986
  raw = _np.asarray(raw_signal, dtype=float)
987
  post = _np.asarray(post_signal, dtype=float)
988
 
989
- # guard
990
  if t.size == 0:
991
  t = _np.arange(post.size, dtype=float) / max(fs, 1)
992
 
@@ -1003,21 +1096,17 @@ def plot_signals_with_gt(time_axis: np.ndarray,
1003
  gt_t = _np.asarray(gt_time, dtype=float).ravel()
1004
  gt_on_pred = _safe_interp(gt_t, gt, t)
1005
  else:
1006
- # No time vector: try length-based mapping to pred time grid
1007
- gt_on_pred = _safe_interp(_np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size),
1008
- gt, t)
 
1009
 
1010
- # Band-limit both before correlation/HR
1011
  pred_bp = _bandpass(post, fs)
1012
  gt_bp = _bandpass(gt_on_pred, fs)
1013
 
1014
- # Estimate best lag (sec) of GT relative to Pred
1015
  lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0)
1016
-
1017
- # Apply lag to GT for visualization and correlation
1018
  gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs)
1019
 
1020
- # Compute Pearson r on overlapping valid samples
1021
  valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp)
1022
  if valid.sum() >= 16:
1023
  pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1])
@@ -1026,25 +1115,21 @@ def plot_signals_with_gt(time_axis: np.ndarray,
1026
 
1027
  hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs)
1028
 
1029
-
1030
  _plt.figure(figsize=(13, 6), dpi=110)
1031
  gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35)
1032
 
1033
- # (1) Raw Pred
1034
  ax1 = _plt.subplot(gs[0, 0])
1035
  ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5)
1036
  ax1.set_title(f"Predicted (Raw) — fs={fs} Hz")
1037
  ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude")
1038
  ax1.grid(True, alpha=0.3)
1039
 
1040
- # (2) Post Pred
1041
  ax2 = _plt.subplot(gs[0, 1])
1042
  ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5)
1043
  ax2.set_title("Predicted (Post-processed)")
1044
  ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude")
1045
  ax2.grid(True, alpha=0.3)
1046
 
1047
- # (3) Overlay Pred vs GT (z-scored) OR just post
1048
  ax3 = _plt.subplot(gs[1, :])
1049
  ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6)
1050
 
@@ -1053,7 +1138,6 @@ def plot_signals_with_gt(time_axis: np.ndarray,
1053
  gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs)
1054
  ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9)
1055
 
1056
- # metrics box
1057
  txt = [
1058
  f"HR_pred: {hr_pred:.1f} BPM",
1059
  f"HR_gt: {hr_gt:.1f} BPM",
@@ -1088,16 +1172,23 @@ def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
1088
  if frame is None or frame.size == 0:
1089
  return None
1090
 
 
 
 
 
1091
  try:
1092
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
1093
  except Exception:
1094
  # If color conversion fails, assume already gray
1095
- gray = frame.copy() if frame.ndim == 2 else cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY)
 
 
 
1096
 
1097
  # Light preproc to improve Haar performance
1098
  gray = cv2.equalizeHist(gray)
1099
 
1100
- faces_all = []
1101
  # Try a couple of parameter combos to be more forgiving
1102
  params = [
1103
  dict(scaleFactor=1.05, minNeighbors=3),
@@ -1118,6 +1209,7 @@ def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
1118
  # Return the largest (by area)
1119
  return max(faces_all, key=lambda f: f[2] * f[3])
1120
 
 
1121
  def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]:
1122
  """
1123
  Crop ROI from the frame based on a face bbox and the selected roi_type.
@@ -1151,9 +1243,15 @@ def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndar
1151
  return None
1152
  return roi
1153
 
1154
- def crop_roi_with_bbox(face_bbox: Tuple[int, int, int, int],
1155
- roi_type: str,
1156
- frame: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]:
 
 
 
 
 
 
1157
  if face_bbox is None or frame is None or frame.size == 0:
1158
  return None, None
1159
 
@@ -1180,6 +1278,7 @@ def crop_roi_with_bbox(face_bbox: Tuple[int, int, int, int],
1180
 
1181
  return roi, (rx, ry, rx2 - rx, ry2 - ry)
1182
 
 
1183
  def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
1184
  """
1185
  PhysMamba-compatible normalization with DiffNormalized support.
@@ -1204,21 +1303,28 @@ def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
1204
  chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False)
1205
  return chw
1206
 
1207
- def extract_attention_map(model, clip_tensor: torch.Tensor,
1208
- attention_viz) -> Optional[np.ndarray]:
1209
  """Attention visualization disabled - model architecture incompatible."""
1210
  return None
1211
 
1212
- def create_attention_overlay(frame: np.ndarray, attention: Optional[np.ndarray],
1213
- attention_viz: Optional[SimpleActivationAttention] = None) -> np.ndarray:
1214
- """Create attention heatmap overlay."""
1215
- # Attention disabled - return original frame
 
 
 
1216
  return frame
1217
 
 
1218
  def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
1219
  H, W = roi_bgr.shape[:2]
1220
- base_bvp = forward_bvp(model, torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE))
1221
- .unsqueeze(0).unsqueeze(2).to(DEVICE)) # fake T=1 path if needed
 
 
 
1222
  base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs)
1223
 
1224
  heat = np.zeros((H, W), np.float32)
@@ -1226,15 +1332,20 @@ def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
1226
  for x in range(0, W - patch + 1, stride):
1227
  tmp = roi_bgr.copy()
1228
  tmp[y:y+patch, x:x+patch] = 127 # occlude
1229
- bvp = forward_bvp(model, torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE))
1230
- .unsqueeze(0).unsqueeze(2).to(DEVICE))
 
 
 
1231
  power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs)
1232
  drop = max(0.0, base_power - power)
1233
  heat[y:y+patch, x:x+patch] += drop
1234
  heat -= heat.min()
1235
- if heat.max() > 1e-8: heat /= heat.max()
 
1236
  return heat
1237
 
 
1238
  def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
1239
  """
1240
  Try common 5D layouts:
@@ -1251,6 +1362,7 @@ def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
1251
  last_err = e
1252
  raise last_err
1253
 
 
1254
  def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
1255
  """
1256
  Forward and extract a 1D time-like BVP vector with length T_clip.
@@ -1297,7 +1409,8 @@ def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
1297
  B, K = out.shape
1298
  if B == 1:
1299
  v = out[0]
1300
- return (v.numpy() if v.shape[0] == T_clip else np.resize(v.numpy(), T_clip).astype(np.float32))
 
1301
  if B == T_clip:
1302
  return out[:, 0].numpy()
1303
  if K == T_clip:
@@ -1342,6 +1455,7 @@ def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
1342
  val = float(out.mean().item()) if out.numel() else 0.0
1343
  return np.full(T_clip, val, dtype=np.float32)
1344
 
 
1345
  def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
1346
  """
1347
  Classical rPPG from green-channel means when the model yields nothing.
@@ -1366,19 +1480,25 @@ def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
1366
  std = float(np.std(y)) + 1e-6
1367
  return (y / std).astype(np.float32)
1368
 
 
1369
  def _to_floats(s: str) -> List[float]:
1370
  """
1371
  Extract all real numbers from free-form text, including scientific notation.
1372
- Gracefully ignores 'nan', 'inf', units, and comments.
1373
  """
1374
  if not isinstance(s, str) or not s:
1375
  return []
1376
 
 
1377
  s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE)
1378
 
 
1379
  s = s.replace(",", " ").replace(";", " ")
1380
 
1381
- toks = re.findall(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?", s)
 
 
 
1382
  out: List[float] = []
1383
  for t in toks:
1384
  try:
@@ -1418,7 +1538,6 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1418
  diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
1419
  return (1.0 / float(np.median(diffs))) if diffs.size else 0.0
1420
 
1421
-
1422
  def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float:
1423
  if bvp is None or bvp.size == 0:
1424
  return 0.0
@@ -1426,9 +1545,16 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1426
  bp = bandpass_filter(bvp.astype(float), fs=fs_use)
1427
  return hr_from_welch(bp, fs=fs_use)
1428
 
1429
- if p.name.lower() == "ground_truth.txt" or (ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2):
 
 
 
1430
  try:
1431
- lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines() if ln.strip()]
 
 
 
 
1432
  ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else []
1433
  hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else []
1434
  t_vals = _to_floats(lines[2]) if len(lines) >= 3 else []
@@ -1452,7 +1578,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1452
  # Fall through to generic handlers
1453
  pass
1454
 
1455
-
1456
  if ext == ".txt":
1457
  try:
1458
  nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore"))
@@ -1462,12 +1588,10 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1462
  except Exception:
1463
  return np.array([]), 0.0, 0.0
1464
 
1465
-
1466
  if ext == ".json":
1467
  try:
1468
  data = json.loads(p.read_text(encoding="utf-8", errors="ignore"))
1469
- # Try several paths for BVP array
1470
- bvp = None
1471
 
1472
  def _seek(obj, keys):
1473
  for k in keys:
@@ -1475,9 +1599,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1475
  return obj[k]
1476
  return None
1477
 
1478
- # Direct top-level
1479
  bvp = _seek(data, ("ppg", "bvp", "signal", "wave"))
1480
- # Common nested containers
1481
  if bvp is None:
1482
  for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"):
1483
  if container_key in data:
@@ -1491,7 +1613,6 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1491
  else:
1492
  bvp = np.array([], dtype=float)
1493
 
1494
- # fs / hr (accept scalar or array)
1495
  fs_hint = 0.0
1496
  if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0:
1497
  fs_hint = float(data["fs"])
@@ -1507,11 +1628,10 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1507
  except Exception:
1508
  return np.array([]), 0.0, 0.0
1509
 
1510
-
1511
  if ext == ".csv":
1512
  try:
1513
  df = pd.read_csv(p)
1514
- # Normalize column names
1515
  cols = {str(c).strip().lower(): c for c in df.columns}
1516
 
1517
  def _first_match(names):
@@ -1537,18 +1657,16 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1537
  except Exception:
1538
  return np.array([]), 0.0, 0.0
1539
 
1540
-
1541
  if ext == ".mat":
1542
  try:
1543
  md = loadmat(str(p))
1544
- # look for most likely array
1545
  arr = None
1546
  for key in ("ppg", "bvp", "signal", "wave"):
1547
  if key in md and isinstance(md[key], np.ndarray):
1548
  arr = md[key]
1549
  break
1550
  if arr is None:
1551
- # fallback: first 1-D array
1552
  for v in md.values():
1553
  if isinstance(v, np.ndarray) and v.ndim == 1:
1554
  arr = v
@@ -1582,7 +1700,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1582
  try:
1583
  bvp = np.asarray(np.load(str(p)), dtype=float).ravel()
1584
  fs_hint, hr = 0.0, 0.0
1585
- # optional sidecar JSON (same stem) with fs/hr
1586
  sidecar = p.with_suffix(".json")
1587
  if sidecar.exists():
1588
  try:
@@ -1594,6 +1712,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1594
  hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
1595
  except Exception:
1596
  pass
 
1597
  if hr == 0.0 and bvp.size:
1598
  hr = _hr_from_bvp(bvp, fs_hint)
1599
  return bvp, hr, fs_hint
@@ -1603,6 +1722,7 @@ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1603
  # Fallback (unsupported extension)
1604
  return np.array([]), 0.0, 0.0
1605
 
 
1606
  def scan_models() -> List[str]:
1607
  if not MODEL_DIR.exists():
1608
  return []
@@ -1614,8 +1734,10 @@ def scan_models() -> List[str]:
1614
 
1615
  return models
1616
 
 
1617
  _GLOBAL_CONTROLS: Dict[str, Dict] = {}
1618
 
 
1619
  def ensure_controls(control_id: str) -> Tuple[str, Dict]:
1620
  # Use a stable default so Pause/Resume/Stop work for the current run
1621
  if not control_id:
@@ -1627,6 +1749,7 @@ def ensure_controls(control_id: str) -> Tuple[str, Dict]:
1627
  }
1628
  return control_id, _GLOBAL_CONTROLS[control_id]
1629
 
 
1630
  def process_video_file(
1631
  video_path: str,
1632
  gt_file: Optional[str],
@@ -1639,9 +1762,13 @@ def process_video_file(
1639
  """
1640
  Enhanced video processing with Grad-CAM attention visualization,
1641
  plus per-frame illumination and motion logging.
 
 
 
 
1642
  """
1643
  global _HR_SMOOTH
1644
- global FRAME_METRICS # NEW: use global frame metrics buffer
1645
  _HR_SMOOTH = None
1646
  FRAME_METRICS = [] # reset per run
1647
 
@@ -1655,8 +1782,19 @@ def process_video_file(
1655
  control_id, controls = ensure_controls(control_id)
1656
  controls['stop'].clear()
1657
 
 
 
 
 
 
 
 
 
 
 
 
1658
  if not model_name:
1659
- yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
1660
  return
1661
 
1662
  if isinstance(model_name, int):
@@ -1664,24 +1802,24 @@ def process_video_file(
1664
 
1665
  model_path = MODEL_DIR / model_name
1666
  if not model_path.exists():
1667
- yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None)
1668
  return
1669
 
1670
  try:
1671
  model, attention_viz = load_physmamba_model(model_path, DEVICE)
1672
  except Exception as e:
1673
- yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None)
1674
  return
1675
 
1676
  gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0)
1677
 
1678
  if not video_path or not os.path.exists(video_path):
1679
- yield ("ERROR: Video not found", None, None, None, None, None, None, None, None, None)
1680
  return
1681
 
1682
  cap = cv2.VideoCapture(video_path)
1683
  if not cap.isOpened():
1684
- yield ("ERROR: Cannot open video", None, None, None, None, None, None, None, None, None)
1685
  return
1686
 
1687
  fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30)
@@ -1698,7 +1836,7 @@ def process_video_file(
1698
  last_rmssd = 0.0
1699
  last_attention = None
1700
 
1701
- # NEW: previous grayscale frames for motion
1702
  prev_gray = None
1703
  prev_roi_gray = None
1704
 
@@ -1712,8 +1850,15 @@ def process_video_file(
1712
  raw_path = tmpdir / "raw_signal.png"
1713
  post_path = tmpdir / "post_signal.png"
1714
 
1715
- yield ("Starting… reading video frames", None, f"{gt_hr:.1f}" if gt_hr > 0 else "--",
1716
- None, None, None, None, None, None, None)
 
 
 
 
 
 
 
1717
 
1718
  while True:
1719
  if controls['stop'].is_set():
@@ -1730,13 +1875,13 @@ def process_video_file(
1730
 
1731
  frame_idx += 1
1732
 
1733
- # NEW: default per-frame metrics for this iteration
1734
  global_brightness = None
1735
  global_motion = None
1736
  roi_brightness = None
1737
  roi_motion = None
1738
 
1739
- # NEW: global illumination & motion (full frame)
1740
  try:
1741
  frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
1742
  except Exception:
@@ -1762,7 +1907,7 @@ def process_video_file(
1762
  roi = crop_roi(face, roi_type, frame)
1763
 
1764
  if roi is not None and roi.size > 0:
1765
- # NEW: ROI-level brightness & motion
1766
  try:
1767
  roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
1768
  except Exception:
@@ -1799,7 +1944,9 @@ def process_video_file(
1799
  try:
1800
  raw = forward_bvp(model, clip_t)
1801
  if isinstance(raw, np.ndarray):
1802
- raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
 
 
1803
  bvp_out = raw if raw.size > 0 else None
1804
  else:
1805
  bvp_out = None
@@ -1807,7 +1954,7 @@ def process_video_file(
1807
  print(f"[infer] forward_bvp error: {e}")
1808
  bvp_out = None
1809
 
1810
- # UPDATED: Generate attention with Grad-CAM
1811
  try:
1812
  last_attention = extract_attention_map(model, clip_t, attention_viz)
1813
  except Exception as e:
@@ -1815,7 +1962,9 @@ def process_video_file(
1815
  last_attention = None
1816
 
1817
  if bvp_out is None or bvp_out.size == 0:
1818
- gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0)
 
 
1819
  fb = _fallback_bvp_from_means(gbuf, fs=fps)
1820
  if isinstance(fb, np.ndarray) and fb.size > 0:
1821
  bvp_out = fb
@@ -1829,7 +1978,9 @@ def process_video_file(
1829
  bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
1830
 
1831
  if len(bvp_stream) >= int(5 * fps):
1832
- seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32)
 
 
1833
  _, last_bpm = postprocess_bvp(seg, fs=fps)
1834
  last_rmssd = compute_rmssd(seg, fs=fps)
1835
 
@@ -1841,10 +1992,17 @@ def process_video_file(
1841
  last_infer = frame_idx
1842
 
1843
  else:
1844
- cv2.putText(vis_frame, "No face detected", (20, 40),
1845
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2)
 
 
 
 
 
 
 
1846
 
1847
- # NEW: log per-frame metrics into global FRAME_METRICS
1848
  try:
1849
  time_s = frame_idx / float(fps) if fps > 0 else float(frame_idx)
1850
  except Exception:
@@ -1859,11 +2017,24 @@ def process_video_file(
1859
  "roi_motion": float(roi_motion) if roi_motion is not None else None,
1860
  })
1861
 
 
 
 
 
 
 
1862
  if last_bpm > 0:
1863
  color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255)
1864
  cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1)
1865
- cv2.putText(vis_frame, f"HR: {last_bpm:.1f} BPM", (20, 48),
1866
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
 
 
 
 
 
 
 
1867
 
1868
  vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
1869
 
@@ -1934,6 +2105,7 @@ def process_video_file(
1934
  elapsed = now - start_time
1935
  status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
1936
 
 
1937
  yield (
1938
  status,
1939
  f"{last_bpm:.1f}" if last_bpm > 0 else None,
@@ -1944,7 +2116,11 @@ def process_video_file(
1944
  str(signal_path) if signal_path.exists() else None,
1945
  str(raw_path) if raw_path.exists() else None,
1946
  str(post_path) if post_path.exists() else None,
1947
- None
 
 
 
 
1948
  )
1949
 
1950
  next_display = now + (1.0 / DISPLAY_FPS)
@@ -1984,7 +2160,7 @@ def process_video_file(
1984
  except Exception:
1985
  pass
1986
 
1987
- # NEW: save per-frame illumination & motion metrics
1988
  frame_metrics_path = None
1989
  if FRAME_METRICS:
1990
  try:
@@ -1995,9 +2171,13 @@ def process_video_file(
1995
  print(f"[metrics] Failed to save frame metrics CSV: {e}")
1996
  frame_metrics_path = None
1997
 
 
 
 
1998
  elapsed = time.time() - start_time
1999
  final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
2000
 
 
2001
  yield (
2002
  final_status,
2003
  f"{last_bpm:.1f}" if last_bpm > 0 else None,
@@ -2008,9 +2188,14 @@ def process_video_file(
2008
  str(signal_path) if signal_path.exists() else None,
2009
  str(raw_path) if raw_path.exists() else None,
2010
  str(post_path) if post_path.exists() else None,
2011
- str(csv_path) if csv_path else None
 
 
 
 
2012
  )
2013
 
 
2014
  def process_stream(
2015
  input_source: str,
2016
  video_path: Optional[str],
@@ -2024,8 +2209,10 @@ def process_stream(
2024
  if input_source == "Live Webcam":
2025
  yield from process_live_webcam(model_name, fps_input, roi_type, control_id)
2026
  else:
2027
- yield from process_video_file(video_path, gt_file, model_name, fps_input,
2028
- max_seconds, roi_type, control_id)
 
 
2029
 
2030
  def pause_processing(control_id: str) -> str:
2031
  _, controls = ensure_controls(control_id)
@@ -2044,7 +2231,23 @@ def stop_processing(control_id: str) -> str:
2044
  return "Stopped"
2045
 
2046
  def reset_ui():
2047
- return ("Ready", None, None, None, None, None, None, None, None, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2048
 
2049
  def handle_folder_upload(files):
2050
  if not files:
@@ -2111,8 +2314,10 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
2111
  with gr.Column():
2112
  video_upload = gr.Video(label="Upload Video", sources=["upload"])
2113
  with gr.Column():
2114
- gt_upload = gr.File(label="Ground Truth (optional)",
2115
- file_types=[".txt", ".csv", ".json"])
 
 
2116
 
2117
  with gr.Row(visible=False) as folder_inputs:
2118
  with gr.Column():
@@ -2177,7 +2382,11 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
2177
  )
2178
 
2179
  with gr.Row():
2180
- roi_dropdown = gr.Dropdown(choices=["auto","forehead","cheeks","face"], value="auto", label="ROI")
 
 
 
 
2181
 
2182
  control_state = gr.State(value="")
2183
  placeholder_state = gr.State(value=None)
@@ -2194,6 +2403,21 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
2194
  hr_output = gr.Textbox(label="HR (BPM)", interactive=False)
2195
  gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False)
2196
  rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2197
 
2198
  with gr.Row():
2199
  with gr.Column():
@@ -2232,20 +2456,51 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
2232
  ).then(
2233
  reset_ui,
2234
  inputs=None,
2235
- outputs=[status_text, hr_output, gt_hr_output, rmssd_output,
2236
- frame_output, attention_output, signal_output,
2237
- raw_signal_output, post_signal_output, csv_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
2238
  )
2239
 
2240
- def run_processing(input_source, video_upload, gt_upload, folder_video, folder_gt,
2241
- model_name, fps, max_sec, roi, ctrl_id):
2242
- """Fixed version that handles model_name type conversion."""
 
 
 
 
 
 
 
 
 
 
2243
 
2244
  if isinstance(model_name, int):
2245
  model_name = str(model_name)
2246
 
2247
  if not model_name:
2248
- yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
 
 
 
 
 
 
 
 
2249
  return
2250
 
2251
  if input_source == "Video File":
@@ -2257,12 +2512,12 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
2257
  else: # Live Webcam
2258
  video_path, gt_file = None, None
2259
 
 
2260
  yield from process_stream(
2261
  input_source, video_path, gt_file,
2262
  model_name, fps, max_sec, roi, ctrl_id
2263
  )
2264
 
2265
-
2266
  run_btn.click(
2267
  fn=run_processing,
2268
  inputs=[
@@ -2287,35 +2542,11 @@ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as
2287
  signal_output,
2288
  raw_signal_output,
2289
  post_signal_output,
2290
- csv_output
2291
- ]
2292
- )
2293
-
2294
-
2295
- run_btn.click(
2296
- fn=run_processing,
2297
- inputs=[
2298
- input_source,
2299
- video_upload,
2300
- folder_video,
2301
- folder_gt,
2302
- model_dropdown,
2303
- fps_slider,
2304
- max_seconds_slider,
2305
- roi_dropdown,
2306
- control_state
2307
- ],
2308
- outputs=[
2309
- status_text,
2310
- hr_output,
2311
- gt_hr_output,
2312
- rmssd_output,
2313
- frame_output,
2314
- attention_output,
2315
- signal_output,
2316
- raw_signal_output,
2317
- post_signal_output,
2318
- csv_output
2319
  ]
2320
  )
2321
 
 
24
  import torch.nn as nn
25
  import torch.nn.functional as F
26
 
 
27
  from scipy import signal
28
  from scipy.signal import find_peaks, welch, get_window
29
 
 
42
 
43
  import gradio as gr
44
 
45
+ # Global buffer for per-frame illumination/motion metrics
46
+ FRAME_METRICS: List[Dict] = []
47
 
48
  class PhysMambaattention_viz:
49
  """Simplified Grad-CAM for PhysMamba."""
 
228
 
229
 
230
  def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray:
231
+ """
232
+ Apply DiffNormalized preprocessing from the PhysMamba paper:
233
+
234
+ diff_t = (I_t - I_{t-1}) / (I_t + I_{t-1} + eps)
235
+ then global std-normalize.
236
+
237
+ frames: list of HxWx3 uint8 or float32 arrays (RGB or BGR, consistent).
238
+ Returns: (T, H, W, C) float32.
239
+ """
240
+ if not frames:
241
+ return np.zeros((0,), dtype=np.float32)
242
+
243
  if len(frames) < 2:
244
+ f0 = frames[0].astype(np.float32)
245
+ return np.stack([np.zeros_like(f0, dtype=np.float32)], axis=0)
246
+
247
  diff_frames = []
 
248
  for i in range(len(frames)):
249
  if i == 0:
250
  diff_frames.append(np.zeros_like(frames[0], dtype=np.float32))
251
  else:
252
  curr = frames[i].astype(np.float32)
253
+ prev = frames[i - 1].astype(np.float32)
254
+ denom = curr + prev + 1e-8
255
+ diff = (curr - prev) / denom
256
  diff_frames.append(diff)
257
+
258
+ diff_array = np.stack(diff_frames).astype(np.float32)
259
+ std = float(diff_array.std()) + 1e-8
260
+ diff_array /= std
 
 
261
  return diff_array
262
 
263
 
264
+ def preprocess_for_physmamba(
265
+ frames: List[np.ndarray],
266
+ target_frames: int = 128,
267
+ target_size: int = 128
268
+ ) -> torch.Tensor:
269
+ """
270
+ Complete DiffNormalized preprocessing pipeline to produce
271
+ a PhysMamba-compatible clip tensor of shape [1, 3, T, H, W].
272
+
273
+ NOTE: This path is *not* used in the current live demo, which instead
274
+ uses normalize_frame() + forward_bvp(). Keep for future experiments.
275
+ """
276
+ if not frames:
277
+ # Dummy tensor; caller should guard length > 0
278
+ return torch.zeros(1, 3, target_frames, target_size, target_size, dtype=torch.float32)
279
+
280
+ # Temporal sampling / padding to target_frames
281
  if len(frames) < target_frames:
282
  frames = frames + [frames[-1]] * (target_frames - len(frames))
283
  elif len(frames) > target_frames:
284
+ idx = np.linspace(0, len(frames) - 1, target_frames).astype(int)
285
+ frames = [frames[i] for i in idx]
286
+
287
+ # Convert to RGB and resize
288
+ frames_rgb = [f[..., ::-1].copy() for f in frames] # BGR->RGB
289
  frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb]
 
 
 
 
 
 
290
 
291
+ # DiffNormalized
292
+ diff_array = apply_diff_normalized(frames_resized) # (T, H, W, C)
293
+
294
+ # To [B, C, T, H, W]
295
+ diff_array = np.transpose(diff_array, (3, 0, 1, 2)) # (C, T, H, W)
296
+ diff_array = np.expand_dims(diff_array, axis=0) # (1, C, T, H, W)
297
+
298
+ return torch.from_numpy(diff_array.astype(np.float32))
299
+
300
+
301
+ # ---------------------------------------------------------------------------
302
+ # Paths, device, constants
303
+ # ---------------------------------------------------------------------------
304
+
305
+ HERE = Path(__file__).resolve().parent
306
  MODEL_DIR = HERE / "final_model_release"
307
  LOG_DIR = HERE / "logs"
308
  ANALYSIS_DIR = HERE / "analysis"
309
+ for d in (MODEL_DIR, LOG_DIR, ANALYSIS_DIR):
310
+ d.mkdir(exist_ok=True, parents=True)
311
 
312
  DEVICE = (
313
  torch.device("cuda") if torch.cuda.is_available()
 
315
  else torch.device("cpu")
316
  )
317
 
318
+ FACE_CASCADE = cv2.CascadeClassifier(
319
+ cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
320
+ )
321
 
322
  DEFAULT_SIZE = 128 # input H=W to model
323
  DEFAULT_T = 128 # clip length
 
335
  GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"}
336
  GT_EXTS = {".txt", ".csv", ".json"}
337
 
338
+
339
  def _as_path(maybe) -> Optional[str]:
340
+ """
341
+ Return a filesystem path from Gradio values (str, dict, Path, tempfile objects, lists).
342
+
343
+ Handles:
344
+ - plain strings
345
+ - pathlib.Path
346
+ - Gradio dicts with keys: 'name', 'path', 'file'
347
+ - file-like objects with .name
348
+ - lists (takes first element)
349
+ """
350
  if maybe is None:
351
  return None
352
+
353
+ # Gradio can pass a list (e.g., multiple files / directory upload)
354
+ if isinstance(maybe, list):
355
+ if not maybe:
356
+ return None
357
+ return _as_path(maybe[0])
358
+
359
  if isinstance(maybe, str):
360
  return maybe
361
+
362
+ if isinstance(maybe, Path):
363
+ return str(maybe)
364
+
365
+ # Gradio v4 File/Video components often pass a dict
366
  if isinstance(maybe, dict):
367
+ for key in ("name", "path", "file"):
368
+ v = maybe.get(key)
369
+ if isinstance(v, str) and v:
370
+ return v
371
+ return None
372
+
373
+ # tempfile-like / UploadedFile objects
374
+ name = getattr(maybe, "name", None)
375
+ if isinstance(name, str) and name:
376
  return name
377
+
378
  try:
379
  return str(maybe)
380
  except Exception:
381
  return None
382
 
383
+
384
  def _import_from_file(py_path: Path):
385
  spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path))
386
  if not spec or not spec.loader:
 
389
  spec.loader.exec_module(mod)
390
  return mod
391
 
392
+
393
  def _looks_like_video(p: Path) -> bool:
394
+ """
395
+ Heuristic for 'video-like' files used in subject-folder discovery.
396
+ Treat .mat as video, plus common video extensions.
397
+ """
398
  if p.suffix.lower() == ".mat":
399
  return True
400
  return p.suffix.lower() in VIDEO_EXTENSIONS
401
 
402
+
403
  class SimpleActivationAttention:
404
+ """Lightweight attention visualization using forward activations (no gradients)."""
405
 
406
  def __init__(self, model: nn.Module, device: torch.device):
407
  self.model = model
408
  self.device = device
409
+ self.activations: Optional[torch.Tensor] = None
410
+ self.hook_handle: Optional[Any] = None
411
 
412
  def _activation_hook(self, module, input, output):
413
  """Capture activations during forward pass."""
414
+ try:
415
+ self.activations = output.detach()
416
+ except Exception:
417
+ self.activations = None
418
 
419
  def register_hook(self):
420
+ """Register hook on a suitable conv layer (last conv before Mamba if possible)."""
 
421
  target = None
422
  target_name = None
423
 
 
425
  if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower():
426
  target = module
427
  target_name = name
428
+
429
  if target is None:
430
  print("⚠ [attention_viz] No suitable conv layer found, attention disabled")
431
  return
 
434
  print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})")
435
 
436
  def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]:
437
+ """
438
+ Generate attention map from stored activations (call AFTER the forward pass).
439
+
440
+ Returns a 2D numpy array in [0,1] or None if unavailable.
441
+ """
442
  try:
443
  if self.activations is None:
444
  return None
445
 
 
446
  act = self.activations
447
+
448
  # Handle different tensor shapes
449
  if act.dim() == 5: # [B, C, T, H, W]
450
+ # Average over channels and time: -> [B, H, W]
451
+ attention = act.mean(dim=[1, 2])
452
  elif act.dim() == 4: # [B, C, H, W]
453
  attention = act.mean(dim=1) # -> [B, H, W]
454
  else:
455
  print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}")
456
  return None
457
 
458
+ # Take first batch
459
+ attention = attention[0].detach().cpu().numpy()
460
+
461
  # Normalize to [0, 1]
462
+ a_min, a_max = attention.min(), attention.max()
463
+ if a_max > a_min:
464
+ attention = (attention - a_min) / (a_max - a_min)
465
+ else:
466
+ attention = np.zeros_like(attention, dtype=np.float32)
467
 
468
  return attention
469
 
 
473
 
474
  def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
475
  """Overlay heatmap on frame."""
476
+ if heatmap is None or frame is None or frame.size == 0:
477
+ return frame
478
+
479
  h, w = frame.shape[:2]
480
  heatmap_resized = cv2.resize(heatmap, (w, h))
481
  heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
482
  heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
483
+ overlay = cv2.addWeighted(frame, 1 - alpha, heatmap_colored, alpha, 0)
484
  return overlay
485
 
486
  def cleanup(self):
487
  if self.hook_handle is not None:
488
  self.hook_handle.remove()
489
+ self.hook_handle = None
490
+ self.activations = None
491
+
492
 
493
  class VideoReader:
494
  """
 
504
  self._idx = 0
505
  self._len = 0
506
  self._shape = None
507
+ self._fps = 0
508
 
509
  if self.path.lower().endswith(".mat") and MAT_SUPPORT:
510
  self._open_mat(self.path)
 
517
  raise RuntimeError("Cannot open video")
518
  self._cap = cap
519
  self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
520
+ self._fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
521
 
522
  def _open_mat(self, path: str):
523
  try:
 
529
  break
530
  else:
531
  arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None)
532
+
533
  if arr is None:
534
  raise RuntimeError("No ndarray found in .mat")
535
 
 
537
  # Normalize to (T,H,W,3)
538
  if a.ndim == 4:
539
  if a.shape[-1] == 3:
540
+ # Heuristic: if first dim is much smaller than spatial dims -> assume T
541
+ if a.shape[0] <= a.shape[1] and a.shape[0] <= a.shape[2]: # (T,H,W,3)
542
  v = a
543
+ else: # (H,W,T,3) -> (T,H,W,3)
544
  v = np.transpose(a, (2, 0, 1, 3))
545
  else:
546
+ v = a[..., :1] # take first channel
547
  elif a.ndim == 3:
548
+ # (T,H,W) or (H,W,T)
549
+ if a.shape[0] <= a.shape[1] and a.shape[0] <= a.shape[2]: # (T,H,W)
550
  v = a
551
+ else: # (H,W,T) -> (T,H,W)
552
  v = np.transpose(a, (2, 0, 1))
553
  v = v[..., None]
554
  else:
 
561
  self._mat = v
562
  self._len = v.shape[0]
563
  self._shape = v.shape[1:3]
564
+ self._fps = 0.0 # unknown; caller can override
565
  except Exception as e:
566
  raise RuntimeError(f"Failed to open .mat video: {e}")
567
 
 
578
 
579
  def fps(self, fallback: int = 30) -> int:
580
  if self._mat is not None:
581
+ return int(fallback)
582
+ if self._fps and self._fps > 0:
583
+ return int(self._fps)
584
  f = self._cap.get(cv2.CAP_PROP_FPS)
585
+ return int(f) if f and f > 0 else int(fallback)
586
 
587
  def length(self) -> int:
588
  return self._len
 
591
  if self._cap is not None:
592
  self._cap.release()
593
 
594
+
595
  def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]:
596
  x, y, w, h = face
597
  # forehead
 
602
  ff = frame[y:y + h, x:x + w]
603
  return {"forehead": fh, "cheeks": ck, "face": ff}
604
 
605
+
606
  def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float:
607
  if patch is None or patch.size == 0:
608
  return -1e9
609
  g = patch[..., 1].astype(np.float32) / 255.0 # green channel
610
  g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling
611
  g = g - g.mean()
 
612
  try:
613
+ b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band")
614
  y = signal.filtfilt(b, a, g, method="gust")
615
  except Exception:
616
  y = g
617
  return float((y ** 2).mean())
618
 
619
+
620
  def pick_auto_roi(face: Tuple[int, int, int, int],
621
+ frame: np.ndarray,
622
+ attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]:
623
+ """Simple ROI selection using signal quality + optional attention weighting."""
624
  cands = roi_candidates(face, frame)
625
  scores = {k: roi_quality_score(v) for k, v in cands.items()}
626
 
 
633
  ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0
634
  ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0
635
  scores['forehead'] += fh_attn * 0.2
636
+ scores['cheeks'] += ck_attn * 0.2
637
+ scores['face'] += ff_attn * 0.2
638
  except Exception:
639
  pass
640
 
641
  best = max(scores, key=scores.get)
642
  return cands[best], best
643
 
644
+
645
  def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
646
  """
647
  Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None).
 
696
  uniq.append((v, g))
697
  return uniq
698
 
699
+
700
  def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"):
701
  import inspect
702
 
 
735
  except Exception:
736
  continue
737
 
738
+ raise ImportError("Could not find PhysMamba model class")
739
+
740
 
741
  def load_physmamba_model(ckpt_path: Path, device: torch.device,
742
  model_file: str = "", model_class: str = "PhysMamba"):
 
777
  state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
778
  model.load_state_dict(state_dict, strict=False)
779
  except Exception:
780
+ # If loading fails, you still get an uninitialized model (for debugging)
781
  pass
782
 
783
  model.to(device).eval()
 
786
  with torch.no_grad():
787
  _ = model(torch.zeros(1, 3, 8, 128, 128, device=device))
788
  except Exception:
789
+ # Shape sanity check failed, but we keep the model usable.
790
  pass
791
 
792
+ # For now: attention visualization disabled (extract_attention_map returns None)
793
  attention_viz = None
794
 
795
  return model, attention_viz
796
 
797
+
798
  def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray:
799
  """
800
  Stable band-pass with edge-safety and parameter clipping.
 
812
 
813
  try:
814
  b, a = signal.butter(order, [lo, hi], btype="band")
 
815
  padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1))
816
  return signal.filtfilt(b, a, x, padlen=padlen)
817
  except Exception:
818
  return x
819
 
820
+
821
  def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float:
822
  """
823
  HR (BPM) via Welch PSD peak in [lo, hi] Hz.
 
826
  if x.size < int(fs * 4.0): # need ~4s for a usable PSD
827
  return 0.0
828
  try:
 
829
  nper = int(min(max(64, fs * 2), min(512, x.size)))
830
  f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant")
831
 
 
840
 
841
  fpk = float(f_band[np.argmax(p_band)])
842
  bpm = fpk * 60.0
 
843
  return float(np.clip(bpm, 30.0, 220.0))
844
  except Exception:
845
  return 0.0
846
 
847
+
848
  def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
849
  """
850
  HRV RMSSD from peaks; robust to short/flat segments.
 
853
  if x.size < int(fs * 5.0):
854
  return 0.0
855
  try:
 
856
  peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs)))
857
  if len(peaks) < 3:
858
  return 0.0
 
863
  except Exception:
864
  return 0.0
865
 
866
+
867
  def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
868
  """
869
  Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band.
 
886
  lo, hi = REST_HR_RANGE
887
  if hr < lo or hr > hi:
888
  dist = abs(hr - REST_HR_TARGET)
 
889
  alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65))
890
  hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET
891
 
 
900
 
901
  return y_filt, float(hr)
902
 
903
+
904
  def draw_face_and_roi(frame_bgr: np.ndarray,
905
  face_bbox: Optional[Tuple[int, int, int, int]],
906
  roi_bbox: Optional[Tuple[int, int, int, int]],
 
919
  cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2)
920
  return vis
921
 
922
+
923
  def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
924
  roi_type: str,
925
  frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
 
945
  return (0, 0, 0, 0)
946
  return (rx, ry, rx2 - rx, ry2 - ry)
947
 
948
+
949
  def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
950
  """
951
  Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR.
 
954
  if chw is None or chw.ndim != 3 or chw.shape[0] != 3:
955
  return np.zeros((128, 128, 3), dtype=np.uint8)
956
 
 
957
  img = chw.copy()
 
958
  vmin, vmax = float(img.min()), float(img.max())
959
  if vmax <= vmin + 1e-6:
960
  img = np.zeros_like(img)
 
964
  img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR
965
  return img
966
 
967
+
968
  def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
969
  if gt_len <= 1:
970
  return None
 
972
  return np.arange(gt_len, dtype=float) / float(gt_fs)
973
  return None # will fall back to length-matching overlay
974
 
975
+
976
  def plot_signals_with_gt(time_axis: np.ndarray,
977
  raw_signal: np.ndarray,
978
  post_signal: np.ndarray,
 
1017
  t_new = _np.asarray(t_new, dtype=float).ravel()
1018
 
1019
  if x_t.size < 2 or y.size != x_t.size:
 
1020
  if y.size == 0 or t_new.size == 0:
1021
  return _np.zeros_like(t_new)
1022
  idx = _np.linspace(0, y.size - 1, num=t_new.size)
1023
  return _np.interp(_np.arange(t_new.size), idx, y)
1024
 
 
1025
  order = _np.argsort(x_t)
1026
  x_t = x_t[order]
1027
  y = y[order]
1028
  mask = _np.concatenate(([True], _np.diff(x_t) > 0))
1029
  x_t = x_t[mask]
1030
  y = y[mask]
 
1031
  t_clip = _np.clip(t_new, x_t[0], x_t[-1])
1032
  return _np.interp(t_clip, x_t, y)
1033
 
 
1039
  n = int(min(len(x), len(y)))
1040
  x = x[:n]; y = y[:n]
1041
  max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local))))
 
1042
  lags = _np.arange(-max_lag, max_lag + 1)
 
1043
  best_corr = -_np.inf
1044
  best_lag = 0
1045
  for L in lags:
 
1071
  out = _np.empty_like(y)
1072
  out[:] = _np.nan
1073
  if shift > 0:
 
1074
  out[shift:] = y[:-shift]
1075
  else:
 
1076
  out[:shift] = y[-shift:]
1077
  return out
1078
 
 
1080
  raw = _np.asarray(raw_signal, dtype=float)
1081
  post = _np.asarray(post_signal, dtype=float)
1082
 
 
1083
  if t.size == 0:
1084
  t = _np.arange(post.size, dtype=float) / max(fs, 1)
1085
 
 
1096
  gt_t = _np.asarray(gt_time, dtype=float).ravel()
1097
  gt_on_pred = _safe_interp(gt_t, gt, t)
1098
  else:
1099
+ gt_on_pred = _safe_interp(
1100
+ _np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size),
1101
+ gt, t
1102
+ )
1103
 
 
1104
  pred_bp = _bandpass(post, fs)
1105
  gt_bp = _bandpass(gt_on_pred, fs)
1106
 
 
1107
  lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0)
 
 
1108
  gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs)
1109
 
 
1110
  valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp)
1111
  if valid.sum() >= 16:
1112
  pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1])
 
1115
 
1116
  hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs)
1117
 
 
1118
  _plt.figure(figsize=(13, 6), dpi=110)
1119
  gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35)
1120
 
 
1121
  ax1 = _plt.subplot(gs[0, 0])
1122
  ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5)
1123
  ax1.set_title(f"Predicted (Raw) — fs={fs} Hz")
1124
  ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude")
1125
  ax1.grid(True, alpha=0.3)
1126
 
 
1127
  ax2 = _plt.subplot(gs[0, 1])
1128
  ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5)
1129
  ax2.set_title("Predicted (Post-processed)")
1130
  ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude")
1131
  ax2.grid(True, alpha=0.3)
1132
 
 
1133
  ax3 = _plt.subplot(gs[1, :])
1134
  ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6)
1135
 
 
1138
  gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs)
1139
  ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9)
1140
 
 
1141
  txt = [
1142
  f"HR_pred: {hr_pred:.1f} BPM",
1143
  f"HR_gt: {hr_gt:.1f} BPM",
 
1172
  if frame is None or frame.size == 0:
1173
  return None
1174
 
1175
+ # If cascade is missing, fail fast (prevents cryptic OpenCV errors)
1176
+ if FACE_CASCADE is None:
1177
+ return None
1178
+
1179
  try:
1180
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
1181
  except Exception:
1182
  # If color conversion fails, assume already gray
1183
+ if frame.ndim == 2:
1184
+ gray = frame.copy()
1185
+ else:
1186
+ gray = cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY)
1187
 
1188
  # Light preproc to improve Haar performance
1189
  gray = cv2.equalizeHist(gray)
1190
 
1191
+ faces_all: List[Tuple[int, int, int, int]] = []
1192
  # Try a couple of parameter combos to be more forgiving
1193
  params = [
1194
  dict(scaleFactor=1.05, minNeighbors=3),
 
1209
  # Return the largest (by area)
1210
  return max(faces_all, key=lambda f: f[2] * f[3])
1211
 
1212
+
1213
  def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]:
1214
  """
1215
  Crop ROI from the frame based on a face bbox and the selected roi_type.
 
1243
  return None
1244
  return roi
1245
 
1246
+
1247
+ def crop_roi_with_bbox(
1248
+ face_bbox: Tuple[int, int, int, int],
1249
+ roi_type: str,
1250
+ frame: np.ndarray
1251
+ ) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]:
1252
+ """
1253
+ Same as crop_roi, but also returns the ROI bbox (x, y, w, h) in frame coords.
1254
+ """
1255
  if face_bbox is None or frame is None or frame.size == 0:
1256
  return None, None
1257
 
 
1278
 
1279
  return roi, (rx, ry, rx2 - rx, ry2 - ry)
1280
 
1281
+
1282
  def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
1283
  """
1284
  PhysMamba-compatible normalization with DiffNormalized support.
 
1303
  chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False)
1304
  return chw
1305
 
1306
+
1307
+ def extract_attention_map(model, clip_tensor: torch.Tensor, attention_viz) -> Optional[np.ndarray]:
1308
  """Attention visualization disabled - model architecture incompatible."""
1309
  return None
1310
 
1311
+
1312
+ def create_attention_overlay(
1313
+ frame: np.ndarray,
1314
+ attention: Optional[np.ndarray],
1315
+ attention_viz: Optional[SimpleActivationAttention] = None
1316
+ ) -> np.ndarray:
1317
+ """Create attention heatmap overlay (currently passthrough)."""
1318
  return frame
1319
 
1320
+
1321
  def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
1322
  H, W = roi_bgr.shape[:2]
1323
+ base_bvp = forward_bvp(
1324
+ model,
1325
+ torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE))
1326
+ .unsqueeze(0).unsqueeze(2).to(DEVICE) # fake T=1 path if needed
1327
+ )
1328
  base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs)
1329
 
1330
  heat = np.zeros((H, W), np.float32)
 
1332
  for x in range(0, W - patch + 1, stride):
1333
  tmp = roi_bgr.copy()
1334
  tmp[y:y+patch, x:x+patch] = 127 # occlude
1335
+ bvp = forward_bvp(
1336
+ model,
1337
+ torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE))
1338
+ .unsqueeze(0).unsqueeze(2).to(DEVICE)
1339
+ )
1340
  power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs)
1341
  drop = max(0.0, base_power - power)
1342
  heat[y:y+patch, x:x+patch] += drop
1343
  heat -= heat.min()
1344
+ if heat.max() > 1e-8:
1345
+ heat /= heat.max()
1346
  return heat
1347
 
1348
+
1349
  def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
1350
  """
1351
  Try common 5D layouts:
 
1362
  last_err = e
1363
  raise last_err
1364
 
1365
+
1366
  def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
1367
  """
1368
  Forward and extract a 1D time-like BVP vector with length T_clip.
 
1409
  B, K = out.shape
1410
  if B == 1:
1411
  v = out[0]
1412
+ return (v.numpy() if v.shape[0] == T_clip
1413
+ else np.resize(v.numpy(), T_clip).astype(np.float32))
1414
  if B == T_clip:
1415
  return out[:, 0].numpy()
1416
  if K == T_clip:
 
1455
  val = float(out.mean().item()) if out.numel() else 0.0
1456
  return np.full(T_clip, val, dtype=np.float32)
1457
 
1458
+
1459
  def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
1460
  """
1461
  Classical rPPG from green-channel means when the model yields nothing.
 
1480
  std = float(np.std(y)) + 1e-6
1481
  return (y / std).astype(np.float32)
1482
 
1483
+
1484
  def _to_floats(s: str) -> List[float]:
1485
  """
1486
  Extract all real numbers from free-form text, including scientific notation.
1487
+ Gracefully ignores comments, units, and non-numeric junk.
1488
  """
1489
  if not isinstance(s, str) or not s:
1490
  return []
1491
 
1492
+ # Strip comments starting with #, //, or ;
1493
  s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE)
1494
 
1495
+ # Normalize common delimiters
1496
  s = s.replace(",", " ").replace(";", " ")
1497
 
1498
+ toks = re.findall(
1499
+ r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?",
1500
+ s
1501
+ )
1502
  out: List[float] = []
1503
  for t in toks:
1504
  try:
 
1538
  diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
1539
  return (1.0 / float(np.median(diffs))) if diffs.size else 0.0
1540
 
 
1541
  def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float:
1542
  if bvp is None or bvp.size == 0:
1543
  return 0.0
 
1545
  bp = bandpass_filter(bvp.astype(float), fs=fs_use)
1546
  return hr_from_welch(bp, fs=fs_use)
1547
 
1548
+ # ================= UBFC-style TXT (3 lines) =================
1549
+ if p.name.lower() == "ground_truth.txt" or (
1550
+ ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2
1551
+ ):
1552
  try:
1553
+ lines = [
1554
+ ln.strip()
1555
+ for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()
1556
+ if ln.strip()
1557
+ ]
1558
  ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else []
1559
  hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else []
1560
  t_vals = _to_floats(lines[2]) if len(lines) >= 3 else []
 
1578
  # Fall through to generic handlers
1579
  pass
1580
 
1581
+ # ================= Generic TXT =================
1582
  if ext == ".txt":
1583
  try:
1584
  nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore"))
 
1588
  except Exception:
1589
  return np.array([]), 0.0, 0.0
1590
 
1591
+ # ================= JSON =================
1592
  if ext == ".json":
1593
  try:
1594
  data = json.loads(p.read_text(encoding="utf-8", errors="ignore"))
 
 
1595
 
1596
  def _seek(obj, keys):
1597
  for k in keys:
 
1599
  return obj[k]
1600
  return None
1601
 
 
1602
  bvp = _seek(data, ("ppg", "bvp", "signal", "wave"))
 
1603
  if bvp is None:
1604
  for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"):
1605
  if container_key in data:
 
1613
  else:
1614
  bvp = np.array([], dtype=float)
1615
 
 
1616
  fs_hint = 0.0
1617
  if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0:
1618
  fs_hint = float(data["fs"])
 
1628
  except Exception:
1629
  return np.array([]), 0.0, 0.0
1630
 
1631
+ # ================= CSV =================
1632
  if ext == ".csv":
1633
  try:
1634
  df = pd.read_csv(p)
 
1635
  cols = {str(c).strip().lower(): c for c in df.columns}
1636
 
1637
  def _first_match(names):
 
1657
  except Exception:
1658
  return np.array([]), 0.0, 0.0
1659
 
1660
+ # ================= MAT =================
1661
  if ext == ".mat":
1662
  try:
1663
  md = loadmat(str(p))
 
1664
  arr = None
1665
  for key in ("ppg", "bvp", "signal", "wave"):
1666
  if key in md and isinstance(md[key], np.ndarray):
1667
  arr = md[key]
1668
  break
1669
  if arr is None:
 
1670
  for v in md.values():
1671
  if isinstance(v, np.ndarray) and v.ndim == 1:
1672
  arr = v
 
1700
  try:
1701
  bvp = np.asarray(np.load(str(p)), dtype=float).ravel()
1702
  fs_hint, hr = 0.0, 0.0
1703
+
1704
  sidecar = p.with_suffix(".json")
1705
  if sidecar.exists():
1706
  try:
 
1712
  hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
1713
  except Exception:
1714
  pass
1715
+
1716
  if hr == 0.0 and bvp.size:
1717
  hr = _hr_from_bvp(bvp, fs_hint)
1718
  return bvp, hr, fs_hint
 
1722
  # Fallback (unsupported extension)
1723
  return np.array([]), 0.0, 0.0
1724
 
1725
+
1726
  def scan_models() -> List[str]:
1727
  if not MODEL_DIR.exists():
1728
  return []
 
1734
 
1735
  return models
1736
 
1737
+
1738
  _GLOBAL_CONTROLS: Dict[str, Dict] = {}
1739
 
1740
+
1741
  def ensure_controls(control_id: str) -> Tuple[str, Dict]:
1742
  # Use a stable default so Pause/Resume/Stop work for the current run
1743
  if not control_id:
 
1749
  }
1750
  return control_id, _GLOBAL_CONTROLS[control_id]
1751
 
1752
+
1753
  def process_video_file(
1754
  video_path: str,
1755
  gt_file: Optional[str],
 
1762
  """
1763
  Enhanced video processing with Grad-CAM attention visualization,
1764
  plus per-frame illumination and motion logging.
1765
+ Returns 14 outputs per yield, matching Gradio UI:
1766
+ status, pred_hr, gt_hr, rmssd,
1767
+ frame_path, attention_path, signal_path, raw_path, post_path, csv_path,
1768
+ global_brightness, global_motion, roi_brightness, roi_motion
1769
  """
1770
  global _HR_SMOOTH
1771
+ global FRAME_METRICS # global frame metrics buffer
1772
  _HR_SMOOTH = None
1773
  FRAME_METRICS = [] # reset per run
1774
 
 
1782
  control_id, controls = ensure_controls(control_id)
1783
  controls['stop'].clear()
1784
 
1785
+ # Helper for consistent error yields (14 outputs)
1786
+ def _error_status(msg: str):
1787
+ return (
1788
+ msg,
1789
+ None, None, None, # HR, GT HR, RMSSD
1790
+ None, None, None, # frame, attention, signal
1791
+ None, None, # raw, post
1792
+ None, # csv
1793
+ None, None, None, None # brightness & motion
1794
+ )
1795
+
1796
  if not model_name:
1797
+ yield _error_status("ERROR: No model selected")
1798
  return
1799
 
1800
  if isinstance(model_name, int):
 
1802
 
1803
  model_path = MODEL_DIR / model_name
1804
  if not model_path.exists():
1805
+ yield _error_status("ERROR: Model not found")
1806
  return
1807
 
1808
  try:
1809
  model, attention_viz = load_physmamba_model(model_path, DEVICE)
1810
  except Exception as e:
1811
+ yield _error_status(f"ERROR loading model: {str(e)}")
1812
  return
1813
 
1814
  gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0)
1815
 
1816
  if not video_path or not os.path.exists(video_path):
1817
+ yield _error_status("ERROR: Video not found")
1818
  return
1819
 
1820
  cap = cv2.VideoCapture(video_path)
1821
  if not cap.isOpened():
1822
+ yield _error_status("ERROR: Cannot open video")
1823
  return
1824
 
1825
  fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30)
 
1836
  last_rmssd = 0.0
1837
  last_attention = None
1838
 
1839
+ # previous grayscale frames for motion
1840
  prev_gray = None
1841
  prev_roi_gray = None
1842
 
 
1850
  raw_path = tmpdir / "raw_signal.png"
1851
  post_path = tmpdir / "post_signal.png"
1852
 
1853
+ # Initial status yield (14 outputs)
1854
+ yield (
1855
+ "Starting… reading video frames",
1856
+ None,
1857
+ f"{gt_hr:.1f}" if gt_hr > 0 else "--",
1858
+ None,
1859
+ None, None, None, None, None, None,
1860
+ None, None, None, None
1861
+ )
1862
 
1863
  while True:
1864
  if controls['stop'].is_set():
 
1875
 
1876
  frame_idx += 1
1877
 
1878
+ # default per-frame metrics for this iteration
1879
  global_brightness = None
1880
  global_motion = None
1881
  roi_brightness = None
1882
  roi_motion = None
1883
 
1884
+ # global illumination & motion (full frame)
1885
  try:
1886
  frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
1887
  except Exception:
 
1907
  roi = crop_roi(face, roi_type, frame)
1908
 
1909
  if roi is not None and roi.size > 0:
1910
+ # ROI-level brightness & motion
1911
  try:
1912
  roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
1913
  except Exception:
 
1944
  try:
1945
  raw = forward_bvp(model, clip_t)
1946
  if isinstance(raw, np.ndarray):
1947
+ raw = np.nan_to_num(
1948
+ raw, nan=0.0, posinf=0.0, neginf=0.0
1949
+ ).astype(np.float32, copy=False)
1950
  bvp_out = raw if raw.size > 0 else None
1951
  else:
1952
  bvp_out = None
 
1954
  print(f"[infer] forward_bvp error: {e}")
1955
  bvp_out = None
1956
 
1957
+ # Generate attention with Grad-CAM
1958
  try:
1959
  last_attention = extract_attention_map(model, clip_t, attention_viz)
1960
  except Exception as e:
 
1962
  last_attention = None
1963
 
1964
  if bvp_out is None or bvp_out.size == 0:
1965
+ gbuf = np.nan_to_num(
1966
+ np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0
1967
+ )
1968
  fb = _fallback_bvp_from_means(gbuf, fs=fps)
1969
  if isinstance(fb, np.ndarray) and fb.size > 0:
1970
  bvp_out = fb
 
1978
  bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
1979
 
1980
  if len(bvp_stream) >= int(5 * fps):
1981
+ seg = np.asarray(
1982
+ bvp_stream[-int(10 * fps):], dtype=np.float32
1983
+ )
1984
  _, last_bpm = postprocess_bvp(seg, fs=fps)
1985
  last_rmssd = compute_rmssd(seg, fs=fps)
1986
 
 
1992
  last_infer = frame_idx
1993
 
1994
  else:
1995
+ cv2.putText(
1996
+ vis_frame,
1997
+ "No face detected",
1998
+ (20, 40),
1999
+ cv2.FONT_HERSHEY_SIMPLEX,
2000
+ 0.7,
2001
+ (30, 200, 255),
2002
+ 2,
2003
+ )
2004
 
2005
+ # Log per-frame metrics into global FRAME_METRICS
2006
  try:
2007
  time_s = frame_idx / float(fps) if fps > 0 else float(frame_idx)
2008
  except Exception:
 
2017
  "roi_motion": float(roi_motion) if roi_motion is not None else None,
2018
  })
2019
 
2020
+ # Pretty strings for UI
2021
+ gb_str = f"{global_brightness:.2f}" if global_brightness is not None else None
2022
+ gm_str = f"{global_motion:.2f}" if global_motion is not None else None
2023
+ rb_str = f"{roi_brightness:.2f}" if roi_brightness is not None else None
2024
+ rm_str = f"{roi_motion:.2f}" if roi_motion is not None else None
2025
+
2026
  if last_bpm > 0:
2027
  color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255)
2028
  cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1)
2029
+ cv2.putText(
2030
+ vis_frame,
2031
+ f"HR: {last_bpm:.1f} BPM",
2032
+ (20, 48),
2033
+ cv2.FONT_HERSHEY_SIMPLEX,
2034
+ 0.9,
2035
+ color,
2036
+ 2,
2037
+ )
2038
 
2039
  vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
2040
 
 
2105
  elapsed = now - start_time
2106
  status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
2107
 
2108
+ # Periodic UI update: 14 outputs
2109
  yield (
2110
  status,
2111
  f"{last_bpm:.1f}" if last_bpm > 0 else None,
 
2116
  str(signal_path) if signal_path.exists() else None,
2117
  str(raw_path) if raw_path.exists() else None,
2118
  str(post_path) if post_path.exists() else None,
2119
+ None, # CSV placeholder (filled at end)
2120
+ gb_str, # global brightness
2121
+ gm_str, # global motion
2122
+ rb_str, # ROI brightness
2123
+ rm_str # ROI motion
2124
  )
2125
 
2126
  next_display = now + (1.0 / DISPLAY_FPS)
 
2160
  except Exception:
2161
  pass
2162
 
2163
+ # Save per-frame illumination & motion metrics
2164
  frame_metrics_path = None
2165
  if FRAME_METRICS:
2166
  try:
 
2171
  print(f"[metrics] Failed to save frame metrics CSV: {e}")
2172
  frame_metrics_path = None
2173
 
2174
+ # Decide which CSV to expose: metrics preferred, else BVP
2175
+ final_csv = frame_metrics_path or csv_path
2176
+
2177
  elapsed = time.time() - start_time
2178
  final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
2179
 
2180
+ # Final yield: 14 outputs
2181
  yield (
2182
  final_status,
2183
  f"{last_bpm:.1f}" if last_bpm > 0 else None,
 
2188
  str(signal_path) if signal_path.exists() else None,
2189
  str(raw_path) if raw_path.exists() else None,
2190
  str(post_path) if post_path.exists() else None,
2191
+ str(final_csv) if final_csv else None,
2192
+ None, # final global brightness (leave None or reuse gb_str)
2193
+ None, # final global motion
2194
+ None, # final ROI brightness
2195
+ None # final ROI motion
2196
  )
2197
 
2198
+
2199
  def process_stream(
2200
  input_source: str,
2201
  video_path: Optional[str],
 
2209
  if input_source == "Live Webcam":
2210
  yield from process_live_webcam(model_name, fps_input, roi_type, control_id)
2211
  else:
2212
+ yield from process_video_file(
2213
+ video_path, gt_file, model_name, fps_input,
2214
+ max_seconds, roi_type, control_id
2215
+ )
2216
 
2217
  def pause_processing(control_id: str) -> str:
2218
  _, controls = ensure_controls(control_id)
 
2231
  return "Stopped"
2232
 
2233
  def reset_ui():
2234
+ # 14 values in total, matching all outputs from run_processing / process_stream
2235
+ return (
2236
+ "Ready", # status_text
2237
+ None, # hr_output
2238
+ None, # gt_hr_output
2239
+ None, # rmssd_output
2240
+ None, # frame_output
2241
+ None, # attention_output
2242
+ None, # signal_output
2243
+ None, # raw_signal_output
2244
+ None, # post_signal_output
2245
+ None, # csv_output
2246
+ None, # global_brightness_output
2247
+ None, # global_motion_output
2248
+ None, # roi_brightness_output
2249
+ None # roi_motion_output
2250
+ )
2251
 
2252
  def handle_folder_upload(files):
2253
  if not files:
 
2314
  with gr.Column():
2315
  video_upload = gr.Video(label="Upload Video", sources=["upload"])
2316
  with gr.Column():
2317
+ gt_upload = gr.File(
2318
+ label="Ground Truth (optional)",
2319
+ file_types=[".txt", ".csv", ".json"]
2320
+ )
2321
 
2322
  with gr.Row(visible=False) as folder_inputs:
2323
  with gr.Column():
 
2382
  )
2383
 
2384
  with gr.Row():
2385
+ roi_dropdown = gr.Dropdown(
2386
+ choices=["auto", "forehead", "cheeks", "face"],
2387
+ value="auto",
2388
+ label="ROI"
2389
+ )
2390
 
2391
  control_state = gr.State(value="")
2392
  placeholder_state = gr.State(value=None)
 
2403
  hr_output = gr.Textbox(label="HR (BPM)", interactive=False)
2404
  gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False)
2405
  rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False)
2406
+
2407
+ # NEW: illumination & motion front-end outputs
2408
+ with gr.Row():
2409
+ global_brightness_output = gr.Textbox(
2410
+ label="Global Brightness", interactive=False
2411
+ )
2412
+ global_motion_output = gr.Textbox(
2413
+ label="Global Motion", interactive=False
2414
+ )
2415
+ roi_brightness_output = gr.Textbox(
2416
+ label="ROI Brightness", interactive=False
2417
+ )
2418
+ roi_motion_output = gr.Textbox(
2419
+ label="ROI Motion", interactive=False
2420
+ )
2421
 
2422
  with gr.Row():
2423
  with gr.Column():
 
2456
  ).then(
2457
  reset_ui,
2458
  inputs=None,
2459
+ outputs=[
2460
+ status_text,
2461
+ hr_output,
2462
+ gt_hr_output,
2463
+ rmssd_output,
2464
+ frame_output,
2465
+ attention_output,
2466
+ signal_output,
2467
+ raw_signal_output,
2468
+ post_signal_output,
2469
+ csv_output,
2470
+ global_brightness_output,
2471
+ global_motion_output,
2472
+ roi_brightness_output,
2473
+ roi_motion_output,
2474
+ ]
2475
  )
2476
 
2477
+ def run_processing(
2478
+ input_source,
2479
+ video_upload,
2480
+ gt_upload,
2481
+ folder_video,
2482
+ folder_gt,
2483
+ model_name,
2484
+ fps,
2485
+ max_sec,
2486
+ roi,
2487
+ ctrl_id
2488
+ ):
2489
+ """Wrapper that resolves paths and streams from process_stream."""
2490
 
2491
  if isinstance(model_name, int):
2492
  model_name = str(model_name)
2493
 
2494
  if not model_name:
2495
+ # must return 14 outputs to match UI wiring
2496
+ yield (
2497
+ "ERROR: No model selected",
2498
+ None, None, None, # HR, GT HR, RMSSD
2499
+ None, None, None, # frame, attention, signal
2500
+ None, None, # raw, post
2501
+ None, # CSV
2502
+ None, None, None, None # brightness & motion fields
2503
+ )
2504
  return
2505
 
2506
  if input_source == "Video File":
 
2512
  else: # Live Webcam
2513
  video_path, gt_file = None, None
2514
 
2515
+ # This yields 14-tuples from process_video_file / process_live_webcam
2516
  yield from process_stream(
2517
  input_source, video_path, gt_file,
2518
  model_name, fps, max_sec, roi, ctrl_id
2519
  )
2520
 
 
2521
  run_btn.click(
2522
  fn=run_processing,
2523
  inputs=[
 
2542
  signal_output,
2543
  raw_signal_output,
2544
  post_signal_output,
2545
+ csv_output,
2546
+ global_brightness_output,
2547
+ global_motion_output,
2548
+ roi_brightness_output,
2549
+ roi_motion_output,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2550
  ]
2551
  )
2552