swetchareddytukkani commited on
Commit
1c6711c
·
1 Parent(s): 47fcc3f

Initial commit with PhysMamba rPPG application

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +16 -0
  2. app.py +2559 -0
  3. final_model_release/PURE_PhysMamba_DiffNormalized.pth +3 -0
  4. final_model_release/UBFC-rPPG_PhysMamba_DiffNormalized.pth +3 -0
  5. mamba_ssm/__init__.py +20 -0
  6. mamba_ssm/models/__init__.py +0 -0
  7. mamba_ssm/models/mixer_seq_simple.py +233 -0
  8. mamba_ssm/modules/__init__.py +0 -0
  9. mamba_ssm/modules/mamba_simple.py +418 -0
  10. mamba_ssm/ops/__init__.py +7 -0
  11. mamba_ssm/ops/selective_scan_interface.py +17 -0
  12. mamba_ssm/ops/triton/__init__.py +0 -0
  13. mamba_ssm/ops/triton/layernorm.py +636 -0
  14. mamba_ssm/ops/triton/selective_state_update.py +192 -0
  15. mamba_ssm/utils/__init__.py +0 -0
  16. mamba_ssm/utils/generation.py +377 -0
  17. mamba_ssm/utils/hf.py +23 -0
  18. neural_methods/__init__.py +0 -0
  19. neural_methods/loss/NegPearsonLoss.py +23 -0
  20. neural_methods/loss/PhysFormerLossComputer.py +120 -0
  21. neural_methods/loss/PhysNetNegPearsonLoss.py +43 -0
  22. neural_methods/loss/RythmFormerLossComputer.py +167 -0
  23. neural_methods/loss/__init__.py +0 -0
  24. neural_methods/model/BigSmall.py +177 -0
  25. neural_methods/model/DeepPhys.py +125 -0
  26. neural_methods/model/EfficientPhys.py +128 -0
  27. neural_methods/model/FactorizePhys/FSAM.py +530 -0
  28. neural_methods/model/FactorizePhys/FactorizePhys.py +251 -0
  29. neural_methods/model/FactorizePhys/FactorizePhysBig.py +251 -0
  30. neural_methods/model/FactorizePhys/__init__.py +0 -0
  31. neural_methods/model/FactorizePhys/test_FactorizePhys.py +286 -0
  32. neural_methods/model/FactorizePhys/test_FactorizePhysBig.py +292 -0
  33. neural_methods/model/PhysFormer.py +313 -0
  34. neural_methods/model/PhysMamba.py +246 -0
  35. neural_methods/model/PhysNet.py +124 -0
  36. neural_methods/model/RhythmFormer.py +418 -0
  37. neural_methods/model/TS_CAN.py +269 -0
  38. neural_methods/model/__init__.py +0 -0
  39. neural_methods/model/iBVPNet.py +194 -0
  40. neural_methods/trainer/BaseTrainer.py +108 -0
  41. neural_methods/trainer/BigSmallTrainer.py +484 -0
  42. neural_methods/trainer/BigSmallTrainer.py.backup +484 -0
  43. neural_methods/trainer/DeepPhysTrainer.py +209 -0
  44. neural_methods/trainer/DeepPhysTrainer.py.backup +209 -0
  45. neural_methods/trainer/EfficientPhysTrainer.py +228 -0
  46. neural_methods/trainer/EfficientPhysTrainer.py.backup +228 -0
  47. neural_methods/trainer/FactorizePhysTrainer.py +312 -0
  48. neural_methods/trainer/FactorizePhysTrainer.py.backup +312 -0
  49. neural_methods/trainer/PhysFormerTrainer.py +273 -0
  50. neural_methods/trainer/PhysFormerTrainer.py.backup +273 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.so
7
+ *.egg
8
+ *.egg-info/
9
+ dist/
10
+ build/
11
+ .pyenv/
12
+ .venv/
13
+ logs/
14
+ analysis/
15
+ *.log
16
+ .DS_Store
app.py ADDED
@@ -0,0 +1,2559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ.setdefault("OPENCV_AVFOUNDATION_SKIP_AUTH", "1")
3
+ os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
4
+
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import re
9
+ import io
10
+ import json
11
+ import tempfile
12
+ import time
13
+ import threading
14
+ import uuid
15
+ import importlib.util
16
+ from pathlib import Path
17
+ from collections import deque
18
+ from typing import Optional, Tuple, Dict, List
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ import cv2
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+
28
+ from scipy import signal
29
+ from scipy.signal import find_peaks, welch, get_window
30
+
31
+ # .mat support (SCAMPS)
32
+ try:
33
+ from scipy.io import loadmat
34
+ MAT_SUPPORT = True
35
+ except Exception:
36
+ MAT_SUPPORT = False
37
+
38
+ # Matplotlib headless backend
39
+ import matplotlib
40
+ matplotlib.use("Agg")
41
+ import matplotlib.pyplot as plt
42
+ from matplotlib.gridspec import GridSpec
43
+
44
+ import gradio as gr
45
+
46
+ class PhysMambaattention_viz:
47
+ """Simplified Grad-CAM for PhysMamba."""
48
+
49
+ def __init__(self, model: nn.Module, device: torch.device):
50
+ self.model = model
51
+ self.device = device
52
+ self.activations = None
53
+ self.gradients = None
54
+ self.hook_handles = []
55
+
56
+ def _get_target_layer(self):
57
+ """Auto-detect the best layer for visualization, excluding Mamba/SSM layers."""
58
+
59
+ # Strategy 1: Look for temporal difference convolution
60
+ for name, module in reversed(list(self.model.named_modules())):
61
+ if ('tdc' in name.lower() or 'temporal_diff' in name.lower()) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
62
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
63
+ print(f"[attention_viz] Selected TDC layer: {name} ({type(module).__name__})")
64
+ return name, module
65
+
66
+ # Strategy 2: Look for 3D convolutions (temporal-spatial)
67
+ for name, module in reversed(list(self.model.named_modules())):
68
+ if isinstance(module, nn.Conv3d) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
69
+ print(f"[attention_viz] Selected Conv3d layer: {name}")
70
+ return name, module
71
+
72
+ # Strategy 3: Look for 2D convolutions (spatial)
73
+ for name, module in reversed(list(self.model.named_modules())):
74
+ if isinstance(module, nn.Conv2d) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
75
+ print(f"[attention_viz] Selected Conv2d layer: {name}")
76
+ return name, module
77
+
78
+ # Strategy 4: Look for any convolution (last resort)
79
+ for name, module in reversed(list(self.model.named_modules())):
80
+ if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
81
+ print(f"[attention_viz] Selected Conv layer (fallback): {name} ({type(module).__name__})")
82
+ return name, module
83
+
84
+ # Strategy 5: Print all available layers and pick the last non-Mamba one
85
+ print("\n[attention_viz] Available layers:")
86
+ suitable_layers = []
87
+ for name, module in self.model.named_modules():
88
+ layer_type = type(module).__name__
89
+ if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
90
+ is_mamba = 'mamba' in name.lower() or 'ssm' in name.lower()
91
+ print(f" - {name}: {layer_type} {'(MAMBA - SKIP)' if is_mamba else ''}")
92
+ if not is_mamba:
93
+ suitable_layers.append((name, module))
94
+
95
+ if suitable_layers:
96
+ name, module = suitable_layers[-1]
97
+ print(f"\n[attention_viz] Selected last suitable layer: {name} ({type(module).__name__})")
98
+ return name, module
99
+
100
+ raise ValueError("No suitable layer found for Grad-CAM (all layers are Mamba/SSM)")
101
+
102
+ def _forward_hook(self, module, input, output):
103
+ self.activations = output.detach()
104
+
105
+ def _backward_hook(self, module, grad_input, grad_output):
106
+ self.gradients = grad_output[0].detach()
107
+
108
+ def register_hooks(self, target_layer_name: Optional[str] = None):
109
+ """Register hooks on target layer"""
110
+ self._remove_hooks()
111
+
112
+ if target_layer_name and target_layer_name.strip():
113
+ # Manual selection
114
+ try:
115
+ target_module = dict(self.model.named_modules())[target_layer_name.strip()]
116
+ print(f"Using manually specified layer: {target_layer_name}")
117
+ except KeyError:
118
+ print(f"⚠ Layer '{target_layer_name}' not found, falling back to auto-detection")
119
+ target_layer_name, target_module = self._get_target_layer()
120
+ else:
121
+ # Auto-detection
122
+ target_layer_name, target_module = self._get_target_layer()
123
+
124
+ fwd_handle = target_module.register_forward_hook(self._forward_hook)
125
+ bwd_handle = target_module.register_full_backward_hook(self._backward_hook)
126
+
127
+ self.hook_handles = [fwd_handle, bwd_handle]
128
+
129
+ def _remove_hooks(self):
130
+ for handle in self.hook_handles:
131
+ handle.remove()
132
+ self.hook_handles = []
133
+ self.activations = None
134
+ self.gradients = None
135
+
136
+ def generate(self, input_tensor: torch.Tensor) -> np.ndarray:
137
+ """Generate Grad-CAM heatmap with improved gradient handling."""
138
+ # Set model to eval but keep gradient computation
139
+ self.model.eval()
140
+
141
+ # Ensure tensor requires grad
142
+ input_tensor = input_tensor.requires_grad_(True)
143
+
144
+ # Forward pass
145
+ output = self.model(input_tensor)
146
+
147
+ # Handle different output types
148
+ if isinstance(output, dict):
149
+ # Try common keys
150
+ for key in ('pred', 'output', 'bvp', 'logits', 'out'):
151
+ if key in output and isinstance(output[key], torch.Tensor):
152
+ output = output[key]
153
+ break
154
+
155
+ # Get scalar for backward
156
+ if output.numel() == 1:
157
+ target = output
158
+ else:
159
+ # Use mean of output
160
+ target = output.mean()
161
+
162
+ print(f"[attention_viz] Output shape: {output.shape}, Target: {target.item():.4f}")
163
+
164
+ # Backward pass
165
+ self.model.zero_grad()
166
+ target.backward(retain_graph=False)
167
+
168
+ # Check if we got gradients
169
+ if self.activations is None:
170
+ print("⚠ No activations captured!")
171
+ return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
172
+
173
+ if self.gradients is None:
174
+ print("⚠ No gradients captured!")
175
+ return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
176
+
177
+ print(f"[attention_viz] Activations shape: {self.activations.shape}")
178
+ print(f"[attention_viz] Gradients shape: {self.gradients.shape}")
179
+
180
+ activations = self.activations
181
+ gradients = self.gradients
182
+
183
+ # Compute weights (global average pooling of gradients)
184
+ if gradients.dim() == 5: # [B, C, T, H, W]
185
+ weights = gradients.mean(dim=[2, 3, 4], keepdim=True)
186
+ elif gradients.dim() == 4: # [B, C, H, W]
187
+ weights = gradients.mean(dim=[2, 3], keepdim=True)
188
+ elif gradients.dim() == 3: # [B, C, T]
189
+ weights = gradients.mean(dim=2, keepdim=True).unsqueeze(-1).unsqueeze(-1)
190
+ else:
191
+ print(f"⚠ Unexpected gradient dimensions: {gradients.dim()}")
192
+ return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
193
+
194
+ # Weighted combination
195
+ cam = (weights * activations).sum(dim=1, keepdim=True)
196
+
197
+ # If 5D, average over time
198
+ if cam.dim() == 5:
199
+ cam = cam.mean(dim=2)
200
+
201
+ # ReLU
202
+ cam = F.relu(cam)
203
+
204
+ # Convert to numpy
205
+ cam = cam.squeeze().cpu().detach().numpy()
206
+
207
+ # Normalize
208
+ if cam.max() > cam.min():
209
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
210
+
211
+ print(f"[attention_viz] Final heatmap shape: {cam.shape}, range: [{cam.min():.3f}, {cam.max():.3f}]")
212
+
213
+ return cam
214
+
215
+ def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
216
+ """Overlay heatmap on frame."""
217
+ h, w = frame.shape[:2]
218
+ heatmap_resized = cv2.resize(heatmap, (w, h))
219
+ heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
220
+ heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
221
+ overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
222
+ return overlay
223
+
224
+ def cleanup(self):
225
+ self._remove_hooks()
226
+
227
+
228
+ def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray:
229
+ """Apply DiffNormalized preprocessing from PhysMamba paper."""
230
+ if len(frames) < 2:
231
+ return np.zeros((len(frames), *frames[0].shape), dtype=np.float32)
232
+
233
+ diff_frames = []
234
+
235
+ for i in range(len(frames)):
236
+ if i == 0:
237
+ diff_frames.append(np.zeros_like(frames[0], dtype=np.float32))
238
+ else:
239
+ curr = frames[i].astype(np.float32)
240
+ prev = frames[i-1].astype(np.float32)
241
+ denominator = curr + prev + 1e-8
242
+ diff = (curr - prev) / denominator
243
+ diff_frames.append(diff)
244
+
245
+ diff_array = np.stack(diff_frames)
246
+ std = diff_array.std()
247
+ if std > 0:
248
+ diff_array = diff_array / std
249
+
250
+ return diff_array
251
+
252
+
253
+ def preprocess_for_physmamba(frames: List[np.ndarray],
254
+ target_frames: int = 128,
255
+ target_size: int = 128) -> torch.Tensor:
256
+ """Complete preprocessing pipeline for PhysMamba model."""
257
+ if len(frames) < target_frames:
258
+ frames = frames + [frames[-1]] * (target_frames - len(frames))
259
+ elif len(frames) > target_frames:
260
+ indices = np.linspace(0, len(frames)-1, target_frames).astype(int)
261
+ frames = [frames[i] for i in indices]
262
+
263
+ frames_rgb = [f[..., ::-1].copy() for f in frames]
264
+ frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb]
265
+ frames_diff = apply_diff_normalized(frames_resized)
266
+ frames_transposed = np.transpose(frames_diff, (3, 0, 1, 2))
267
+ frames_batched = np.expand_dims(frames_transposed, axis=0)
268
+ tensor = torch.from_numpy(frames_batched.astype(np.float32))
269
+
270
+ return tensor
271
+
272
+ HERE = Path(__file__).parent
273
+ MODEL_DIR = HERE / "final_model_release"
274
+ LOG_DIR = HERE / "logs"
275
+ ANALYSIS_DIR = HERE / "analysis"
276
+ for d in [MODEL_DIR, LOG_DIR, ANALYSIS_DIR]:
277
+ d.mkdir(exist_ok=True)
278
+
279
+ DEVICE = (
280
+ torch.device("cuda") if torch.cuda.is_available()
281
+ else torch.device("mps") if torch.backends.mps.is_available()
282
+ else torch.device("cpu")
283
+ )
284
+
285
+ FACE_CASCADE = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
286
+
287
+ DEFAULT_SIZE = 128 # input H=W to model
288
+ DEFAULT_T = 128 # clip length
289
+ DEFAULT_STRIDE = 8 # inference hop
290
+ DISPLAY_FPS = 10
291
+
292
+ VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v']
293
+
294
+ _HR_SMOOTH = None
295
+ REST_HR_TARGET = 72.0
296
+ REST_HR_RANGE = (55.0, 95.0)
297
+ MAX_JUMP_BPM = 8.0
298
+
299
+ # Recognized GT files for subject folders
300
+ GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"}
301
+ GT_EXTS = {".txt", ".csv", ".json"}
302
+
303
+ def _as_path(maybe) -> Optional[str]:
304
+ """Return a filesystem path from Gradio values (str, dict, tempfile objects)."""
305
+ if maybe is None:
306
+ return None
307
+ if isinstance(maybe, str):
308
+ return maybe
309
+ if isinstance(maybe, dict):
310
+ return maybe.get("name") or maybe.get("path")
311
+ name = getattr(maybe, "name", None) # tempfile-like object
312
+ if isinstance(name, str):
313
+ return name
314
+ try:
315
+ return str(maybe)
316
+ except Exception:
317
+ return None
318
+
319
+ def _import_from_file(py_path: Path):
320
+ spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path))
321
+ if not spec or not spec.loader:
322
+ raise ImportError(f"Cannot import module from {py_path}")
323
+ mod = importlib.util.module_from_spec(spec)
324
+ spec.loader.exec_module(mod)
325
+ return mod
326
+
327
+ def _looks_like_video(p: Path) -> bool:
328
+ if p.suffix.lower() == ".mat":
329
+ return True
330
+ return p.suffix.lower() in VIDEO_EXTENSIONS
331
+
332
+ class SimpleActivationAttention:
333
+ """Lightweight attention visualization without gradients."""
334
+
335
+ def __init__(self, model: nn.Module, device: torch.device):
336
+ self.model = model
337
+ self.device = device
338
+ self.activations = None
339
+ self.hook_handle = None
340
+
341
+ def _activation_hook(self, module, input, output):
342
+ """Capture activations during forward pass."""
343
+ self.activations = output.detach()
344
+
345
+ def register_hook(self):
346
+ """Register hook on a suitable layer."""
347
+ # Find the last convolutional layer before Mamba
348
+ target = None
349
+ target_name = None
350
+
351
+ for name, module in self.model.named_modules():
352
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower():
353
+ target = module
354
+ target_name = name
355
+
356
+ if target is None:
357
+ print("⚠ [attention_viz] No suitable conv layer found, attention disabled")
358
+ return
359
+
360
+ self.hook_handle = target.register_forward_hook(self._activation_hook)
361
+ print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})")
362
+
363
+ def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]:
364
+ """Generate attention map from activations (call after forward pass)."""
365
+ try:
366
+ if self.activations is None:
367
+ return None
368
+
369
+ # Process activations to create spatial attention
370
+ act = self.activations
371
+
372
+ # Handle different tensor shapes
373
+ if act.dim() == 5: # [B, C, T, H, W]
374
+ # Average over time and channels
375
+ attention = act.mean(dim=[1, 2]) # -> [B, H, W]
376
+ elif act.dim() == 4: # [B, C, H, W]
377
+ attention = act.mean(dim=1) # -> [B, H, W]
378
+ else:
379
+ print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}")
380
+ return None
381
+
382
+ # Convert to numpy
383
+ attention = attention.squeeze().cpu().numpy()
384
+
385
+ # Normalize to [0, 1]
386
+ if attention.max() > attention.min():
387
+ attention = (attention - attention.min()) / (attention.max() - attention.min())
388
+
389
+ return attention
390
+
391
+ except Exception as e:
392
+ print(f"⚠ [attention_viz] Generation failed: {e}")
393
+ return None
394
+
395
+ def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
396
+ """Overlay heatmap on frame."""
397
+ h, w = frame.shape[:2]
398
+ heatmap_resized = cv2.resize(heatmap, (w, h))
399
+ heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
400
+ heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
401
+ overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
402
+ return overlay
403
+
404
+ def cleanup(self):
405
+ if self.hook_handle is not None:
406
+ self.hook_handle.remove()
407
+
408
+ class VideoReader:
409
+ """
410
+ Unified frame reader:
411
+ • Regular videos via cv2.VideoCapture
412
+ • .mat 'videos' (e.g., SCAMPS): expects array like (T,H,W,3) or (H,W,T,3) or (T,H,W)
413
+ Returns frames as BGR uint8.
414
+ """
415
+ def __init__(self, path: str):
416
+ self.path = str(path)
417
+ self._cap = None
418
+ self._mat = None
419
+ self._idx = 0
420
+ self._len = 0
421
+ self._shape = None
422
+
423
+ if self.path.lower().endswith(".mat") and MAT_SUPPORT:
424
+ self._open_mat(self.path)
425
+ else:
426
+ self._open_cv(self.path)
427
+
428
+ def _open_cv(self, path: str):
429
+ cap = cv2.VideoCapture(path)
430
+ if not cap.isOpened():
431
+ raise RuntimeError("Cannot open video")
432
+ self._cap = cap
433
+ self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
434
+
435
+ def _open_mat(self, path: str):
436
+ try:
437
+ md = loadmat(path)
438
+ # Common keys in SCAMPS-like dumps
439
+ for key in ("video", "frames", "vid", "data"):
440
+ if key in md and isinstance(md[key], np.ndarray):
441
+ arr = md[key]
442
+ break
443
+ else:
444
+ arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None)
445
+ if arr is None:
446
+ raise RuntimeError("No ndarray found in .mat")
447
+
448
+ a = np.asarray(arr)
449
+ # Normalize to (T,H,W,3)
450
+ if a.ndim == 4:
451
+ if a.shape[-1] == 3:
452
+ if a.shape[0] < a.shape[2]: # (T,H,W,3) heuristic
453
+ v = a
454
+ else: # (H,W,T,3) -> (T,H,W,3)
455
+ v = np.transpose(a, (2, 0, 1, 3))
456
+ else:
457
+ v = a[..., :1] # take first channel
458
+ elif a.ndim == 3:
459
+ if a.shape[0] < a.shape[2]: # (T,H,W)
460
+ v = a
461
+ else: # (H,W,T) -> (T,H,W)
462
+ v = np.transpose(a, (2, 0, 1))
463
+ v = v[..., None]
464
+ else:
465
+ raise RuntimeError(f"Unsupported .mat video shape: {a.shape}")
466
+
467
+ v = np.ascontiguousarray(v)
468
+ if v.shape[-1] == 1:
469
+ v = np.repeat(v, 3, axis=-1)
470
+ v = v.astype(np.uint8)
471
+ self._mat = v
472
+ self._len = v.shape[0]
473
+ self._shape = v.shape[1:3]
474
+ except Exception as e:
475
+ raise RuntimeError(f"Failed to open .mat video: {e}")
476
+
477
+ def read(self):
478
+ """Return (ret, frame[BGR]) like cv2.VideoCapture.read()."""
479
+ if self._mat is not None:
480
+ if self._idx >= self._len:
481
+ return False, None
482
+ frame = self._mat[self._idx]
483
+ self._idx += 1
484
+ return True, frame
485
+ else:
486
+ return self._cap.read()
487
+
488
+ def fps(self, fallback: int = 30) -> int:
489
+ if self._mat is not None:
490
+ return fallback # .mat typically lacks FPS; caller can override
491
+ f = self._cap.get(cv2.CAP_PROP_FPS)
492
+ return int(f) if f and f > 0 else fallback
493
+
494
+ def length(self) -> int:
495
+ return self._len
496
+
497
+ def release(self):
498
+ if self._cap is not None:
499
+ self._cap.release()
500
+
501
+ def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]:
502
+ x, y, w, h = face
503
+ # forehead
504
+ fh = frame[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)]
505
+ # cheeks
506
+ ck = frame[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)]
507
+ # full face
508
+ ff = frame[y:y + h, x:x + w]
509
+ return {"forehead": fh, "cheeks": ck, "face": ff}
510
+
511
+ def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float:
512
+ if patch is None or patch.size == 0:
513
+ return -1e9
514
+ g = patch[..., 1].astype(np.float32) / 255.0 # green channel
515
+ g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling
516
+ g = g - g.mean()
517
+ b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band")
518
+ try:
519
+ y = signal.filtfilt(b, a, g, method="gust")
520
+ except Exception:
521
+ y = g
522
+ return float((y ** 2).mean())
523
+
524
+ def pick_auto_roi(face: Tuple[int, int, int, int],
525
+ frame: np.ndarray,
526
+ attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]:
527
+ """Simple ROI selection."""
528
+ cands = roi_candidates(face, frame)
529
+ scores = {k: roi_quality_score(v) for k, v in cands.items()}
530
+
531
+ if attn is not None and attn.size:
532
+ H, W = frame.shape[:2]
533
+ try:
534
+ attn_resized = cv2.resize(attn, (W, H))
535
+ x, y, w, h = face
536
+ fh_attn = attn_resized[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)].mean() if attn_resized.size > 0 else 0.0
537
+ ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0
538
+ ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0
539
+ scores['forehead'] += fh_attn * 0.2
540
+ scores['cheeks'] += ck_attn * 0.2
541
+ scores['face'] += ff_attn * 0.2
542
+ except Exception:
543
+ pass
544
+
545
+ best = max(scores, key=scores.get)
546
+ return cands[best], best
547
+
548
+ def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
549
+ """
550
+ Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None).
551
+ Heuristics:
552
+ - Subject folder: any directory containing at least one video-like file (.mat or common video).
553
+ - If multiple videos, pick the largest by size.
554
+ - GT file: prefer known names in GT_FILENAMES, else any .txt/.csv/.json not named readme.
555
+ """
556
+ pairs: List[Tuple[str, Optional[str]]] = []
557
+ if not root_dir.exists():
558
+ return pairs
559
+
560
+ def pick_pair(folder: Path) -> Optional[Tuple[str, Optional[str]]]:
561
+ vids = [p for p in folder.rglob("*") if p.is_file() and _looks_like_video(p)]
562
+ if not vids:
563
+ return None
564
+ vids.sort(key=lambda p: p.stat().st_size if p.exists() else 0, reverse=True)
565
+ video = vids[0]
566
+
567
+ gt: Optional[Path] = None
568
+ for p in folder.rglob("*"):
569
+ if p.is_file() and p.name.lower() in GT_FILENAMES:
570
+ gt = p
571
+ break
572
+ if gt is None:
573
+ cands = [
574
+ p for p in folder.rglob("*")
575
+ if p.is_file() and p.suffix.lower() in GT_EXTS and "readme" not in p.name.lower()
576
+ ]
577
+ if cands:
578
+ gt = cands[0]
579
+ return str(video), (str(gt) if gt else None)
580
+
581
+ subs = [d for d in root_dir.iterdir() if d.is_dir()]
582
+ if subs:
583
+ for sub in subs:
584
+ pair = pick_pair(sub)
585
+ if pair:
586
+ pairs.append(pair)
587
+ else:
588
+ pair = pick_pair(root_dir) # the root itself might be a single-subject folder
589
+ if pair:
590
+ pairs.append(pair)
591
+
592
+ # Deduplicate
593
+ seen = set()
594
+ uniq: List[Tuple[str, Optional[str]]] = []
595
+ for v, g in pairs:
596
+ key = (v, g or "")
597
+ if key not in seen:
598
+ seen.add(key)
599
+ uniq.append((v, g))
600
+ return uniq
601
+
602
+ def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"):
603
+ import inspect
604
+
605
+ if model_file:
606
+ model_path = (repo_root / model_file).resolve()
607
+ if model_path.exists():
608
+ try:
609
+ mod = _import_from_file(model_path)
610
+ if hasattr(mod, model_class):
611
+ return getattr(mod, model_class)
612
+ except Exception:
613
+ pass
614
+
615
+ search_dirs = [
616
+ repo_root / "neural_methods" / "model",
617
+ repo_root / "neural_methods",
618
+ repo_root
619
+ ]
620
+
621
+ name_pattern = re.compile(r"mamba", re.IGNORECASE)
622
+
623
+ for base_dir in search_dirs:
624
+ if not base_dir.exists():
625
+ continue
626
+
627
+ for py_file in base_dir.rglob("*.py"):
628
+ if "__pycache__" in str(py_file) or "mamba_ssm" in str(py_file):
629
+ continue
630
+
631
+ try:
632
+ mod = _import_from_file(py_file)
633
+ for name, obj in inspect.getmembers(mod):
634
+ if callable(obj) and name_pattern.search(name) and "ssm" not in name.lower():
635
+ if inspect.isclass(obj) and issubclass(obj, nn.Module):
636
+ return obj
637
+ except Exception:
638
+ continue
639
+
640
+ raise ImportError(f"Could not find PhysMamba model class")
641
+
642
+ def load_physmamba_model(ckpt_path: Path, device: torch.device,
643
+ model_file: str = "", model_class: str = "PhysMamba"):
644
+
645
+ repo_root = Path(".").resolve()
646
+ Builder = find_physmamba_builder(repo_root, model_file, model_class)
647
+
648
+ import inspect
649
+ ctor_trials = [
650
+ {},
651
+ {"d_model": 96},
652
+ {"dim": 96},
653
+ {"d_model": 96, "frames": 128, "img_size": 128, "in_chans": 3},
654
+ {"frames": 128, "img_size": 128, "in_chans": 3},
655
+ {"frame_depth": 3},
656
+ ]
657
+
658
+ model = None
659
+ for kwargs in ctor_trials:
660
+ try:
661
+ candidate = Builder(**kwargs) if inspect.isclass(Builder) else Builder(**kwargs)
662
+ if isinstance(candidate, nn.Module):
663
+ model = candidate
664
+ break
665
+ except Exception:
666
+ continue
667
+
668
+ if model is None:
669
+ raise RuntimeError("Could not construct PhysMamba model")
670
+
671
+ try:
672
+ checkpoint = torch.load(str(ckpt_path), map_location="cpu")
673
+ state_dict = checkpoint.get("state_dict", checkpoint)
674
+
675
+ try:
676
+ model.load_state_dict(state_dict, strict=False)
677
+ except Exception:
678
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
679
+ model.load_state_dict(state_dict, strict=False)
680
+ except Exception:
681
+ pass
682
+
683
+ model.to(device).eval()
684
+
685
+ try:
686
+ with torch.no_grad():
687
+ _ = model(torch.zeros(1, 3, 8, 128, 128, device=device))
688
+ except Exception:
689
+ pass
690
+
691
+ # Disable attention visualization since model forward pass is incompatible
692
+ attention_viz = None
693
+
694
+ return model, attention_viz
695
+
696
+ def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray:
697
+ """
698
+ Stable band-pass with edge-safety and parameter clipping.
699
+ """
700
+ x = np.asarray(x, dtype=float)
701
+ n = int(fs * 2)
702
+ if x.size < max(n, 8): # need at least ~2s
703
+ return x
704
+
705
+ nyq = 0.5 * fs
706
+ lo = max(low / nyq, 1e-6)
707
+ hi = min(high / nyq, 0.999_999)
708
+ if not (0.0 < lo < hi < 1.0):
709
+ return x
710
+
711
+ try:
712
+ b, a = signal.butter(order, [lo, hi], btype="band")
713
+ # padlen must be < len(x); reduce when short
714
+ padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1))
715
+ return signal.filtfilt(b, a, x, padlen=padlen)
716
+ except Exception:
717
+ return x
718
+
719
+ def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float:
720
+ """
721
+ HR (BPM) via Welch PSD peak in [lo, hi] Hz.
722
+ """
723
+ x = np.asarray(x, dtype=float)
724
+ if x.size < int(fs * 4.0): # need ~4s for a usable PSD
725
+ return 0.0
726
+ try:
727
+ # nperseg tuned for short windows while avoiding tiny segments
728
+ nper = int(min(max(64, fs * 2), min(512, x.size)))
729
+ f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant")
730
+
731
+ mask = (f >= lo) & (f <= hi)
732
+ if not np.any(mask):
733
+ return 0.0
734
+ f_band = f[mask]
735
+ p_band = pxx[mask]
736
+
737
+ if p_band.size == 0 or np.all(~np.isfinite(p_band)):
738
+ return 0.0
739
+
740
+ fpk = float(f_band[np.argmax(p_band)])
741
+ bpm = fpk * 60.0
742
+ # clip to plausible range
743
+ return float(np.clip(bpm, 30.0, 220.0))
744
+ except Exception:
745
+ return 0.0
746
+
747
+ def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
748
+ """
749
+ HRV RMSSD from peaks; robust to short/flat segments.
750
+ """
751
+ x = np.asarray(x, dtype=float)
752
+ if x.size < int(fs * 5.0):
753
+ return 0.0
754
+ try:
755
+ # peak distance ~ 0.5s minimum (avoid double counting)
756
+ peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs)))
757
+ if len(peaks) < 3:
758
+ return 0.0
759
+ rr = np.diff(peaks) / fs * 1000.0 # ms
760
+ if rr.size < 2:
761
+ return 0.0
762
+ return float(np.sqrt(np.mean(np.diff(rr) ** 2)))
763
+ except Exception:
764
+ return 0.0
765
+
766
+ def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
767
+ """
768
+ Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band.
769
+ Signature unchanged to avoid breaking callers.
770
+ """
771
+ global _HR_SMOOTH
772
+
773
+ y = np.asarray(pred, dtype=float)
774
+ if y.size == 0:
775
+ return y, 0.0
776
+
777
+ # 1) band-limit
778
+ y_filt = bandpass_filter(y, fs=fs, low=0.7, high=3.5, order=4)
779
+
780
+ # 2) HR estimate
781
+ hr = hr_from_welch(y_filt, fs=fs, lo=0.7, hi=3.5)
782
+
783
+ # 3) gentle attraction to resting band (if way off)
784
+ if hr > 0:
785
+ lo, hi = REST_HR_RANGE
786
+ if hr < lo or hr > hi:
787
+ dist = abs(hr - REST_HR_TARGET)
788
+ # farther away -> stronger pull
789
+ alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65))
790
+ hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET
791
+
792
+ # 4) temporal smoothing to limit frame-to-frame jumps
793
+ if hr > 0:
794
+ if _HR_SMOOTH is None:
795
+ _HR_SMOOTH = hr
796
+ else:
797
+ step = float(np.clip(hr - _HR_SMOOTH, -MAX_JUMP_BPM, MAX_JUMP_BPM))
798
+ _HR_SMOOTH = _HR_SMOOTH + 0.6 * step
799
+ hr = float(_HR_SMOOTH)
800
+
801
+ return y_filt, float(hr)
802
+
803
+ def draw_face_and_roi(frame_bgr: np.ndarray,
804
+ face_bbox: Optional[Tuple[int, int, int, int]],
805
+ roi_bbox: Optional[Tuple[int, int, int, int]],
806
+ label: str = "ROI") -> np.ndarray:
807
+ """
808
+ Draw face (green) and ROI (cyan) rectangles on a copy of the frame.
809
+ """
810
+ vis = frame_bgr.copy()
811
+ if face_bbox is not None:
812
+ x, y, w, h = face_bbox
813
+ cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 230, 0), 2)
814
+ cv2.putText(vis, "FACE", (x, max(20, y - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 230, 0), 2)
815
+ if roi_bbox is not None:
816
+ rx, ry, rw, rh = roi_bbox
817
+ cv2.rectangle(vis, (rx, ry), (rx + rw, ry + rh), (255, 220, 0), 2)
818
+ cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2)
819
+ return vis
820
+
821
+ def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
822
+ roi_type: str,
823
+ frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
824
+ """
825
+ Compute the ROI rectangle (x,y,w,h) from a face bbox and your ROI rule.
826
+ Matches your crop_roi geometry.
827
+ """
828
+ x, y, w, h = face_bbox
829
+ H, W = frame_shape[:2]
830
+ if roi_type == "forehead":
831
+ rx = int(x + 0.25 * w); rw = int(0.5 * w)
832
+ ry = int(y + 0.10 * h); rh = int(0.20 * h)
833
+ elif roi_type == "cheeks":
834
+ rx = int(x + 0.15 * w); rw = int(0.70 * w)
835
+ ry = int(y + 0.55 * h); rh = int(0.30 * h)
836
+ else:
837
+ rx, ry, rw, rh = x, y, w, h
838
+
839
+ rx2 = min(W, rx + rw)
840
+ ry2 = min(H, ry + rh)
841
+ rx = max(0, rx); ry = max(0, ry)
842
+ if rx2 <= rx or ry2 <= ry:
843
+ return (0, 0, 0, 0)
844
+ return (rx, ry, rx2 - rx, ry2 - ry)
845
+
846
+ def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
847
+ """
848
+ Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR.
849
+ Assumes chw = (3, H, W) with zero-mean, unit-var normalization per-frame.
850
+ """
851
+ if chw is None or chw.ndim != 3 or chw.shape[0] != 3:
852
+ return np.zeros((128, 128, 3), dtype=np.uint8)
853
+
854
+ # Undo channel-first & normalization to a viewable image
855
+ img = chw.copy()
856
+ # Re-normalize to 0..1 by min-max of the tensor to "show" contrast
857
+ vmin, vmax = float(img.min()), float(img.max())
858
+ if vmax <= vmin + 1e-6:
859
+ img = np.zeros_like(img)
860
+ else:
861
+ img = (img - vmin) / (vmax - vmin)
862
+
863
+ img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR
864
+ return img
865
+
866
+ def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
867
+ if gt_len <= 1:
868
+ return None
869
+ if gt_fs and gt_fs > 0:
870
+ return np.arange(gt_len, dtype=float) / float(gt_fs)
871
+ return None # will fall back to length-matching overlay
872
+
873
+ def plot_signals_with_gt(time_axis: np.ndarray,
874
+ raw_signal: np.ndarray,
875
+ post_signal: np.ndarray,
876
+ fs: int,
877
+ out_path: str,
878
+ gt_time: Optional[np.ndarray] = None,
879
+ gt_bvp: Optional[np.ndarray] = None,
880
+ title: str = "rPPG Signals (Pred vs GT)") -> str:
881
+ """
882
+ Save a 3-pane plot: (1) predicted raw, (2) predicted post, (3) overlay Pred vs GT (normalized).
883
+ If GT is provided, it is resampled to the prediction time grid and lag-aligned
884
+ (±5 s) using cross-correlation. The overlay includes Pearson r, lag, and HR stats.
885
+ """
886
+ import numpy as _np
887
+ import matplotlib.pyplot as _plt
888
+ from matplotlib.gridspec import GridSpec as _GridSpec
889
+
890
+ def z(x):
891
+ x = _np.asarray(x, dtype=float)
892
+ if x.size == 0:
893
+ return x
894
+ m = float(_np.nanmean(x))
895
+ s = float(_np.nanstd(x)) + 1e-8
896
+ return (x - m) / s
897
+
898
+ def _bandpass(x, fs_local, lo=0.7, hi=3.5, order=4):
899
+ try:
900
+ return bandpass_filter(_np.asarray(x, float), fs=fs_local, low=lo, high=hi, order=order)
901
+ except Exception:
902
+ return _np.asarray(x, float)
903
+
904
+ def _welch_hr(x, fs_local):
905
+ try:
906
+ return float(hr_from_welch(_np.asarray(x, float), fs=fs_local, lo=0.7, hi=3.5))
907
+ except Exception:
908
+ return 0.0
909
+
910
+ def _safe_interp(x_t, y, t_new):
911
+ """Monotonic-safe 1D interpolation with clipping to valid domain."""
912
+ x_t = _np.asarray(x_t, dtype=float).ravel()
913
+ y = _np.asarray(y, dtype=float).ravel()
914
+ t_new = _np.asarray(t_new, dtype=float).ravel()
915
+
916
+ if x_t.size < 2 or y.size != x_t.size:
917
+ # Fallback: length-based resize to t_new length
918
+ if y.size == 0 or t_new.size == 0:
919
+ return _np.zeros_like(t_new)
920
+ idx = _np.linspace(0, y.size - 1, num=t_new.size)
921
+ return _np.interp(_np.arange(t_new.size), idx, y)
922
+
923
+ # Enforce strictly increasing time (dedup if needed)
924
+ order = _np.argsort(x_t)
925
+ x_t = x_t[order]
926
+ y = y[order]
927
+ mask = _np.concatenate(([True], _np.diff(x_t) > 0))
928
+ x_t = x_t[mask]
929
+ y = y[mask]
930
+ # Clip t_new to the valid domain to avoid edge extrapolation artifacts
931
+ t_clip = _np.clip(t_new, x_t[0], x_t[-1])
932
+ return _np.interp(t_clip, x_t, y)
933
+
934
+ def _best_lag(pred, gt, fs_local, max_lag_s=5.0):
935
+ """Return lag (sec) that best aligns GT to Pred using cross-correlation on z-scored signals."""
936
+ x = z(pred); y = z(gt)
937
+ if x.size < 8 or y.size < 8:
938
+ return 0.0
939
+ n = int(min(len(x), len(y)))
940
+ x = x[:n]; y = y[:n]
941
+ max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local))))
942
+ # valid lags: negative means GT should be shifted left (advance) relative to Pred
943
+ lags = _np.arange(-max_lag, max_lag + 1)
944
+ # compute correlation for each lag
945
+ best_corr = -_np.inf
946
+ best_lag = 0
947
+ for L in lags:
948
+ if L < 0:
949
+ xx = x[-L:n]
950
+ yy = y[0:n+L]
951
+ elif L > 0:
952
+ xx = x[0:n-L]
953
+ yy = y[L:n]
954
+ else:
955
+ xx = x
956
+ yy = y
957
+ if xx.size < 8 or yy.size < 8:
958
+ continue
959
+ c = _np.corrcoef(xx, yy)[0, 1]
960
+ if _np.isfinite(c) and c > best_corr:
961
+ best_corr = c
962
+ best_lag = L
963
+ return float(best_lag / float(fs_local))
964
+
965
+ def _apply_lag(y, lag_sec, fs_local):
966
+ """Shift y by lag_sec (positive => delay GT) using sample roll; edges set to NaN."""
967
+ y = _np.asarray(y, float)
968
+ if y.size == 0 or fs_local <= 0:
969
+ return y
970
+ shift = int(round(lag_sec * fs_local))
971
+ if shift == 0:
972
+ return y
973
+ out = _np.empty_like(y)
974
+ out[:] = _np.nan
975
+ if shift > 0:
976
+ # delay: move content right
977
+ out[shift:] = y[:-shift]
978
+ else:
979
+ # advance: move content left
980
+ out[:shift] = y[-shift:]
981
+ return out
982
+
983
+ t = _np.asarray(time_axis, dtype=float)
984
+ raw = _np.asarray(raw_signal, dtype=float)
985
+ post = _np.asarray(post_signal, dtype=float)
986
+
987
+ # guard
988
+ if t.size == 0:
989
+ t = _np.arange(post.size, dtype=float) / max(fs, 1)
990
+
991
+ have_gt = gt_bvp is not None and _np.asarray(gt_bvp).size > 0
992
+ gt_on_pred = None
993
+ lag_sec = 0.0
994
+ pearson_r = _np.nan
995
+ hr_pred = _welch_hr(_bandpass(post, fs), fs)
996
+ hr_gt = 0.0
997
+
998
+ if have_gt:
999
+ gt = _np.asarray(gt_bvp, dtype=float).ravel()
1000
+ if gt_time is not None and _np.asarray(gt_time).size == gt.size:
1001
+ gt_t = _np.asarray(gt_time, dtype=float).ravel()
1002
+ gt_on_pred = _safe_interp(gt_t, gt, t)
1003
+ else:
1004
+ # No time vector: try length-based mapping to pred time grid
1005
+ gt_on_pred = _safe_interp(_np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size),
1006
+ gt, t)
1007
+
1008
+ # Band-limit both before correlation/HR
1009
+ pred_bp = _bandpass(post, fs)
1010
+ gt_bp = _bandpass(gt_on_pred, fs)
1011
+
1012
+ # Estimate best lag (sec) of GT relative to Pred
1013
+ lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0)
1014
+
1015
+ # Apply lag to GT for visualization and correlation
1016
+ gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs)
1017
+
1018
+ # Compute Pearson r on overlapping valid samples
1019
+ valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp)
1020
+ if valid.sum() >= 16:
1021
+ pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1])
1022
+ else:
1023
+ pearson_r = _np.nan
1024
+
1025
+ hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs)
1026
+
1027
+
1028
+ _plt.figure(figsize=(13, 6), dpi=110)
1029
+ gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35)
1030
+
1031
+ # (1) Raw Pred
1032
+ ax1 = _plt.subplot(gs[0, 0])
1033
+ ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5)
1034
+ ax1.set_title(f"Predicted (Raw) — fs={fs} Hz")
1035
+ ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude")
1036
+ ax1.grid(True, alpha=0.3)
1037
+
1038
+ # (2) Post Pred
1039
+ ax2 = _plt.subplot(gs[0, 1])
1040
+ ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5)
1041
+ ax2.set_title("Predicted (Post-processed)")
1042
+ ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude")
1043
+ ax2.grid(True, alpha=0.3)
1044
+
1045
+ # (3) Overlay Pred vs GT (z-scored) OR just post
1046
+ ax3 = _plt.subplot(gs[1, :])
1047
+ ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6)
1048
+
1049
+ if have_gt and gt_on_pred is not None:
1050
+ gt_bp = _bandpass(gt_on_pred, fs)
1051
+ gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs)
1052
+ ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9)
1053
+
1054
+ # metrics box
1055
+ txt = [
1056
+ f"HR_pred: {hr_pred:.1f} BPM",
1057
+ f"HR_gt: {hr_gt:.1f} BPM",
1058
+ f"Pearson r: {pearson_r:.3f}" if _np.isfinite(pearson_r) else "Pearson r: --",
1059
+ f"Lag: {lag_sec:+.2f} s"
1060
+ ]
1061
+ ax3.text(0.01, 0.98, "\n".join(txt), transform=ax3.transAxes,
1062
+ va="top", ha="left", fontsize=9,
1063
+ bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="0.8", alpha=0.9))
1064
+ ax3.set_title("Pred vs GT (z-score overlay, lag-aligned)")
1065
+ else:
1066
+ ax3.set_title("Pred vs GT (no GT provided)")
1067
+
1068
+ ax3.set_xlabel("Time (s)"); ax3.set_ylabel("z")
1069
+ ax3.grid(True, alpha=0.3)
1070
+ ax3.legend(loc="upper right")
1071
+
1072
+ _plt.suptitle(title, fontweight="bold")
1073
+ _plt.tight_layout(rect=[0, 0.02, 1, 0.97])
1074
+ _plt.savefig(out_path, bbox_inches="tight")
1075
+ _plt.close()
1076
+ return out_path
1077
+
1078
+ def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
1079
+ """
1080
+ Robust single-face detector with a few practical guards:
1081
+ - converts to gray safely
1082
+ - equalizes histogram (helps underexposure)
1083
+ - tries multiple scales / minNeighbors
1084
+ - returns the largest face bbox (x,y,w,h) or None
1085
+ """
1086
+ if frame is None or frame.size == 0:
1087
+ return None
1088
+
1089
+ try:
1090
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
1091
+ except Exception:
1092
+ # If color conversion fails, assume already gray
1093
+ gray = frame.copy() if frame.ndim == 2 else cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY)
1094
+
1095
+ # Light preproc to improve Haar performance
1096
+ gray = cv2.equalizeHist(gray)
1097
+
1098
+ faces_all = []
1099
+ # Try a couple of parameter combos to be more forgiving
1100
+ params = [
1101
+ dict(scaleFactor=1.05, minNeighbors=3),
1102
+ dict(scaleFactor=1.10, minNeighbors=4),
1103
+ dict(scaleFactor=1.20, minNeighbors=5),
1104
+ ]
1105
+ for p in params:
1106
+ try:
1107
+ faces = FACE_CASCADE.detectMultiScale(gray, **p)
1108
+ if faces is not None and len(faces) > 0:
1109
+ faces_all.extend([tuple(map(int, f)) for f in faces])
1110
+ except Exception:
1111
+ continue
1112
+
1113
+ if not faces_all:
1114
+ return None
1115
+
1116
+ # Return the largest (by area)
1117
+ return max(faces_all, key=lambda f: f[2] * f[3])
1118
+
1119
+ def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]:
1120
+ """
1121
+ Crop ROI from the frame based on a face bbox and the selected roi_type.
1122
+ Returns the cropped BGR ROI or None if invalid.
1123
+ """
1124
+ if face_bbox is None or frame is None or frame.size == 0:
1125
+ return None
1126
+
1127
+ x, y, w, h = map(int, face_bbox)
1128
+ H, W = frame.shape[:2]
1129
+
1130
+ if roi_type == "forehead":
1131
+ rx = int(x + 0.25 * w); rw = int(0.50 * w)
1132
+ ry = int(y + 0.10 * h); rh = int(0.20 * h)
1133
+ elif roi_type == "cheeks":
1134
+ rx = int(x + 0.15 * w); rw = int(0.70 * w)
1135
+ ry = int(y + 0.55 * h); rh = int(0.30 * h)
1136
+ else:
1137
+ rx, ry, rw, rh = x, y, w, h
1138
+
1139
+ # clamp in-bounds
1140
+ rx = max(0, rx); ry = max(0, ry)
1141
+ rx2 = min(W, rx + rw); ry2 = min(H, ry + rh)
1142
+
1143
+ if rx2 <= rx or ry2 <= ry:
1144
+ return None
1145
+
1146
+ roi = frame[ry:ry2, rx:rx2]
1147
+ # Avoid empty or 1-pixel slivers
1148
+ if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4:
1149
+ return None
1150
+ return roi
1151
+
1152
+ def crop_roi_with_bbox(face_bbox: Tuple[int, int, int, int],
1153
+ roi_type: str,
1154
+ frame: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]:
1155
+ if face_bbox is None or frame is None or frame.size == 0:
1156
+ return None, None
1157
+
1158
+ x, y, w, h = map(int, face_bbox)
1159
+ H, W = frame.shape[:2]
1160
+
1161
+ if roi_type == "forehead":
1162
+ rx = int(x + 0.25 * w); rw = int(0.50 * w)
1163
+ ry = int(y + 0.10 * h); rh = int(0.20 * h)
1164
+ elif roi_type == "cheeks":
1165
+ rx = int(x + 0.15 * w); rw = int(0.70 * w)
1166
+ ry = int(y + 0.55 * h); rh = int(0.30 * h)
1167
+ else:
1168
+ rx, ry, rw, rh = x, y, w, h
1169
+
1170
+ rx = max(0, rx); ry = max(0, ry)
1171
+ rx2 = min(W, rx + rw); ry2 = min(H, ry + rh)
1172
+ if rx2 <= rx or ry2 <= ry:
1173
+ return None, None
1174
+
1175
+ roi = frame[ry:ry2, rx:rx2]
1176
+ if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4:
1177
+ return None, None
1178
+
1179
+ return roi, (rx, ry, rx2 - rx, ry2 - ry)
1180
+
1181
+ def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
1182
+ """
1183
+ PhysMamba-compatible normalization with DiffNormalized support.
1184
+ Returns (3, size, size).
1185
+ """
1186
+ if face_bgr is None or face_bgr.size == 0:
1187
+ return np.zeros((3, size, size), dtype=np.float32)
1188
+
1189
+ try:
1190
+ face = cv2.resize(face_bgr, (size, size), interpolation=cv2.INTER_AREA)
1191
+ except Exception:
1192
+ face = cv2.resize(face_bgr, (size, size))
1193
+
1194
+ face = face.astype(np.float32) / 255.0
1195
+
1196
+ # Per-frame standardization
1197
+ mean = face.mean(axis=(0, 1), keepdims=True)
1198
+ std = face.std(axis=(0, 1), keepdims=True) + 1e-6
1199
+ face = (face - mean) / std
1200
+
1201
+ # BGR -> RGB and HWC -> CHW
1202
+ chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False)
1203
+ return chw
1204
+
1205
+ def extract_attention_map(model, clip_tensor: torch.Tensor,
1206
+ attention_viz) -> Optional[np.ndarray]:
1207
+ """Attention visualization disabled - model architecture incompatible."""
1208
+ return None
1209
+
1210
+ def create_attention_overlay(frame: np.ndarray, attention: Optional[np.ndarray],
1211
+ attention_viz: Optional[SimpleActivationAttention] = None) -> np.ndarray:
1212
+ """Create attention heatmap overlay."""
1213
+ # Attention disabled - return original frame
1214
+ return frame
1215
+
1216
+ def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
1217
+ H, W = roi_bgr.shape[:2]
1218
+ base_bvp = forward_bvp(model, torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE))
1219
+ .unsqueeze(0).unsqueeze(2).to(DEVICE)) # fake T=1 path if needed
1220
+ base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs)
1221
+
1222
+ heat = np.zeros((H, W), np.float32)
1223
+ for y in range(0, H - patch + 1, stride):
1224
+ for x in range(0, W - patch + 1, stride):
1225
+ tmp = roi_bgr.copy()
1226
+ tmp[y:y+patch, x:x+patch] = 127 # occlude
1227
+ bvp = forward_bvp(model, torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE))
1228
+ .unsqueeze(0).unsqueeze(2).to(DEVICE))
1229
+ power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs)
1230
+ drop = max(0.0, base_power - power)
1231
+ heat[y:y+patch, x:x+patch] += drop
1232
+ heat -= heat.min()
1233
+ if heat.max() > 1e-8: heat /= heat.max()
1234
+ return heat
1235
+
1236
+ def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
1237
+ """
1238
+ Try common 5D layouts:
1239
+ [B, C, T, H, W] then [B, T, C, H, W].
1240
+ """
1241
+ last_err = None
1242
+ try:
1243
+ return model(clip_tensor)
1244
+ except Exception as e:
1245
+ last_err = e
1246
+ try:
1247
+ return model(clip_tensor.permute(0, 2, 1, 3, 4).contiguous())
1248
+ except Exception as e:
1249
+ last_err = e
1250
+ raise last_err
1251
+
1252
+ def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
1253
+ """
1254
+ Forward and extract a 1D time-like BVP vector with length T_clip.
1255
+ Tolerant to dict/list/tuple heads and odd shapes.
1256
+ """
1257
+ T_clip = int(clip_tensor.shape[2]) # intended time length for [B,C,T,H,W]
1258
+ with torch.no_grad():
1259
+ out = _call_model_try_orders(model, clip_tensor)
1260
+
1261
+ # unwrap common containers
1262
+ if isinstance(out, dict):
1263
+ for key in ("bvp", "ppg", "signal", "pred", "y", "out", "logits"):
1264
+ if key in out and isinstance(out[key], torch.Tensor):
1265
+ out = out[key]
1266
+ break
1267
+
1268
+ if isinstance(out, (list, tuple)):
1269
+ tensors = [t for t in out if isinstance(t, torch.Tensor)]
1270
+ if not tensors:
1271
+ return np.zeros(T_clip, dtype=np.float32)
1272
+
1273
+ def score(t: torch.Tensor):
1274
+ has_T = 1 if T_clip in t.shape else 0
1275
+ return (has_T, t.numel())
1276
+
1277
+ out = max(tensors, key=score)
1278
+
1279
+ if not isinstance(out, torch.Tensor):
1280
+ return np.zeros(T_clip, dtype=np.float32)
1281
+
1282
+ out = out.detach().cpu().float()
1283
+
1284
+ # ---- 1D
1285
+ if out.ndim == 1:
1286
+ v = out
1287
+ if v.shape[0] == T_clip:
1288
+ return v.numpy()
1289
+ if v.numel() == 1:
1290
+ return np.full(T_clip, float(v.item()), dtype=np.float32)
1291
+ return np.resize(v.numpy(), T_clip).astype(np.float32)
1292
+
1293
+ # ---- 2D
1294
+ if out.ndim == 2:
1295
+ B, K = out.shape
1296
+ if B == 1:
1297
+ v = out[0]
1298
+ return (v.numpy() if v.shape[0] == T_clip else np.resize(v.numpy(), T_clip).astype(np.float32))
1299
+ if B == T_clip:
1300
+ return out[:, 0].numpy()
1301
+ if K == T_clip:
1302
+ return out[0, :].numpy()
1303
+ return np.resize(out.flatten().numpy(), T_clip).astype(np.float32)
1304
+
1305
+ # ---- 3D
1306
+ if out.ndim == 3:
1307
+ B, D1, D2 = out.shape[0], out.shape[1], out.shape[2]
1308
+ if D1 == T_clip: # [B, T, C]
1309
+ return out[0, :, 0].numpy()
1310
+ if D2 == T_clip: # [B, C, T]
1311
+ return out[0, 0, :].numpy()
1312
+ v = out.mean(dim=tuple(range(1, out.ndim))).squeeze(0)
1313
+ return np.resize(v.numpy(), T_clip).astype(np.float32)
1314
+
1315
+ # ---- 4D
1316
+ if out.ndim == 4:
1317
+ B, A, H, W = out.shape
1318
+ if A == T_clip: # [B, T, H, W]
1319
+ return out[0].mean(dim=(1, 2)).numpy()
1320
+ v = out[0].mean(dim=(1, 2))
1321
+ return np.resize(v.numpy(), T_clip).astype(np.float32)
1322
+
1323
+ # ---- 5D+
1324
+ if out.ndim >= 5:
1325
+ shape = list(out.shape)
1326
+ try:
1327
+ t_idx = next(i for i, s in enumerate(shape) if (i != 0 and s == T_clip))
1328
+ except StopIteration:
1329
+ pooled = out[0].mean(dim=tuple(i for i in range(1, out.ndim) if i not in (-1,)))
1330
+ v = pooled.flatten()
1331
+ return np.resize(v.numpy(), T_clip).astype(np.float32)
1332
+
1333
+ axes = list(range(out.ndim))
1334
+ perm = [0, t_idx] + [i for i in axes[1:] if i != t_idx]
1335
+ o2 = out.permute(*perm) # [B, T, ...]
1336
+ pooled = o2.mean(dim=tuple(range(2, o2.ndim))) # -> [B, T]
1337
+ return pooled[0].numpy()
1338
+
1339
+ # fallback: constant vector with mean value
1340
+ val = float(out.mean().item()) if out.numel() else 0.0
1341
+ return np.full(T_clip, val, dtype=np.float32)
1342
+
1343
+ def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
1344
+ """
1345
+ Classical rPPG from green-channel means when the model yields nothing.
1346
+ Detrend -> bandpass -> z-normalize.
1347
+ """
1348
+ if means is None:
1349
+ return np.array([], dtype=np.float32)
1350
+
1351
+ x = np.asarray(means, dtype=np.float32)
1352
+ if x.size == 0:
1353
+ return np.array([], dtype=np.float32)
1354
+
1355
+ x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
1356
+
1357
+ try:
1358
+ x = signal.detrend(x, type="linear")
1359
+ except Exception:
1360
+ pass
1361
+
1362
+ y = bandpass_filter(x, fs=fs, low=0.7, high=3.5, order=4)
1363
+
1364
+ std = float(np.std(y)) + 1e-6
1365
+ return (y / std).astype(np.float32)
1366
+
1367
+ def _to_floats(s: str) -> List[float]:
1368
+ """
1369
+ Extract all real numbers from free-form text, including scientific notation.
1370
+ Gracefully ignores 'nan', 'inf', units, and comments.
1371
+ """
1372
+ if not isinstance(s, str) or not s:
1373
+ return []
1374
+
1375
+ s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE)
1376
+
1377
+ s = s.replace(",", " ").replace(";", " ")
1378
+
1379
+ toks = re.findall(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?", s)
1380
+ out: List[float] = []
1381
+ for t in toks:
1382
+ try:
1383
+ v = float(t)
1384
+ if np.isfinite(v):
1385
+ out.append(v)
1386
+ except Exception:
1387
+ continue
1388
+ return out
1389
+
1390
+ def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
1391
+ """
1392
+ Parse ground-truth files from:
1393
+ • UBFC/UBFC-rPPG style TXT: 3 lines => PPG, HR, timestep(s)
1394
+ • Generic TXT: one column (or free-form) numeric sequence
1395
+ • JSON: keys like 'ppg' / 'bvp' (optionally nested), 'hr', 'fs'
1396
+ • CSV: columns named BVP/PPG/Signal + optional HR + optional Time
1397
+ • MAT: array under keys ['ppg','bvp','signal','wave']; optionally 'fs'/'hr'
1398
+ • NPY: np.load() 1-D array; optionally sidecar .json with fs/hr (same stem)
1399
+
1400
+ Returns:
1401
+ bvp : np.ndarray (float) — may be empty
1402
+ hr : float (mean HR if available or estimated)
1403
+ fs_hint : float — sampling rate if derivable (0.0 if unknown)
1404
+ """
1405
+ if not gt_path or not os.path.exists(gt_path):
1406
+ return np.array([]), 0.0, 0.0
1407
+
1408
+ p = Path(gt_path)
1409
+ ext = p.suffix.lower()
1410
+
1411
+ def _fs_from_time_vector(tv: np.ndarray) -> float:
1412
+ tv = np.asarray(tv, dtype=float)
1413
+ if tv.ndim != 1 or tv.size < 2:
1414
+ return 0.0
1415
+ diffs = np.diff(tv)
1416
+ diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
1417
+ return (1.0 / float(np.median(diffs))) if diffs.size else 0.0
1418
+
1419
+
1420
+ def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float:
1421
+ if bvp is None or bvp.size == 0:
1422
+ return 0.0
1423
+ fs_use = fs_hint if (fs_hint and fs_hint > 0) else 30.0
1424
+ bp = bandpass_filter(bvp.astype(float), fs=fs_use)
1425
+ return hr_from_welch(bp, fs=fs_use)
1426
+
1427
+ if p.name.lower() == "ground_truth.txt" or (ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2):
1428
+ try:
1429
+ lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines() if ln.strip()]
1430
+ ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else []
1431
+ hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else []
1432
+ t_vals = _to_floats(lines[2]) if len(lines) >= 3 else []
1433
+
1434
+ bvp = np.asarray(ppg_vals, dtype=float) if ppg_vals else np.array([], dtype=float)
1435
+
1436
+ hr = float(np.nanmean(hr_vals)) if hr_vals else 0.0
1437
+ fs_hint = 0.0
1438
+ if t_vals:
1439
+ if len(t_vals) == 1:
1440
+ dt = float(t_vals[0])
1441
+ if dt > 0:
1442
+ fs_hint = 1.0 / dt
1443
+ else:
1444
+ fs_hint = _fs_from_time_vector(np.asarray(t_vals, dtype=float))
1445
+
1446
+ if hr == 0.0 and bvp.size:
1447
+ hr = _hr_from_bvp(bvp, fs_hint)
1448
+ return bvp, hr, fs_hint
1449
+ except Exception:
1450
+ # Fall through to generic handlers
1451
+ pass
1452
+
1453
+
1454
+ if ext == ".txt":
1455
+ try:
1456
+ nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore"))
1457
+ bvp = np.asarray(nums, dtype=float) if nums else np.array([], dtype=float)
1458
+ hr = _hr_from_bvp(bvp, fs_hint=0.0) if bvp.size else 0.0
1459
+ return bvp, hr, 0.0
1460
+ except Exception:
1461
+ return np.array([]), 0.0, 0.0
1462
+
1463
+
1464
+ if ext == ".json":
1465
+ try:
1466
+ data = json.loads(p.read_text(encoding="utf-8", errors="ignore"))
1467
+ # Try several paths for BVP array
1468
+ bvp = None
1469
+
1470
+ def _seek(obj, keys):
1471
+ for k in keys:
1472
+ if isinstance(obj, dict) and k in obj:
1473
+ return obj[k]
1474
+ return None
1475
+
1476
+ # Direct top-level
1477
+ bvp = _seek(data, ("ppg", "bvp", "signal", "wave"))
1478
+ # Common nested containers
1479
+ if bvp is None:
1480
+ for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"):
1481
+ if container_key in data:
1482
+ cand = _seek(data[container_key], ("ppg", "bvp", "signal", "wave"))
1483
+ if cand is not None:
1484
+ bvp = cand
1485
+ break
1486
+
1487
+ if bvp is not None:
1488
+ bvp = np.asarray(bvp, dtype=float).ravel()
1489
+ else:
1490
+ bvp = np.array([], dtype=float)
1491
+
1492
+ # fs / hr (accept scalar or array)
1493
+ fs_hint = 0.0
1494
+ if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0:
1495
+ fs_hint = float(data["fs"])
1496
+
1497
+ hr = 0.0
1498
+ if "hr" in data:
1499
+ v = data["hr"]
1500
+ hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
1501
+
1502
+ if hr == 0.0 and bvp.size:
1503
+ hr = _hr_from_bvp(bvp, fs_hint)
1504
+ return bvp, hr, fs_hint
1505
+ except Exception:
1506
+ return np.array([]), 0.0, 0.0
1507
+
1508
+
1509
+ if ext == ".csv":
1510
+ try:
1511
+ df = pd.read_csv(p)
1512
+ # Normalize column names
1513
+ cols = {str(c).strip().lower(): c for c in df.columns}
1514
+
1515
+ def _first_match(names):
1516
+ for nm in names:
1517
+ if nm in cols:
1518
+ return cols[nm]
1519
+ return None
1520
+
1521
+ bvp_col = _first_match(["bvp", "ppg", "wave", "signal", "bvp_signal", "ppg_signal"])
1522
+ hr_col = _first_match(["hr", "heart_rate", "hr_bpm", "bpm"])
1523
+ t_col = _first_match(["time", "t", "timestamp", "sec", "seconds", "time_s"])
1524
+
1525
+ bvp = np.asarray(df[bvp_col].values, dtype=float) if bvp_col else np.array([], dtype=float)
1526
+
1527
+ fs_hint = 0.0
1528
+ if t_col is not None and len(df[t_col].values) >= 2:
1529
+ fs_hint = _fs_from_time_vector(np.asarray(df[t_col].values, dtype=float))
1530
+
1531
+ hr = float(np.nanmean(df[hr_col].values)) if (hr_col and df[hr_col].notna().any()) else 0.0
1532
+ if hr == 0.0 and bvp.size:
1533
+ hr = _hr_from_bvp(bvp, fs_hint)
1534
+ return bvp, hr, fs_hint
1535
+ except Exception:
1536
+ return np.array([]), 0.0, 0.0
1537
+
1538
+
1539
+ if ext == ".mat":
1540
+ try:
1541
+ md = loadmat(str(p))
1542
+ # look for most likely array
1543
+ arr = None
1544
+ for key in ("ppg", "bvp", "signal", "wave"):
1545
+ if key in md and isinstance(md[key], np.ndarray):
1546
+ arr = md[key]
1547
+ break
1548
+ if arr is None:
1549
+ # fallback: first 1-D array
1550
+ for v in md.values():
1551
+ if isinstance(v, np.ndarray) and v.ndim == 1:
1552
+ arr = v
1553
+ break
1554
+ bvp = np.asarray(arr, dtype=float).ravel() if arr is not None else np.array([], dtype=float)
1555
+
1556
+ fs_hint = 0.0
1557
+ for k in ("fs", "Fs", "sampling_rate", "sr"):
1558
+ if k in md:
1559
+ try:
1560
+ fs_hint = float(np.ravel(md[k])[0])
1561
+ break
1562
+ except Exception:
1563
+ pass
1564
+
1565
+ hr = 0.0
1566
+ if "hr" in md:
1567
+ try:
1568
+ hr = float(np.nanmean(np.ravel(md["hr"])))
1569
+ except Exception:
1570
+ hr = 0.0
1571
+
1572
+ if hr == 0.0 and bvp.size:
1573
+ hr = _hr_from_bvp(bvp, fs_hint)
1574
+ return bvp, hr, fs_hint
1575
+ except Exception:
1576
+ return np.array([]), 0.0, 0.0
1577
+
1578
+ # ================= NPY =================
1579
+ if ext == ".npy":
1580
+ try:
1581
+ bvp = np.asarray(np.load(str(p)), dtype=float).ravel()
1582
+ fs_hint, hr = 0.0, 0.0
1583
+ # optional sidecar JSON (same stem) with fs/hr
1584
+ sidecar = p.with_suffix(".json")
1585
+ if sidecar.exists():
1586
+ try:
1587
+ meta = json.loads(sidecar.read_text(encoding="utf-8", errors="ignore"))
1588
+ if isinstance(meta.get("fs", None), (int, float)) and meta["fs"] > 0:
1589
+ fs_hint = float(meta["fs"])
1590
+ if "hr" in meta:
1591
+ v = meta["hr"]
1592
+ hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
1593
+ except Exception:
1594
+ pass
1595
+ if hr == 0.0 and bvp.size:
1596
+ hr = _hr_from_bvp(bvp, fs_hint)
1597
+ return bvp, hr, fs_hint
1598
+ except Exception:
1599
+ return np.array([]), 0.0, 0.0
1600
+
1601
+ # Fallback (unsupported extension)
1602
+ return np.array([]), 0.0, 0.0
1603
+
1604
+ def scan_models() -> List[str]:
1605
+ if not MODEL_DIR.exists():
1606
+ return []
1607
+
1608
+ models = []
1609
+ for f in sorted(MODEL_DIR.iterdir()):
1610
+ if f.suffix.lower() == '.pth':
1611
+ models.append(f.name)
1612
+
1613
+ return models
1614
+
1615
+ _GLOBAL_CONTROLS: Dict[str, Dict] = {}
1616
+
1617
+ def ensure_controls(control_id: str) -> Tuple[str, Dict]:
1618
+ # Use a stable default so Pause/Resume/Stop work for the current run
1619
+ if not control_id:
1620
+ control_id = "default-session"
1621
+ if control_id not in _GLOBAL_CONTROLS:
1622
+ _GLOBAL_CONTROLS[control_id] = {
1623
+ 'pause': threading.Event(),
1624
+ 'stop': threading.Event()
1625
+ }
1626
+ return control_id, _GLOBAL_CONTROLS[control_id]
1627
+
1628
+ def process_video_file(
1629
+ video_path: str,
1630
+ gt_file: Optional[str],
1631
+ model_name: str,
1632
+ fps_input: int,
1633
+ max_seconds: int,
1634
+ roi_type: str,
1635
+ control_id: str
1636
+ ):
1637
+ """
1638
+ Enhanced video processing with Grad-CAM attention visualization.
1639
+ """
1640
+ global _HR_SMOOTH
1641
+ _HR_SMOOTH = None
1642
+
1643
+ def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
1644
+ if gt_len <= 1:
1645
+ return None
1646
+ if gt_fs and gt_fs > 0:
1647
+ return np.arange(gt_len, dtype=float) / float(gt_fs)
1648
+ return None
1649
+
1650
+ control_id, controls = ensure_controls(control_id)
1651
+ controls['stop'].clear()
1652
+
1653
+ if not model_name:
1654
+ yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
1655
+ return
1656
+
1657
+ if isinstance(model_name, int):
1658
+ model_name = str(model_name)
1659
+
1660
+ model_path = MODEL_DIR / model_name
1661
+ if not model_path.exists():
1662
+ yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None)
1663
+ return
1664
+
1665
+ try:
1666
+ model, attention_viz = load_physmamba_model(model_path, DEVICE)
1667
+ except Exception as e:
1668
+ yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None)
1669
+ return
1670
+
1671
+
1672
+ gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0)
1673
+
1674
+ if not video_path or not os.path.exists(video_path):
1675
+ yield ("ERROR: Video not found", None, None, None, None, None, None, None, None, None)
1676
+ return
1677
+
1678
+ cap = cv2.VideoCapture(video_path)
1679
+ if not cap.isOpened():
1680
+ yield ("ERROR: Cannot open video", None, None, None, None, None, None, None, None, None)
1681
+ return
1682
+
1683
+ fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30)
1684
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
1685
+ max_frames = int(max_seconds * fps) if max_seconds and max_seconds > 0 else total_frames
1686
+ MAX_SIGNAL_LENGTH = max_frames if max_frames > 0 else fps * 600
1687
+
1688
+ frames_chw = deque(maxlen=DEFAULT_T)
1689
+ raw_g_means = deque(maxlen=DEFAULT_T)
1690
+ bvp_stream: List[float] = []
1691
+ frame_idx = 0
1692
+ last_infer = -1
1693
+ last_bpm = 0.0
1694
+ last_rmssd = 0.0
1695
+ last_attention = None
1696
+
1697
+ start_time = time.time()
1698
+ next_display = time.time()
1699
+
1700
+ tmpdir = Path(tempfile.gettempdir())
1701
+ frame_path = tmpdir / "frame.jpg"
1702
+ attention_path = tmpdir / "attention.jpg"
1703
+ signal_path = tmpdir / "signal.png"
1704
+ raw_path = tmpdir / "raw_signal.png"
1705
+ post_path = tmpdir / "post_signal.png"
1706
+
1707
+ yield ("Starting… reading video frames", None, f"{gt_hr:.1f}" if gt_hr > 0 else "--",
1708
+ None, None, None, None, None, None, None)
1709
+
1710
+ while True:
1711
+ if controls['stop'].is_set():
1712
+ break
1713
+
1714
+ while controls['pause'].is_set():
1715
+ time.sleep(0.2)
1716
+ if controls['stop'].is_set():
1717
+ break
1718
+
1719
+ ret, frame = cap.read()
1720
+ if not ret or (max_frames > 0 and frame_idx >= max_frames):
1721
+ break
1722
+
1723
+ frame_idx += 1
1724
+
1725
+ face = detect_face(frame)
1726
+ vis_frame = frame.copy()
1727
+
1728
+ if face is not None:
1729
+ x, y, w, h = face
1730
+ cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
1731
+
1732
+ if roi_type == "auto":
1733
+ roi, _ = pick_auto_roi(face, frame, attn=last_attention)
1734
+ else:
1735
+ roi = crop_roi(face, roi_type, frame)
1736
+
1737
+ if roi is not None and roi.size > 0:
1738
+ try:
1739
+ g_mean = float(roi[..., 1].astype(np.float32).mean())
1740
+ raw_g_means.append(g_mean)
1741
+ except Exception:
1742
+ pass
1743
+
1744
+ face_norm = normalize_frame(roi, DEFAULT_SIZE)
1745
+ frames_chw.append(face_norm)
1746
+
1747
+ if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE:
1748
+ try:
1749
+ clip = np.stack(list(frames_chw), axis=1).astype(np.float32)
1750
+ except Exception as e:
1751
+ print(f"[infer] clip stack failed: {e}")
1752
+ clip = None
1753
+
1754
+ bvp_out = None
1755
+ if clip is not None:
1756
+ clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE)
1757
+
1758
+ try:
1759
+ raw = forward_bvp(model, clip_t)
1760
+ if isinstance(raw, np.ndarray):
1761
+ raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
1762
+ bvp_out = raw if raw.size > 0 else None
1763
+ else:
1764
+ bvp_out = None
1765
+ except Exception as e:
1766
+ print(f"[infer] forward_bvp error: {e}")
1767
+ bvp_out = None
1768
+
1769
+ # UPDATED: Generate attention with Grad-CAM
1770
+ try:
1771
+ last_attention = extract_attention_map(model, clip_t, attention_viz)
1772
+ except Exception as e:
1773
+ print(f"⚠ Attention generation error: {e}")
1774
+ last_attention = None
1775
+
1776
+ if bvp_out is None or bvp_out.size == 0:
1777
+ gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0)
1778
+ fb = _fallback_bvp_from_means(gbuf, fs=fps)
1779
+ if isinstance(fb, np.ndarray) and fb.size > 0:
1780
+ bvp_out = fb
1781
+ else:
1782
+ print("[infer] fallback produced empty output")
1783
+
1784
+ if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0:
1785
+ tail = min(DEFAULT_STRIDE, bvp_out.size)
1786
+ bvp_stream.extend(bvp_out[-tail:].tolist())
1787
+ if len(bvp_stream) > MAX_SIGNAL_LENGTH:
1788
+ bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
1789
+
1790
+ if len(bvp_stream) >= int(5 * fps):
1791
+ seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32)
1792
+ _, last_bpm = postprocess_bvp(seg, fs=fps)
1793
+ last_rmssd = compute_rmssd(seg, fs=fps)
1794
+
1795
+ if frame_idx % (DEFAULT_STRIDE * 2) == 0:
1796
+ print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}")
1797
+ else:
1798
+ print("[infer] no usable bvp_out after fallback")
1799
+
1800
+ last_infer = frame_idx
1801
+
1802
+ else:
1803
+ cv2.putText(vis_frame, "No face detected", (20, 40),
1804
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2)
1805
+
1806
+ if last_bpm > 0:
1807
+ color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255)
1808
+ cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1)
1809
+ cv2.putText(vis_frame, f"HR: {last_bpm:.1f} BPM", (20, 48),
1810
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
1811
+
1812
+ vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
1813
+
1814
+ cv2.imwrite(str(frame_path), vis_frame)
1815
+ cv2.imwrite(str(attention_path), vis_attention)
1816
+
1817
+ now = time.time()
1818
+ if now >= next_display:
1819
+ if len(bvp_stream) >= max(10, int(1 * fps)):
1820
+ try:
1821
+ signal_array = np.array(bvp_stream, dtype=float)
1822
+ time_axis = np.arange(len(signal_array)) / fps
1823
+
1824
+ raw_signal = bandpass_filter(signal_array, fs=fps)
1825
+ post_signal, _ = postprocess_bvp(signal_array, fs=fps)
1826
+
1827
+ raw_vis = raw_signal - np.mean(raw_signal)
1828
+ post_vis = post_signal - np.mean(post_signal)
1829
+
1830
+ plt.figure(figsize=(10, 4), dpi=100)
1831
+ plt.plot(time_axis, raw_vis, linewidth=1.6)
1832
+ plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
1833
+ plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM')
1834
+ plt.grid(True, alpha=0.3)
1835
+ plt.tight_layout(); plt.savefig(str(raw_path), dpi=100, bbox_inches='tight'); plt.close()
1836
+
1837
+ plt.figure(figsize=(10, 4), dpi=100)
1838
+ plt.plot(time_axis, post_vis, linewidth=1.6)
1839
+ plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
1840
+ plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM')
1841
+ plt.grid(True, alpha=0.3)
1842
+ plt.tight_layout(); plt.savefig(str(post_path), dpi=100, bbox_inches='tight'); plt.close()
1843
+
1844
+ if gt_bvp is not None and gt_bvp.size > 0:
1845
+ try:
1846
+ gt_time = _gt_time_axis(len(gt_bvp), gt_fs)
1847
+ plot_signals_with_gt(
1848
+ time_axis=time_axis,
1849
+ raw_signal=raw_signal,
1850
+ post_signal=post_signal,
1851
+ fs=fps,
1852
+ out_path=str(signal_path),
1853
+ gt_time=gt_time,
1854
+ gt_bvp=gt_bvp,
1855
+ title=f"Pred vs GT — HR: {last_bpm:.1f} BPM"
1856
+ )
1857
+ except Exception:
1858
+ fig = plt.figure(figsize=(12, 5), dpi=100)
1859
+ gs = GridSpec(1, 2, figure=fig, wspace=0.3)
1860
+ ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6)
1861
+ ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3)
1862
+ ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6)
1863
+ ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3)
1864
+ plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold')
1865
+ plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all')
1866
+ else:
1867
+ fig = plt.figure(figsize=(12, 5), dpi=100)
1868
+ gs = GridSpec(1, 2, figure=fig, wspace=0.3)
1869
+ ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6)
1870
+ ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3)
1871
+ ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6)
1872
+ ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3)
1873
+ plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold')
1874
+ plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all')
1875
+ except Exception:
1876
+ pass
1877
+
1878
+ elapsed = now - start_time
1879
+ status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
1880
+
1881
+ yield (
1882
+ status,
1883
+ f"{last_bpm:.1f}" if last_bpm > 0 else None,
1884
+ f"{gt_hr:.1f}" if gt_hr > 0 else "--",
1885
+ f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
1886
+ str(frame_path),
1887
+ str(attention_path),
1888
+ str(signal_path) if signal_path.exists() else None,
1889
+ str(raw_path) if raw_path.exists() else None,
1890
+ str(post_path) if post_path.exists() else None,
1891
+ None
1892
+ )
1893
+
1894
+ next_display = now + (1.0 / DISPLAY_FPS)
1895
+
1896
+ cap.release()
1897
+
1898
+ # Cleanup Grad-CAM
1899
+ if attention_viz: attention_viz.cleanup()
1900
+
1901
+ csv_path = None
1902
+ if bvp_stream:
1903
+ csv_path = Path(tempfile.gettempdir()) / "bvp_output.csv"
1904
+ time_array = np.arange(len(bvp_stream)) / fps
1905
+ signal_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps)
1906
+ try:
1907
+ pd.DataFrame({'time_s': time_array, 'bvp': signal_final}).to_csv(csv_path, index=False)
1908
+ except Exception:
1909
+ csv_path = None
1910
+
1911
+ try:
1912
+ final_overlay = Path(tempfile.gettempdir()) / "signal_final_overlay.png"
1913
+ if gt_bvp is not None and gt_bvp.size > 0:
1914
+ gt_time = _gt_time_axis(len(gt_bvp), gt_fs)
1915
+ plot_signals_with_gt(
1916
+ time_axis=time_array,
1917
+ raw_signal=bandpass_filter(np.array(bvp_stream, dtype=float), fs=fps),
1918
+ post_signal=signal_final,
1919
+ fs=fps,
1920
+ out_path=str(final_overlay),
1921
+ gt_time=gt_time,
1922
+ gt_bvp=gt_bvp,
1923
+ title=f"Final Pred vs GT — HR: {last_bpm:.1f} BPM"
1924
+ )
1925
+ if final_overlay.exists():
1926
+ signal_path = final_overlay
1927
+ except Exception:
1928
+ pass
1929
+
1930
+ elapsed = time.time() - start_time
1931
+ final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
1932
+
1933
+ yield (
1934
+ final_status,
1935
+ f"{last_bpm:.1f}" if last_bpm > 0 else None,
1936
+ f"{gt_hr:.1f}" if gt_hr > 0 else "--",
1937
+ f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
1938
+ str(frame_path),
1939
+ str(attention_path),
1940
+ str(signal_path) if signal_path.exists() else None,
1941
+ str(raw_path) if raw_path.exists() else None,
1942
+ str(post_path) if post_path.exists() else None,
1943
+ str(csv_path) if csv_path else None
1944
+ )
1945
+
1946
+ def process_live_webcam(
1947
+ model_name: str,
1948
+ fps_input: int,
1949
+ roi_type: str,
1950
+ control_id: str
1951
+ ):
1952
+ """Stream live webcam with Grad-CAM attention visualization."""
1953
+ global _HR_SMOOTH
1954
+ _HR_SMOOTH = None
1955
+
1956
+ def _perf_heartbeat(frame_idx, t0, bvp_len, frames_chw_len, fps):
1957
+ if frame_idx == 1:
1958
+ print(f"[run] device={DEVICE} target_fps={fps}")
1959
+ if frame_idx % 60 == 0:
1960
+ elapsed = time.time() - t0
1961
+ cur_fps = frame_idx / max(elapsed, 1e-6)
1962
+ print(f"[perf] frames={frame_idx} ({cur_fps:.1f} FPS) "
1963
+ f"clip={frames_chw_len}/{DEFAULT_T} bvp_len={bvp_len}")
1964
+
1965
+ control_id, controls = ensure_controls(control_id)
1966
+ controls['stop'].clear()
1967
+
1968
+ if not model_name:
1969
+ yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
1970
+ return
1971
+
1972
+ model_path = MODEL_DIR / model_name
1973
+ if not model_path.exists():
1974
+ yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None)
1975
+ return
1976
+
1977
+ try:
1978
+ model, attention_viz = load_physmamba_model(model_path, DEVICE)
1979
+ except Exception as e:
1980
+ yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None)
1981
+ return
1982
+
1983
+
1984
+ cap = None
1985
+ for camera_id in [0, 1]:
1986
+ for backend in [cv2.CAP_AVFOUNDATION, cv2.CAP_ANY]:
1987
+ try:
1988
+ test_cap = cv2.VideoCapture(camera_id, backend)
1989
+ if test_cap.isOpened():
1990
+ test_cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
1991
+ test_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
1992
+ ret, frame = test_cap.read()
1993
+ if ret and frame is not None:
1994
+ cap = test_cap
1995
+ break
1996
+ test_cap.release()
1997
+ except Exception:
1998
+ pass
1999
+ if cap is not None:
2000
+ break
2001
+
2002
+ if cap is None:
2003
+ yield ("ERROR: Cannot access webcam", None, None, None, None, None, None, None, None, None)
2004
+ return
2005
+
2006
+ fps = int(fps_input) if fps_input else 30
2007
+
2008
+ frames_chw = deque(maxlen=DEFAULT_T)
2009
+ raw_g_means = deque(maxlen=DEFAULT_T)
2010
+ bvp_stream: List[float] = []
2011
+ frame_idx = 0
2012
+ last_infer = -1
2013
+ last_bpm = 0.0
2014
+ last_rmssd = 0.0
2015
+ last_attention = None
2016
+
2017
+ t0 = time.time()
2018
+ next_display = time.time()
2019
+ DISPLAY_INTERVAL = 0.5
2020
+
2021
+ frame_path = Path(tempfile.gettempdir()) / "live_frame.jpg"
2022
+ attention_path = Path(tempfile.gettempdir()) / "live_attention.jpg"
2023
+ signal_path = Path(tempfile.gettempdir()) / "live_signal.png"
2024
+ raw_path = Path(tempfile.gettempdir()) / "live_raw.png"
2025
+ post_path = Path(tempfile.gettempdir()) / "live_post.png"
2026
+
2027
+ MAX_SIGNAL_LENGTH = fps * 60
2028
+
2029
+ yield ("Starting… waiting for frames", None, "--", None,
2030
+ None, None, None, None, None, None)
2031
+
2032
+ while True:
2033
+ if controls['stop'].is_set():
2034
+ break
2035
+
2036
+ while controls['pause'].is_set():
2037
+ time.sleep(0.2)
2038
+ if controls['stop'].is_set():
2039
+ break
2040
+
2041
+ ret, frame = cap.read()
2042
+ if not ret:
2043
+ time.sleep(0.05)
2044
+ continue
2045
+
2046
+ frame_idx += 1
2047
+ _perf_heartbeat(frame_idx, t0, len(bvp_stream), len(frames_chw), fps)
2048
+
2049
+ face = detect_face(frame)
2050
+ vis_frame = frame.copy()
2051
+
2052
+ if face is not None:
2053
+ x, y, w, h = face
2054
+ cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 3)
2055
+ cv2.putText(vis_frame, "FACE", (x, max(20, y - 10)),
2056
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
2057
+
2058
+ if roi_type == "auto":
2059
+ roi, _ = pick_auto_roi(face, frame, attn=last_attention)
2060
+ else:
2061
+ roi = crop_roi(face, roi_type, frame)
2062
+
2063
+ if roi is not None and roi.size > 0:
2064
+ try:
2065
+ g_mean = float(roi[..., 1].astype(np.float32).mean())
2066
+ raw_g_means.append(g_mean)
2067
+ except Exception:
2068
+ pass
2069
+
2070
+ chw = normalize_frame(roi, DEFAULT_SIZE)
2071
+ frames_chw.append(chw)
2072
+
2073
+ if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE:
2074
+ try:
2075
+ clip = np.stack(list(frames_chw), axis=1).astype(np.float32)
2076
+ except Exception as e:
2077
+ print(f"[infer] clip stack failed: {e}")
2078
+ clip = None
2079
+
2080
+ bvp_out = None
2081
+ if clip is not None:
2082
+ clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE)
2083
+ try:
2084
+ raw = forward_bvp(model, clip_t)
2085
+ if isinstance(raw, np.ndarray):
2086
+ raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
2087
+ bvp_out = raw if raw.size > 0 else None
2088
+ else:
2089
+ bvp_out = None
2090
+ except Exception as e:
2091
+ print(f"[infer] forward_bvp error: {e}")
2092
+ bvp_out = None
2093
+
2094
+ try:
2095
+ last_attention = extract_attention_map(model, clip_t, attention_viz)
2096
+ except Exception as e:
2097
+ print(f" Attention generation error: {e}")
2098
+ last_attention = None
2099
+
2100
+ if bvp_out is None or bvp_out.size == 0:
2101
+ gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0)
2102
+ fb = _fallback_bvp_from_means(gbuf, fs=fps)
2103
+ if isinstance(fb, np.ndarray) and fb.size > 0:
2104
+ bvp_out = fb
2105
+ else:
2106
+ print("[infer] fallback produced empty output")
2107
+
2108
+ if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0:
2109
+ tail = min(DEFAULT_STRIDE, bvp_out.size)
2110
+ bvp_stream.extend(bvp_out[-tail:].tolist())
2111
+ if len(bvp_stream) > MAX_SIGNAL_LENGTH:
2112
+ bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
2113
+
2114
+ if len(bvp_stream) >= int(5 * fps):
2115
+ seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32)
2116
+ _, last_bpm = postprocess_bvp(seg, fs=fps)
2117
+ last_rmssd = compute_rmssd(seg, fs=fps)
2118
+
2119
+ if frame_idx % (DEFAULT_STRIDE * 2) == 0:
2120
+ print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}")
2121
+ else:
2122
+ print("[infer] no usable bvp_out after fallback")
2123
+
2124
+ last_infer = frame_idx
2125
+
2126
+ else:
2127
+ cv2.putText(vis_frame, "No face detected", (20, 40),
2128
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2)
2129
+
2130
+ cv2.putText(vis_frame, f"Fill: {len(frames_chw)}/{DEFAULT_T}", (20, 25),
2131
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
2132
+ cv2.putText(vis_frame, f"BVP: {len(bvp_stream)}", (20, 45),
2133
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
2134
+
2135
+ if last_bpm > 0:
2136
+ color = (50, 255, 50) if 55 <= last_bpm <= 100 else (50, 50, 255) if last_bpm > 100 else (255, 200, 50)
2137
+ overlay = vis_frame.copy()
2138
+ cv2.rectangle(overlay, (10, 10), (450, 100), (0, 0, 0), -1)
2139
+ vis_frame = cv2.addWeighted(vis_frame, 0.6, overlay, 0.4, 0)
2140
+ cv2.putText(vis_frame, f"HR: {last_bpm:.0f} BPM", (20, 80),
2141
+ cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 3)
2142
+ cv2.circle(vis_frame, (30, 30), 10, (0, 255, 0), -1)
2143
+ cv2.putText(vis_frame, "LIVE", (50, 38),
2144
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
2145
+ else:
2146
+ cv2.putText(vis_frame, "Collecting…", (20, 80),
2147
+ cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
2148
+
2149
+ vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
2150
+
2151
+ cv2.imwrite(str(frame_path), vis_frame)
2152
+ cv2.imwrite(str(attention_path), vis_attention)
2153
+
2154
+ now = time.time()
2155
+ if now >= next_display:
2156
+ if len(bvp_stream) >= max(10, int(1 * fps)):
2157
+ try:
2158
+ sig = np.array(bvp_stream, dtype=float)
2159
+ t_axis = np.arange(len(sig)) / fps
2160
+ raw_sig = bandpass_filter(sig, fs=fps)
2161
+ post_sig, _ = postprocess_bvp(sig, fs=fps)
2162
+
2163
+ rv = raw_sig - np.mean(raw_sig)
2164
+ pv = post_sig - np.mean(post_sig)
2165
+
2166
+ plt.figure(figsize=(10, 4), dpi=100)
2167
+ plt.plot(t_axis, rv, linewidth=2)
2168
+ plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
2169
+ plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM')
2170
+ plt.grid(True, alpha=0.3)
2171
+ plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
2172
+ plt.tight_layout(); plt.savefig(str(raw_path)); plt.close()
2173
+
2174
+ plt.figure(figsize=(10, 4), dpi=100)
2175
+ plt.plot(t_axis, pv, linewidth=2)
2176
+ plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
2177
+ plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM')
2178
+ plt.grid(True, alpha=0.3)
2179
+ plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
2180
+ plt.tight_layout(); plt.savefig(str(post_path)); plt.close()
2181
+
2182
+ fig = plt.figure(figsize=(12, 5), dpi=100)
2183
+ from matplotlib.gridspec import GridSpec
2184
+ gs = GridSpec(1, 2, figure=fig, wspace=0.3)
2185
+ ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(t_axis, rv, linewidth=2)
2186
+ ax1.set_title('Raw Signal'); ax1.grid(True, alpha=0.3)
2187
+ ax1.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
2188
+ ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(t_axis, pv, linewidth=2)
2189
+ ax2.set_title('Post-processed'); ax2.grid(True, alpha=0.3)
2190
+ ax2.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
2191
+ plt.suptitle(f'LIVE rPPG - HR: {last_bpm:.1f} BPM')
2192
+ plt.savefig(str(signal_path)); plt.close('all')
2193
+ except Exception:
2194
+ pass
2195
+
2196
+ elapsed = now - t0
2197
+ status = (f"Frame {frame_idx} | Fill {len(frames_chw)}/{DEFAULT_T} | "
2198
+ f"BVP {len(bvp_stream)} | HR {last_bpm:.1f} BPM | "
2199
+ f"Time {int(elapsed)}s")
2200
+
2201
+ yield (
2202
+ status,
2203
+ f"{last_bpm:.1f}" if last_bpm > 0 else None,
2204
+ "--",
2205
+ f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
2206
+ str(frame_path),
2207
+ str(attention_path),
2208
+ str(signal_path) if signal_path.exists() else None,
2209
+ str(raw_path) if raw_path.exists() else None,
2210
+ str(post_path) if post_path.exists() else None,
2211
+ None
2212
+ )
2213
+ next_display = now + DISPLAY_INTERVAL
2214
+
2215
+ time.sleep(0.01)
2216
+
2217
+ cap.release()
2218
+
2219
+ # Cleanup Grad-CAM
2220
+ if attention_viz: attention_viz.cleanup()
2221
+
2222
+ csv_path = None
2223
+ if bvp_stream:
2224
+ csv_path = Path(tempfile.gettempdir()) / "live_bvp.csv"
2225
+ t = np.arange(len(bvp_stream)) / fps
2226
+ sig_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps)
2227
+ pd.DataFrame({"time_s": t, "bvp": sig_final}).to_csv(csv_path, index=False)
2228
+
2229
+ elapsed = time.time() - t0
2230
+ final_status = f"Session ended | Frames {frame_idx} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
2231
+ yield (
2232
+ final_status,
2233
+ f"{last_bpm:.1f}" if last_bpm > 0 else None,
2234
+ "--",
2235
+ f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
2236
+ str(frame_path),
2237
+ str(attention_path),
2238
+ str(signal_path) if signal_path.exists() else None,
2239
+ str(raw_path) if raw_path.exists() else None,
2240
+ str(post_path) if post_path.exists() else None,
2241
+ str(csv_path) if csv_path else None
2242
+ )
2243
+
2244
+ def process_stream(
2245
+ input_source: str,
2246
+ video_path: Optional[str],
2247
+ gt_file: Optional[str],
2248
+ model_name: str,
2249
+ fps_input: int,
2250
+ max_seconds: int,
2251
+ roi_type: str,
2252
+ control_id: str
2253
+ ):
2254
+ if input_source == "Live Webcam":
2255
+ yield from process_live_webcam(model_name, fps_input, roi_type, control_id)
2256
+ else:
2257
+ yield from process_video_file(video_path, gt_file, model_name, fps_input,
2258
+ max_seconds, roi_type, control_id)
2259
+
2260
+ def pause_processing(control_id: str) -> str:
2261
+ _, controls = ensure_controls(control_id)
2262
+ controls['pause'].set()
2263
+ return "Paused"
2264
+
2265
+ def resume_processing(control_id: str) -> str:
2266
+ _, controls = ensure_controls(control_id)
2267
+ controls['pause'].clear()
2268
+ return "Resumed"
2269
+
2270
+ def stop_processing(control_id: str) -> str:
2271
+ _, controls = ensure_controls(control_id)
2272
+ controls['stop'].set()
2273
+ controls['pause'].clear()
2274
+ return "Stopped"
2275
+
2276
+ def reset_ui():
2277
+ return ("Ready", None, None, None, None, None, None, None, None, None)
2278
+
2279
+ def handle_folder_upload(files):
2280
+ if not files:
2281
+ return None, None, "No files uploaded"
2282
+
2283
+ if not isinstance(files, list):
2284
+ files = [files]
2285
+
2286
+ video_path = None
2287
+ for file_obj in files:
2288
+ file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
2289
+ if file_path.suffix.lower() in VIDEO_EXTENSIONS:
2290
+ video_path = str(file_path)
2291
+ break
2292
+
2293
+ gt_path = None
2294
+ gt_patterns = ['gtdump.txt', 'ground_truth.txt', 'gt.txt']
2295
+
2296
+ for file_obj in files:
2297
+ file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
2298
+ if file_path.name.lower() in [p.lower() for p in gt_patterns]:
2299
+ gt_path = str(file_path)
2300
+ break
2301
+
2302
+ if not gt_path:
2303
+ for file_obj in files:
2304
+ file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
2305
+ if file_path.suffix.lower() in ['.txt', '.json', '.csv']:
2306
+ if 'readme' not in file_path.name.lower():
2307
+ gt_path = str(file_path)
2308
+ break
2309
+
2310
+ status = []
2311
+ if video_path:
2312
+ status.append(f"Video: {Path(video_path).name}")
2313
+ else:
2314
+ status.append("No video found")
2315
+
2316
+ if gt_path:
2317
+ status.append(f"GT: {Path(gt_path).name}")
2318
+ else:
2319
+ status.append("No ground truth")
2320
+
2321
+ return video_path, gt_path, " | ".join(status)
2322
+
2323
+ with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as demo:
2324
+ gr.Markdown("# rPPG Analysis Tool with Attention Visualization")
2325
+
2326
+ with gr.Row():
2327
+ input_source = gr.Radio(
2328
+ choices=["Video File", "Subject Folder", "Live Webcam"],
2329
+ value="Live Webcam",
2330
+ label="Input Source"
2331
+ )
2332
+
2333
+ with gr.Row():
2334
+ target_layer_input = gr.Textbox(
2335
+ label="Grad-CAM Target Layer (optional - leave empty for auto)",
2336
+ placeholder="e.g., backbone.conv3, encoder.layer2",
2337
+ value=""
2338
+ )
2339
+
2340
+ with gr.Row(visible=False) as video_inputs:
2341
+ with gr.Column():
2342
+ video_upload = gr.Video(label="Upload Video", sources=["upload"])
2343
+ with gr.Column():
2344
+ gt_upload = gr.File(label="Ground Truth (optional)",
2345
+ file_types=[".txt", ".csv", ".json"])
2346
+
2347
+ with gr.Row(visible=False) as folder_inputs:
2348
+ with gr.Column():
2349
+ folder_upload = gr.File(
2350
+ label="Upload Subject Folder",
2351
+ file_count="directory",
2352
+ file_types=None,
2353
+ type="filepath"
2354
+ )
2355
+ folder_status = gr.Textbox(label="Folder Status", interactive=False)
2356
+
2357
+ with gr.Row(visible=True) as webcam_inputs:
2358
+ gr.Markdown("### Live Webcam - Click Run to start")
2359
+
2360
+ def toggle_input_source(source):
2361
+ return [
2362
+ gr.update(visible=(source == "Video File")),
2363
+ gr.update(visible=(source == "Subject Folder")),
2364
+ gr.update(visible=(source == "Live Webcam"))
2365
+ ]
2366
+
2367
+ input_source.change(
2368
+ toggle_input_source,
2369
+ inputs=[input_source],
2370
+ outputs=[video_inputs, folder_inputs, webcam_inputs]
2371
+ )
2372
+
2373
+ folder_video = gr.State(None)
2374
+ folder_gt = gr.State(None)
2375
+
2376
+ folder_upload.upload(
2377
+ handle_folder_upload,
2378
+ inputs=[folder_upload],
2379
+ outputs=[folder_video, folder_gt, folder_status]
2380
+ )
2381
+
2382
+ with gr.Row():
2383
+ with gr.Column(scale=3):
2384
+ model_dropdown = gr.Dropdown(
2385
+ choices=scan_models(),
2386
+ value=scan_models()[0] if scan_models() else None,
2387
+ label="PhysMamba Model",
2388
+ interactive=True
2389
+ )
2390
+ with gr.Column(scale=1):
2391
+ refresh_models_btn = gr.Button("Refresh", variant="secondary")
2392
+
2393
+ refresh_models_btn.click(
2394
+ lambda: gr.update(choices=scan_models()),
2395
+ inputs=None,
2396
+ outputs=[model_dropdown]
2397
+ )
2398
+
2399
+ with gr.Row():
2400
+ fps_slider = gr.Slider(
2401
+ minimum=10, maximum=120, value=30, step=5,
2402
+ label="FPS"
2403
+ )
2404
+ max_seconds_slider = gr.Slider(
2405
+ minimum=10, maximum=600, value=180, step=10,
2406
+ label="Max Duration (s)"
2407
+ )
2408
+
2409
+ with gr.Row():
2410
+ roi_dropdown = gr.Dropdown(choices=["auto","forehead","cheeks","face"], value="auto", label="ROI")
2411
+
2412
+ control_state = gr.State(value="")
2413
+ placeholder_state = gr.State(value=None)
2414
+
2415
+ with gr.Row():
2416
+ run_btn = gr.Button("Run", variant="primary")
2417
+ pause_btn = gr.Button("Pause", variant="secondary")
2418
+ resume_btn = gr.Button("Resume", variant="secondary")
2419
+ stop_btn = gr.Button("Stop", variant="stop")
2420
+
2421
+ status_text = gr.Textbox(label="Status", lines=2, value="Ready")
2422
+
2423
+ with gr.Row():
2424
+ hr_output = gr.Textbox(label="HR (BPM)", interactive=False)
2425
+ gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False)
2426
+ rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False)
2427
+
2428
+ with gr.Row():
2429
+ with gr.Column():
2430
+ frame_output = gr.Image(label="Video Feed", type="filepath")
2431
+ with gr.Column():
2432
+ attention_output = gr.Image(label="Attention Map", type="filepath")
2433
+
2434
+ with gr.Row():
2435
+ signal_output = gr.Image(label="Signal Comparison", type="filepath")
2436
+
2437
+ with gr.Row():
2438
+ with gr.Column():
2439
+ raw_signal_output = gr.Image(label="Raw Signal", type="filepath")
2440
+ with gr.Column():
2441
+ post_signal_output = gr.Image(label="Post-processed Signal", type="filepath")
2442
+
2443
+ with gr.Row():
2444
+ csv_output = gr.File(label="Download CSV")
2445
+
2446
+ pause_btn.click(
2447
+ pause_processing,
2448
+ inputs=[control_state],
2449
+ outputs=[status_text]
2450
+ )
2451
+
2452
+ resume_btn.click(
2453
+ resume_processing,
2454
+ inputs=[control_state],
2455
+ outputs=[status_text]
2456
+ )
2457
+
2458
+ stop_btn.click(
2459
+ stop_processing,
2460
+ inputs=[control_state],
2461
+ outputs=[status_text]
2462
+ ).then(
2463
+ reset_ui,
2464
+ inputs=None,
2465
+ outputs=[status_text, hr_output, gt_hr_output, rmssd_output,
2466
+ frame_output, attention_output, signal_output,
2467
+ raw_signal_output, post_signal_output, csv_output]
2468
+ )
2469
+
2470
+ def run_processing(input_source, video_upload, gt_upload, folder_video, folder_gt,
2471
+ model_name, fps, max_sec, roi, ctrl_id):
2472
+ """Fixed version that handles model_name type conversion."""
2473
+
2474
+ if isinstance(model_name, int):
2475
+ model_name = str(model_name)
2476
+
2477
+ if not model_name:
2478
+ yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
2479
+ return
2480
+
2481
+ if input_source == "Video File":
2482
+ video_path = _as_path(video_upload)
2483
+ gt_file = _as_path(gt_upload)
2484
+ elif input_source == "Subject Folder":
2485
+ video_path = _as_path(folder_video)
2486
+ gt_file = _as_path(folder_gt)
2487
+ else: # Live Webcam
2488
+ video_path, gt_file = None, None
2489
+
2490
+ yield from process_stream(
2491
+ input_source, video_path, gt_file,
2492
+ model_name, fps, max_sec, roi, ctrl_id
2493
+ )
2494
+
2495
+
2496
+ run_btn.click(
2497
+ fn=run_processing,
2498
+ inputs=[
2499
+ input_source,
2500
+ video_upload,
2501
+ gt_upload,
2502
+ folder_video,
2503
+ folder_gt,
2504
+ model_dropdown,
2505
+ fps_slider,
2506
+ max_seconds_slider,
2507
+ roi_dropdown,
2508
+ control_state
2509
+ ],
2510
+ outputs=[
2511
+ status_text,
2512
+ hr_output,
2513
+ gt_hr_output,
2514
+ rmssd_output,
2515
+ frame_output,
2516
+ attention_output,
2517
+ signal_output,
2518
+ raw_signal_output,
2519
+ post_signal_output,
2520
+ csv_output
2521
+ ]
2522
+ )
2523
+
2524
+
2525
+ run_btn.click(
2526
+ fn=run_processing,
2527
+ inputs=[
2528
+ input_source,
2529
+ video_upload,
2530
+ folder_video,
2531
+ folder_gt,
2532
+ model_dropdown,
2533
+ fps_slider,
2534
+ max_seconds_slider,
2535
+ roi_dropdown,
2536
+ control_state
2537
+ ],
2538
+ outputs=[
2539
+ status_text,
2540
+ hr_output,
2541
+ gt_hr_output,
2542
+ rmssd_output,
2543
+ frame_output,
2544
+ attention_output,
2545
+ signal_output,
2546
+ raw_signal_output,
2547
+ post_signal_output,
2548
+ csv_output
2549
+ ]
2550
+ )
2551
+
2552
+ if __name__ == "__main__":
2553
+ demo.queue(max_size=10).launch(
2554
+ server_name="127.0.0.1",
2555
+ server_port=7861,
2556
+ share=False,
2557
+ show_error=True,
2558
+ inbrowser=True
2559
+ )
final_model_release/PURE_PhysMamba_DiffNormalized.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2da6b9bc7e8728c20743c8f359eb43eefd2a96665098a29302fed2b09bc6b7d7
3
+ size 3013798
final_model_release/UBFC-rPPG_PhysMamba_DiffNormalized.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd18f227ff18f5140f471dc78e61426d9344ecc43677b35ec16b3199f1dccd19
3
+ size 3013798
mamba_ssm/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+ import torch, torch.nn as nn, torch.nn.functional as F
4
+ print('[mamba_ssm shim] Loaded CPU shim')
5
+ class Mamba(nn.Module):
6
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2, conv_bias=True, **kwargs):
7
+ super().__init__()
8
+ h = d_model * expand
9
+ self.in_proj = nn.Linear(d_model, 2*h)
10
+ self.dw = nn.Conv1d(h, h, d_conv, padding=d_conv-1, groups=h, bias=conv_bias)
11
+ self.mix = nn.Conv1d(h, h, 1)
12
+ self.out = nn.Linear(h, d_model)
13
+ self.d_model = d_model
14
+ def forward(self, x):
15
+ B,L,C = x.shape
16
+ u,v = self.in_proj(x).chunk(2, dim=-1)
17
+ y = F.silu(u) * torch.sigmoid(v)
18
+ y = self.dw(y.transpose(1,2))[...,:L]
19
+ y = F.silu(self.mix(y)).transpose(1,2)
20
+ return self.out(y)
mamba_ssm/models/__init__.py ADDED
File without changes
mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+
6
+ from collections import namedtuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
12
+ from mamba_ssm.utils.generation import GenerationMixin
13
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
14
+
15
+ try:
16
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
17
+ except ImportError:
18
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
19
+
20
+
21
+ def create_block(
22
+ d_model,
23
+ ssm_cfg=None,
24
+ norm_epsilon=1e-5,
25
+ rms_norm=False,
26
+ residual_in_fp32=False,
27
+ fused_add_norm=False,
28
+ layer_idx=None,
29
+ device=None,
30
+ dtype=None,
31
+ ):
32
+ if ssm_cfg is None:
33
+ ssm_cfg = {}
34
+ factory_kwargs = {"device": device, "dtype": dtype}
35
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
36
+ norm_cls = partial(
37
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
38
+ )
39
+ block = Block(
40
+ d_model,
41
+ mixer_cls,
42
+ norm_cls=norm_cls,
43
+ fused_add_norm=fused_add_norm,
44
+ residual_in_fp32=residual_in_fp32,
45
+ )
46
+ block.layer_idx = layer_idx
47
+ return block
48
+
49
+
50
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
51
+ def _init_weights(
52
+ module,
53
+ n_layer,
54
+ initializer_range=0.02, # Now only used for embedding layer.
55
+ rescale_prenorm_residual=True,
56
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
57
+ ):
58
+ if isinstance(module, nn.Linear):
59
+ if module.bias is not None:
60
+ if not getattr(module.bias, "_no_reinit", False):
61
+ nn.init.zeros_(module.bias)
62
+ elif isinstance(module, nn.Embedding):
63
+ nn.init.normal_(module.weight, std=initializer_range)
64
+
65
+ if rescale_prenorm_residual:
66
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
67
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
68
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
69
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
70
+ #
71
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
72
+ for name, p in module.named_parameters():
73
+ if name in ["out_proj.weight", "fc2.weight"]:
74
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
75
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
76
+ # We need to reinit p since this code could be called multiple times
77
+ # Having just p *= scale would repeatedly scale it down
78
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
79
+ with torch.no_grad():
80
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
81
+
82
+
83
+ class MixerModel(nn.Module):
84
+ def __init__(
85
+ self,
86
+ d_model: int,
87
+ n_layer: int,
88
+ vocab_size: int,
89
+ ssm_cfg=None,
90
+ norm_epsilon: float = 1e-5,
91
+ rms_norm: bool = False,
92
+ initializer_cfg=None,
93
+ fused_add_norm=False,
94
+ residual_in_fp32=False,
95
+ device=None,
96
+ dtype=None,
97
+ ) -> None:
98
+ factory_kwargs = {"device": device, "dtype": dtype}
99
+ super().__init__()
100
+ self.residual_in_fp32 = residual_in_fp32
101
+
102
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
103
+
104
+ # We change the order of residual and layer norm:
105
+ # Instead of LN -> Attn / MLP -> Add, we do:
106
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
107
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
108
+ # This is for performance reason: we can fuse add + layer_norm.
109
+ self.fused_add_norm = fused_add_norm
110
+ if self.fused_add_norm:
111
+ if layer_norm_fn is None or rms_norm_fn is None:
112
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
113
+
114
+ self.layers = nn.ModuleList(
115
+ [
116
+ create_block(
117
+ d_model,
118
+ ssm_cfg=ssm_cfg,
119
+ norm_epsilon=norm_epsilon,
120
+ rms_norm=rms_norm,
121
+ residual_in_fp32=residual_in_fp32,
122
+ fused_add_norm=fused_add_norm,
123
+ layer_idx=i,
124
+ **factory_kwargs,
125
+ )
126
+ for i in range(n_layer)
127
+ ]
128
+ )
129
+
130
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
131
+ d_model, eps=norm_epsilon, **factory_kwargs
132
+ )
133
+
134
+ self.apply(
135
+ partial(
136
+ _init_weights,
137
+ n_layer=n_layer,
138
+ **(initializer_cfg if initializer_cfg is not None else {}),
139
+ )
140
+ )
141
+
142
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
143
+ return {
144
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
145
+ for i, layer in enumerate(self.layers)
146
+ }
147
+
148
+ def forward(self, input_ids, inference_params=None):
149
+ hidden_states = self.embedding(input_ids)
150
+ residual = None
151
+ for layer in self.layers:
152
+ hidden_states, residual = layer(
153
+ hidden_states, residual, inference_params=inference_params
154
+ )
155
+ if not self.fused_add_norm:
156
+ residual = (hidden_states + residual) if residual is not None else hidden_states
157
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
158
+ else:
159
+ # Set prenorm=False here since we don't need the residual
160
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
161
+ hidden_states = fused_add_norm_fn(
162
+ hidden_states,
163
+ self.norm_f.weight,
164
+ self.norm_f.bias,
165
+ eps=self.norm_f.eps,
166
+ residual=residual,
167
+ prenorm=False,
168
+ residual_in_fp32=self.residual_in_fp32,
169
+ )
170
+ return hidden_states
171
+
172
+
173
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
174
+
175
+ def __init__(
176
+ self,
177
+ d_model: int,
178
+ n_layer: int,
179
+ vocab_size: int,
180
+ initializer_cfg=None,
181
+ pad_vocab_size_multiple: int = 1,
182
+ device=None,
183
+ dtype=None,
184
+ **backbone_kwargs,
185
+ ) -> None:
186
+ factory_kwargs = {"device": device, "dtype": dtype}
187
+ super().__init__()
188
+ if vocab_size % pad_vocab_size_multiple != 0:
189
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
190
+ self.backbone = MixerModel(
191
+ d_model=d_model,
192
+ n_layer=n_layer,
193
+ vocab_size=vocab_size,
194
+ initializer_cfg=initializer_cfg,
195
+ **backbone_kwargs,
196
+ **factory_kwargs,
197
+ )
198
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
199
+
200
+ # Initialize weights and apply final processing
201
+ self.apply(
202
+ partial(
203
+ _init_weights,
204
+ n_layer=n_layer,
205
+ **(initializer_cfg if initializer_cfg is not None else {}),
206
+ )
207
+ )
208
+ self.tie_weights()
209
+
210
+ def tie_weights(self):
211
+ self.lm_head.weight = self.backbone.embedding.weight
212
+
213
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
214
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
215
+
216
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
217
+ """
218
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
219
+ num_last_tokens: if > 0, only return the logits for the last n tokens
220
+ """
221
+ hidden_states = self.backbone(input_ids, inference_params=inference_params)
222
+ if num_last_tokens > 0:
223
+ hidden_states = hidden_states[:, -num_last_tokens:]
224
+ lm_logits = self.lm_head(hidden_states)
225
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
226
+ return CausalLMOutput(logits=lm_logits)
227
+
228
+ @classmethod
229
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
230
+ config = load_config_hf(pretrained_model_name)
231
+ model = cls(**config, device=device, dtype=dtype, **kwargs)
232
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
233
+ return model
mamba_ssm/modules/__init__.py ADDED
File without changes
mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ try:
14
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
15
+ except ImportError:
16
+ causal_conv1d_fn, causal_conv1d_update = None
17
+
18
+ try:
19
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
20
+ except ImportError:
21
+ selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None
22
+
23
+ try:
24
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
25
+ except ImportError:
26
+ selective_state_update = None
27
+
28
+ try:
29
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
30
+ except ImportError:
31
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
32
+
33
+
34
+ class Mamba(nn.Module):
35
+ def __init__(
36
+ self,
37
+ d_model,
38
+ d_state=16,
39
+ d_conv=4,
40
+ expand=2,
41
+ dt_rank="auto",
42
+ dt_min=0.001,
43
+ dt_max=0.1,
44
+ dt_init="random",
45
+ dt_scale=1.0,
46
+ dt_init_floor=1e-4,
47
+ conv_bias=True,
48
+ bias=False,
49
+ use_fast_path=True, # Fused kernel options
50
+ layer_idx=None,
51
+ device=None,
52
+ dtype=None,
53
+ bimamba=True,
54
+ ):
55
+ factory_kwargs = {"device": device, "dtype": dtype}
56
+ super().__init__()
57
+ self.d_model = d_model
58
+ self.d_state = d_state
59
+ self.d_conv = d_conv
60
+ self.expand = expand
61
+ self.d_inner = int(self.expand * self.d_model)
62
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
63
+ self.use_fast_path = use_fast_path
64
+ self.layer_idx = layer_idx
65
+ self.bimamba = bimamba
66
+
67
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
68
+
69
+ self.conv1d = nn.Conv1d(
70
+ in_channels=self.d_inner,
71
+ out_channels=self.d_inner,
72
+ bias=conv_bias,
73
+ kernel_size=d_conv,
74
+ groups=self.d_inner,
75
+ padding=d_conv - 1,
76
+ **factory_kwargs,
77
+ )
78
+
79
+ self.activation = "silu"
80
+ self.act = nn.SiLU()
81
+
82
+ self.x_proj = nn.Linear(
83
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
84
+ )
85
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
86
+
87
+ # Initialize special dt projection to preserve variance at initialization
88
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
89
+ if dt_init == "constant":
90
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
91
+ elif dt_init == "random":
92
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
97
+ dt = torch.exp(
98
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ ).clamp(min=dt_init_floor)
101
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
102
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
103
+ with torch.no_grad():
104
+ self.dt_proj.bias.copy_(inv_dt)
105
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
106
+ self.dt_proj.bias._no_reinit = True
107
+
108
+ # S4D real initialization
109
+ # NOTE: why plus 1?
110
+ A = repeat(
111
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
112
+ "n -> d n",
113
+ d=self.d_inner,
114
+ ).contiguous()
115
+ A_log = torch.log(A) # Keep A_log in fp32
116
+ self.A_log = nn.Parameter(A_log)
117
+ self.A_log._no_weight_decay = True
118
+
119
+ # D "skip" parameter
120
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
121
+ self.D._no_weight_decay = True
122
+
123
+ # bidirectional
124
+ # forked from https://github.com/hustvl/Vim
125
+ if self.bimamba:
126
+ A_b = repeat(
127
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
128
+ "n -> d n",
129
+ d=self.d_inner,
130
+ ).contiguous()
131
+ A_b_log = torch.log(A_b) # Keep A_b_log in fp32
132
+ self.A_b_log = nn.Parameter(A_b_log)
133
+ self.A_b_log._no_weight_decay = True
134
+
135
+ self.conv1d_b = nn.Conv1d(
136
+ in_channels=self.d_inner,
137
+ out_channels=self.d_inner,
138
+ bias=conv_bias,
139
+ kernel_size=d_conv,
140
+ groups=self.d_inner,
141
+ padding=d_conv - 1,
142
+ **factory_kwargs,
143
+ )
144
+
145
+ self.x_proj_b = nn.Linear(
146
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
147
+ )
148
+ self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
149
+
150
+ self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
151
+ self.D_b._no_weight_decay = True
152
+
153
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
154
+
155
+ def forward(self, hidden_states, inference_params=None, T=1):
156
+ """
157
+ hidden_states: (B, L, D)
158
+ Returns: same shape as hidden_states
159
+ """
160
+ batch, seqlen, dim = hidden_states.shape
161
+
162
+ conv_state, ssm_state = None, None
163
+ if inference_params is not None:
164
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
165
+ if inference_params.seqlen_offset > 0:
166
+ # The states are updated inplace
167
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
168
+ return out
169
+
170
+ # We do matmul and transpose BLH -> HBL at the same time
171
+ # NOTE: same as in_proj(hidden_states) but memory-efficient with the following operations
172
+ xz = rearrange(
173
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
174
+ "d (b l) -> b d l",
175
+ l=seqlen,
176
+ )
177
+ if self.in_proj.bias is not None:
178
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
179
+
180
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
181
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
182
+ if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
183
+ if self.bimamba:
184
+ A_b = -torch.exp(self.A_b_log.float())
185
+ out = mamba_inner_fn_no_out_proj(
186
+ xz,
187
+ self.conv1d.weight,
188
+ self.conv1d.bias,
189
+ self.x_proj.weight,
190
+ self.dt_proj.weight,
191
+ A,
192
+ None, # input-dependent B
193
+ None, # input-dependent C
194
+ self.D.float(),
195
+ delta_bias=self.dt_proj.bias.float(),
196
+ delta_softplus=True,
197
+ )
198
+ out_b = mamba_inner_fn_no_out_proj(
199
+ xz.flip([-1]),
200
+ self.conv1d_b.weight,
201
+ self.conv1d_b.bias,
202
+ self.x_proj_b.weight,
203
+ self.dt_proj_b.weight,
204
+ A_b,
205
+ None,
206
+ None,
207
+ self.D_b.float(),
208
+ delta_bias=self.dt_proj_b.bias.float(),
209
+ delta_softplus=True,
210
+ )
211
+ out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
212
+ else:
213
+ out = mamba_inner_fn(
214
+ xz,
215
+ self.conv1d.weight,
216
+ self.conv1d.bias,
217
+ self.x_proj.weight,
218
+ self.dt_proj.weight,
219
+ self.out_proj.weight,
220
+ self.out_proj.bias,
221
+ A,
222
+ None, # input-dependent B
223
+ None, # input-dependent C
224
+ self.D.float(),
225
+ delta_bias=self.dt_proj.bias.float(),
226
+ delta_softplus=True,
227
+ )
228
+ else:
229
+ x, z = xz.chunk(2, dim=1)
230
+ # Compute short convolution
231
+ if conv_state is not None:
232
+ conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W)
233
+ if causal_conv1d_fn is None:
234
+ x = self.act(self.conv1d(x)[..., :seqlen])
235
+ else:
236
+ assert self.activation in ["silu", "swish"]
237
+ x = causal_conv1d_fn(
238
+ x,
239
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
240
+ self.conv1d.bias,
241
+ self.activation,
242
+ )
243
+
244
+ # We're careful here about the layout, to avoid extra transposes.
245
+ # We want dt to have d as the slowest moving dimension
246
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
247
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
248
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
249
+ dt = self.dt_proj.weight @ dt.t()
250
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
251
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
252
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
253
+ assert self.activation in ["silu", "swish"]
254
+ y = selective_scan_fn(
255
+ x,
256
+ dt,
257
+ A,
258
+ B,
259
+ C,
260
+ self.D.float(),
261
+ z=z,
262
+ delta_bias=self.dt_proj.bias.float(),
263
+ delta_softplus=True,
264
+ return_last_state=ssm_state is not None,
265
+ )
266
+ if ssm_state is not None:
267
+ y, last_state = y
268
+ ssm_state.copy_(last_state)
269
+ y = rearrange(y, "b d l -> b l d")
270
+ out = self.out_proj(y)
271
+ return out
272
+
273
+ def step(self, hidden_states, conv_state, ssm_state):
274
+ dtype = hidden_states.dtype
275
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
276
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
277
+ x, z = xz.chunk(2, dim=-1) # (B D)
278
+
279
+ # Conv step
280
+ if causal_conv1d_update is None:
281
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
282
+ conv_state[:, :, -1] = x
283
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
284
+ if self.conv1d.bias is not None:
285
+ x = x + self.conv1d.bias
286
+ x = self.act(x).to(dtype=dtype)
287
+ else:
288
+ x = causal_conv1d_update(
289
+ x,
290
+ conv_state,
291
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
292
+ self.conv1d.bias,
293
+ self.activation,
294
+ )
295
+
296
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
297
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
298
+ # Don't add dt_bias here
299
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
300
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
301
+
302
+ # SSM step
303
+ if selective_state_update is None:
304
+ # Discretize A and B
305
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
306
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
307
+ dB = torch.einsum("bd,bn->bdn", dt, B)
308
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
309
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
310
+ y = y + self.D.to(dtype) * x
311
+ y = y * self.act(z) # (B D)
312
+ else:
313
+ y = selective_state_update(
314
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
315
+ )
316
+
317
+ out = self.out_proj(y)
318
+ return out.unsqueeze(1), conv_state, ssm_state
319
+
320
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
321
+ device = self.out_proj.weight.device
322
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
323
+ conv_state = torch.zeros(
324
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
325
+ )
326
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
327
+ # ssm_dtype = torch.float32
328
+ ssm_state = torch.zeros(
329
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
330
+ )
331
+ return conv_state, ssm_state
332
+
333
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
334
+ assert self.layer_idx is not None
335
+ if self.layer_idx not in inference_params.key_value_memory_dict:
336
+ batch_shape = (batch_size,)
337
+ conv_state = torch.zeros(
338
+ batch_size,
339
+ self.d_model * self.expand,
340
+ self.d_conv,
341
+ device=self.conv1d.weight.device,
342
+ dtype=self.conv1d.weight.dtype,
343
+ )
344
+ ssm_state = torch.zeros(
345
+ batch_size,
346
+ self.d_model * self.expand,
347
+ self.d_state,
348
+ device=self.dt_proj.weight.device,
349
+ dtype=self.dt_proj.weight.dtype,
350
+ # dtype=torch.float32,
351
+ )
352
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
353
+ else:
354
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
355
+ # TODO: What if batch size changes between generation, and we reuse the same states?
356
+ if initialize_states:
357
+ conv_state.zero_()
358
+ ssm_state.zero_()
359
+ return conv_state, ssm_state
360
+
361
+
362
+ class Block(nn.Module):
363
+ def __init__(
364
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
365
+ ):
366
+ """
367
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
368
+
369
+ This Block has a slightly different structure compared to a regular
370
+ prenorm Transformer block.
371
+ The standard block is: LN -> MHA/MLP -> Add.
372
+ [Ref: https://arxiv.org/abs/2002.04745]
373
+ Here we have: Add -> LN -> Mixer, returning both
374
+ the hidden_states (output of the mixer) and the residual.
375
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
376
+ The residual needs to be provided (except for the very first block).
377
+ """
378
+ super().__init__()
379
+ self.residual_in_fp32 = residual_in_fp32
380
+ self.fused_add_norm = fused_add_norm
381
+ self.mixer = mixer_cls(dim)
382
+ self.norm = norm_cls(dim)
383
+ if self.fused_add_norm:
384
+ assert RMSNorm is not None, "RMSNorm import fails"
385
+ assert isinstance(
386
+ self.norm, (nn.LayerNorm, RMSNorm)
387
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
388
+
389
+ def forward(
390
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
391
+ ):
392
+ r"""Pass the input through the encoder layer.
393
+
394
+ Args:
395
+ hidden_states: the sequence to the encoder layer (required).
396
+ residual: hidden_states = Mixer(LN(residual))
397
+ """
398
+ if not self.fused_add_norm:
399
+ residual = (hidden_states + residual) if residual is not None else hidden_states
400
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
401
+ if self.residual_in_fp32:
402
+ residual = residual.to(torch.float32)
403
+ else:
404
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
405
+ hidden_states, residual = fused_add_norm_fn(
406
+ hidden_states,
407
+ self.norm.weight,
408
+ self.norm.bias,
409
+ residual=residual,
410
+ prenorm=True,
411
+ residual_in_fp32=self.residual_in_fp32,
412
+ eps=self.norm.eps,
413
+ )
414
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
415
+ return hidden_states, residual
416
+
417
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
418
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
mamba_ssm/ops/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ def selective_scan(*args, **kwargs):
4
+ for a in args:
5
+ if isinstance(a, torch.Tensor):
6
+ return a
7
+ return torch.tensor(0.0)
mamba_ssm/ops/selective_scan_interface.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ def selective_scan_fn(*args, **kwargs):
4
+ for a in args:
5
+ if isinstance(a, torch.Tensor):
6
+ return a
7
+ return torch.tensor(0.0)
8
+ def mamba_inner_fn(*args, **kwargs):
9
+ for a in args:
10
+ if isinstance(a, torch.Tensor):
11
+ return a
12
+ return torch.tensor(0.0)
13
+ def bimamba_inner_fn(*args, **kwargs):
14
+ for a in args:
15
+ if isinstance(a, torch.Tensor):
16
+ return a
17
+ return torch.tensor(0.0)
mamba_ssm/ops/triton/__init__.py ADDED
File without changes
mamba_ssm/ops/triton/layernorm.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Implement residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.cuda.amp import custom_fwd, custom_bwd
14
+
15
+ import triton
16
+ import triton.language as tl
17
+
18
+
19
+ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
20
+ dtype = x.dtype
21
+ if upcast:
22
+ weight = weight.float()
23
+ bias = bias.float() if bias is not None else None
24
+ if upcast:
25
+ x = x.float()
26
+ residual = residual.float() if residual is not None else residual
27
+ if residual is not None:
28
+ x = (x + residual).to(x.dtype)
29
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
30
+ dtype
31
+ )
32
+ return out if not prenorm else (out, x)
33
+
34
+
35
+ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
36
+ dtype = x.dtype
37
+ if upcast:
38
+ weight = weight.float()
39
+ bias = bias.float() if bias is not None else None
40
+ if upcast:
41
+ x = x.float()
42
+ residual = residual.float() if residual is not None else residual
43
+ if residual is not None:
44
+ x = (x + residual).to(x.dtype)
45
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
46
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
47
+ out = out.to(dtype)
48
+ return out if not prenorm else (out, x)
49
+
50
+
51
+ @triton.autotune(
52
+ configs=[
53
+ triton.Config({}, num_warps=1),
54
+ triton.Config({}, num_warps=2),
55
+ triton.Config({}, num_warps=4),
56
+ triton.Config({}, num_warps=8),
57
+ triton.Config({}, num_warps=16),
58
+ triton.Config({}, num_warps=32),
59
+ ],
60
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
61
+ )
62
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
63
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
64
+ @triton.jit
65
+ def _layer_norm_fwd_1pass_kernel(
66
+ X, # pointer to the input
67
+ Y, # pointer to the output
68
+ W, # pointer to the weights
69
+ B, # pointer to the biases
70
+ RESIDUAL, # pointer to the residual
71
+ RESIDUAL_OUT, # pointer to the residual
72
+ Mean, # pointer to the mean
73
+ Rstd, # pointer to the 1/std
74
+ stride_x_row, # how much to increase the pointer when moving by 1 row
75
+ stride_y_row,
76
+ stride_res_row,
77
+ stride_res_out_row,
78
+ N, # number of columns in X
79
+ eps, # epsilon to avoid division by zero
80
+ IS_RMS_NORM: tl.constexpr,
81
+ BLOCK_N: tl.constexpr,
82
+ HAS_RESIDUAL: tl.constexpr,
83
+ STORE_RESIDUAL_OUT: tl.constexpr,
84
+ HAS_BIAS: tl.constexpr,
85
+ ):
86
+ # Map the program id to the row of X and Y it should compute.
87
+ row = tl.program_id(0)
88
+ X += row * stride_x_row
89
+ Y += row * stride_y_row
90
+ if HAS_RESIDUAL:
91
+ RESIDUAL += row * stride_res_row
92
+ if STORE_RESIDUAL_OUT:
93
+ RESIDUAL_OUT += row * stride_res_out_row
94
+ # Compute mean and variance
95
+ cols = tl.arange(0, BLOCK_N)
96
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
97
+ if HAS_RESIDUAL:
98
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
99
+ x += residual
100
+ if STORE_RESIDUAL_OUT:
101
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
102
+ if not IS_RMS_NORM:
103
+ mean = tl.sum(x, axis=0) / N
104
+ tl.store(Mean + row, mean)
105
+ xbar = tl.where(cols < N, x - mean, 0.0)
106
+ var = tl.sum(xbar * xbar, axis=0) / N
107
+ else:
108
+ xbar = tl.where(cols < N, x, 0.0)
109
+ var = tl.sum(xbar * xbar, axis=0) / N
110
+ rstd = 1 / tl.sqrt(var + eps)
111
+ tl.store(Rstd + row, rstd)
112
+ # Normalize and apply linear transformation
113
+ mask = cols < N
114
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
115
+ if HAS_BIAS:
116
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
117
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
118
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
119
+ # Write output
120
+ tl.store(Y + cols, y, mask=mask)
121
+
122
+
123
+ def _layer_norm_fwd(
124
+ x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
125
+ ):
126
+ if residual is not None:
127
+ residual_dtype = residual.dtype
128
+ M, N = x.shape
129
+ assert x.stride(-1) == 1
130
+ if residual is not None:
131
+ assert residual.stride(-1) == 1
132
+ assert residual.shape == (M, N)
133
+ assert weight.shape == (N,)
134
+ assert weight.stride(-1) == 1
135
+ if bias is not None:
136
+ assert bias.stride(-1) == 1
137
+ assert bias.shape == (N,)
138
+ # allocate output
139
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
140
+ assert y.stride(-1) == 1
141
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
142
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
143
+ assert residual_out.stride(-1) == 1
144
+ else:
145
+ residual_out = None
146
+ mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
147
+ rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
148
+ # Less than 64KB per feature: enqueue fused kernel
149
+ MAX_FUSED_SIZE = 65536 // x.element_size()
150
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
151
+ if N > BLOCK_N:
152
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
153
+ # heuristics for number of warps
154
+ with torch.cuda.device(x.device.index):
155
+ _layer_norm_fwd_1pass_kernel[(M,)](
156
+ x,
157
+ y,
158
+ weight,
159
+ bias,
160
+ residual,
161
+ residual_out,
162
+ mean,
163
+ rstd,
164
+ x.stride(0),
165
+ y.stride(0),
166
+ residual.stride(0) if residual is not None else 0,
167
+ residual_out.stride(0) if residual_out is not None else 0,
168
+ N,
169
+ eps,
170
+ is_rms_norm,
171
+ BLOCK_N,
172
+ residual is not None,
173
+ residual_out is not None,
174
+ bias is not None,
175
+ )
176
+ # residual_out is None if residual is None and residual_dtype == input_dtype
177
+ return y, mean, rstd, residual_out if residual_out is not None else x
178
+
179
+
180
+ @triton.autotune(
181
+ configs=[
182
+ triton.Config({}, num_warps=1),
183
+ triton.Config({}, num_warps=2),
184
+ triton.Config({}, num_warps=4),
185
+ triton.Config({}, num_warps=8),
186
+ triton.Config({}, num_warps=16),
187
+ triton.Config({}, num_warps=32),
188
+ ],
189
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
190
+ )
191
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
192
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
193
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
194
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
195
+ @triton.jit
196
+ def _layer_norm_bwd_kernel(
197
+ X, # pointer to the input
198
+ W, # pointer to the weights
199
+ B, # pointer to the biases
200
+ Y, # pointer to the output to be recomputed
201
+ DY, # pointer to the output gradient
202
+ DX, # pointer to the input gradient
203
+ DW, # pointer to the partial sum of weights gradient
204
+ DB, # pointer to the partial sum of biases gradient
205
+ DRESIDUAL,
206
+ DRESIDUAL_IN,
207
+ Mean, # pointer to the mean
208
+ Rstd, # pointer to the 1/std
209
+ stride_x_row, # how much to increase the pointer when moving by 1 row
210
+ stride_y_row,
211
+ stride_dy_row,
212
+ stride_dx_row,
213
+ stride_dres_row,
214
+ stride_dres_in_row,
215
+ M, # number of rows in X
216
+ N, # number of columns in X
217
+ eps, # epsilon to avoid division by zero
218
+ rows_per_program,
219
+ IS_RMS_NORM: tl.constexpr,
220
+ BLOCK_N: tl.constexpr,
221
+ HAS_DRESIDUAL: tl.constexpr,
222
+ STORE_DRESIDUAL: tl.constexpr,
223
+ HAS_BIAS: tl.constexpr,
224
+ RECOMPUTE_OUTPUT: tl.constexpr,
225
+ ):
226
+ # Map the program id to the elements of X, DX, and DY it should compute.
227
+ row_block_id = tl.program_id(0)
228
+ row_start = row_block_id * rows_per_program
229
+ cols = tl.arange(0, BLOCK_N)
230
+ mask = cols < N
231
+ X += row_start * stride_x_row
232
+ if HAS_DRESIDUAL:
233
+ DRESIDUAL += row_start * stride_dres_row
234
+ if STORE_DRESIDUAL:
235
+ DRESIDUAL_IN += row_start * stride_dres_in_row
236
+ DY += row_start * stride_dy_row
237
+ DX += row_start * stride_dx_row
238
+ if RECOMPUTE_OUTPUT:
239
+ Y += row_start * stride_y_row
240
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
241
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
242
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
243
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
244
+ if HAS_BIAS:
245
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
246
+ row_end = min((row_block_id + 1) * rows_per_program, M)
247
+ for row in range(row_start, row_end):
248
+ # Load data to SRAM
249
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
250
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
251
+ if not IS_RMS_NORM:
252
+ mean = tl.load(Mean + row)
253
+ rstd = tl.load(Rstd + row)
254
+ # Compute dx
255
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
256
+ xhat = tl.where(mask, xhat, 0.0)
257
+ if RECOMPUTE_OUTPUT:
258
+ y = xhat * w + b if HAS_BIAS else xhat * w
259
+ tl.store(Y + cols, y, mask=mask)
260
+ wdy = w * dy
261
+ dw += dy * xhat
262
+ if HAS_BIAS:
263
+ db += dy
264
+ if not IS_RMS_NORM:
265
+ c1 = tl.sum(xhat * wdy, axis=0) / N
266
+ c2 = tl.sum(wdy, axis=0) / N
267
+ dx = (wdy - (xhat * c1 + c2)) * rstd
268
+ else:
269
+ c1 = tl.sum(xhat * wdy, axis=0) / N
270
+ dx = (wdy - xhat * c1) * rstd
271
+ if HAS_DRESIDUAL:
272
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
273
+ dx += dres
274
+ # Write dx
275
+ if STORE_DRESIDUAL:
276
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
277
+ tl.store(DX + cols, dx, mask=mask)
278
+
279
+ X += stride_x_row
280
+ if HAS_DRESIDUAL:
281
+ DRESIDUAL += stride_dres_row
282
+ if STORE_DRESIDUAL:
283
+ DRESIDUAL_IN += stride_dres_in_row
284
+ if RECOMPUTE_OUTPUT:
285
+ Y += stride_y_row
286
+ DY += stride_dy_row
287
+ DX += stride_dx_row
288
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
289
+ if HAS_BIAS:
290
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
291
+
292
+
293
+ def _layer_norm_bwd(
294
+ dy,
295
+ x,
296
+ weight,
297
+ bias,
298
+ eps,
299
+ mean,
300
+ rstd,
301
+ dresidual=None,
302
+ has_residual=False,
303
+ is_rms_norm=False,
304
+ x_dtype=None,
305
+ recompute_output=False,
306
+ ):
307
+ M, N = x.shape
308
+ assert x.stride(-1) == 1
309
+ assert dy.stride(-1) == 1
310
+ assert dy.shape == (M, N)
311
+ if dresidual is not None:
312
+ assert dresidual.stride(-1) == 1
313
+ assert dresidual.shape == (M, N)
314
+ assert weight.shape == (N,)
315
+ assert weight.stride(-1) == 1
316
+ if bias is not None:
317
+ assert bias.stride(-1) == 1
318
+ assert bias.shape == (N,)
319
+ # allocate output
320
+ dx = (
321
+ torch.empty_like(x)
322
+ if x_dtype is None
323
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
324
+ )
325
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
326
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
327
+
328
+ # Less than 64KB per feature: enqueue fused kernel
329
+ MAX_FUSED_SIZE = 65536 // x.element_size()
330
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
331
+ if N > BLOCK_N:
332
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
333
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
334
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
335
+ _db = (
336
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
337
+ if bias is not None
338
+ else None
339
+ )
340
+ rows_per_program = math.ceil(M / sm_count)
341
+ grid = (sm_count,)
342
+ with torch.cuda.device(x.device.index):
343
+ _layer_norm_bwd_kernel[grid](
344
+ x,
345
+ weight,
346
+ bias,
347
+ y,
348
+ dy,
349
+ dx,
350
+ _dw,
351
+ _db,
352
+ dresidual,
353
+ dresidual_in,
354
+ mean,
355
+ rstd,
356
+ x.stride(0),
357
+ 0 if not recompute_output else y.stride(0),
358
+ dy.stride(0),
359
+ dx.stride(0),
360
+ dresidual.stride(0) if dresidual is not None else 0,
361
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
362
+ M,
363
+ N,
364
+ eps,
365
+ rows_per_program,
366
+ is_rms_norm,
367
+ BLOCK_N,
368
+ dresidual is not None,
369
+ dresidual_in is not None,
370
+ bias is not None,
371
+ )
372
+ dw = _dw.sum(0).to(weight.dtype)
373
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
374
+ # Don't need to compute dresidual_in separately in this case
375
+ if has_residual and dx.dtype == x.dtype:
376
+ dresidual_in = dx
377
+ return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
378
+
379
+
380
+ class LayerNormFn(torch.autograd.Function):
381
+ @staticmethod
382
+ def forward(
383
+ ctx,
384
+ x,
385
+ weight,
386
+ bias,
387
+ residual=None,
388
+ eps=1e-6,
389
+ prenorm=False,
390
+ residual_in_fp32=False,
391
+ is_rms_norm=False,
392
+ ):
393
+ x_shape_og = x.shape
394
+ # reshape input data into 2D tensor
395
+ x = x.reshape(-1, x.shape[-1])
396
+ if x.stride(-1) != 1:
397
+ x = x.contiguous()
398
+ if residual is not None:
399
+ assert residual.shape == x_shape_og
400
+ residual = residual.reshape(-1, residual.shape[-1])
401
+ if residual.stride(-1) != 1:
402
+ residual = residual.contiguous()
403
+ weight = weight.contiguous()
404
+ if bias is not None:
405
+ bias = bias.contiguous()
406
+ residual_dtype = (
407
+ residual.dtype
408
+ if residual is not None
409
+ else (torch.float32 if residual_in_fp32 else None)
410
+ )
411
+ y, mean, rstd, residual_out = _layer_norm_fwd(
412
+ x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
413
+ )
414
+ ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
415
+ ctx.x_shape_og = x_shape_og
416
+ ctx.eps = eps
417
+ ctx.is_rms_norm = is_rms_norm
418
+ ctx.has_residual = residual is not None
419
+ ctx.prenorm = prenorm
420
+ ctx.x_dtype = x.dtype
421
+ y = y.reshape(x_shape_og)
422
+ return y if not prenorm else (y, residual_out.reshape(x_shape_og))
423
+
424
+ @staticmethod
425
+ def backward(ctx, dy, *args):
426
+ x, weight, bias, mean, rstd = ctx.saved_tensors
427
+ dy = dy.reshape(-1, dy.shape[-1])
428
+ if dy.stride(-1) != 1:
429
+ dy = dy.contiguous()
430
+ assert dy.shape == x.shape
431
+ if ctx.prenorm:
432
+ dresidual = args[0]
433
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
434
+ if dresidual.stride(-1) != 1:
435
+ dresidual = dresidual.contiguous()
436
+ assert dresidual.shape == x.shape
437
+ else:
438
+ dresidual = None
439
+ dx, dw, db, dresidual_in = _layer_norm_bwd(
440
+ dy,
441
+ x,
442
+ weight,
443
+ bias,
444
+ ctx.eps,
445
+ mean,
446
+ rstd,
447
+ dresidual,
448
+ ctx.has_residual,
449
+ ctx.is_rms_norm,
450
+ x_dtype=ctx.x_dtype,
451
+ )
452
+ return (
453
+ dx.reshape(ctx.x_shape_og),
454
+ dw,
455
+ db,
456
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
457
+ None,
458
+ None,
459
+ None,
460
+ None,
461
+ )
462
+
463
+
464
+ def layer_norm_fn(
465
+ x,
466
+ weight,
467
+ bias,
468
+ residual=None,
469
+ eps=1e-6,
470
+ prenorm=False,
471
+ residual_in_fp32=False,
472
+ is_rms_norm=False,
473
+ ):
474
+ return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
475
+
476
+
477
+ def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
478
+ return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
479
+
480
+
481
+ class RMSNorm(torch.nn.Module):
482
+ def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
483
+ factory_kwargs = {"device": device, "dtype": dtype}
484
+ super().__init__()
485
+ self.eps = eps
486
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
487
+ self.register_parameter("bias", None)
488
+ self.reset_parameters()
489
+
490
+ def reset_parameters(self):
491
+ torch.nn.init.ones_(self.weight)
492
+
493
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
494
+ return rms_norm_fn(
495
+ x,
496
+ self.weight,
497
+ self.bias,
498
+ residual=residual,
499
+ eps=self.eps,
500
+ prenorm=prenorm,
501
+ residual_in_fp32=residual_in_fp32,
502
+ is_rms_norm=True,
503
+ )
504
+
505
+
506
+ class LayerNormLinearFn(torch.autograd.Function):
507
+ @staticmethod
508
+ @custom_fwd
509
+ def forward(
510
+ ctx,
511
+ x,
512
+ norm_weight,
513
+ norm_bias,
514
+ linear_weight,
515
+ linear_bias,
516
+ residual=None,
517
+ eps=1e-6,
518
+ prenorm=False,
519
+ residual_in_fp32=False,
520
+ is_rms_norm=False,
521
+ ):
522
+ x_shape_og = x.shape
523
+ # reshape input data into 2D tensor
524
+ x = x.reshape(-1, x.shape[-1])
525
+ if x.stride(-1) != 1:
526
+ x = x.contiguous()
527
+ if residual is not None:
528
+ assert residual.shape == x_shape_og
529
+ residual = residual.reshape(-1, residual.shape[-1])
530
+ if residual.stride(-1) != 1:
531
+ residual = residual.contiguous()
532
+ norm_weight = norm_weight.contiguous()
533
+ if norm_bias is not None:
534
+ norm_bias = norm_bias.contiguous()
535
+ residual_dtype = (
536
+ residual.dtype
537
+ if residual is not None
538
+ else (torch.float32 if residual_in_fp32 else None)
539
+ )
540
+ y, mean, rstd, residual_out = _layer_norm_fwd(
541
+ x,
542
+ norm_weight,
543
+ norm_bias,
544
+ eps,
545
+ residual,
546
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
547
+ residual_dtype=residual_dtype,
548
+ is_rms_norm=is_rms_norm,
549
+ )
550
+ y = y.reshape(x_shape_og)
551
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
552
+ linear_weight = linear_weight.to(dtype)
553
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
554
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
555
+ # We don't store y, will be recomputed in the backward pass to save memory
556
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
557
+ ctx.x_shape_og = x_shape_og
558
+ ctx.eps = eps
559
+ ctx.is_rms_norm = is_rms_norm
560
+ ctx.has_residual = residual is not None
561
+ ctx.prenorm = prenorm
562
+ ctx.x_dtype = x.dtype
563
+ ctx.linear_bias_is_none = linear_bias is None
564
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
565
+
566
+ @staticmethod
567
+ @custom_bwd
568
+ def backward(ctx, dout, *args):
569
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
570
+ dout = dout.reshape(-1, dout.shape[-1])
571
+ dy = F.linear(dout, linear_weight.t())
572
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
573
+ if dy.stride(-1) != 1:
574
+ dy = dy.contiguous()
575
+ assert dy.shape == x.shape
576
+ if ctx.prenorm:
577
+ dresidual = args[0]
578
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
579
+ if dresidual.stride(-1) != 1:
580
+ dresidual = dresidual.contiguous()
581
+ assert dresidual.shape == x.shape
582
+ else:
583
+ dresidual = None
584
+ dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
585
+ dy,
586
+ x,
587
+ norm_weight,
588
+ norm_bias,
589
+ ctx.eps,
590
+ mean,
591
+ rstd,
592
+ dresidual,
593
+ ctx.has_residual,
594
+ ctx.is_rms_norm,
595
+ x_dtype=ctx.x_dtype,
596
+ recompute_output=True,
597
+ )
598
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
599
+ return (
600
+ dx.reshape(ctx.x_shape_og),
601
+ dnorm_weight,
602
+ dnorm_bias,
603
+ dlinear_weight,
604
+ dlinear_bias,
605
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
606
+ None,
607
+ None,
608
+ None,
609
+ None,
610
+ )
611
+
612
+
613
+ def layer_norm_linear_fn(
614
+ x,
615
+ norm_weight,
616
+ norm_bias,
617
+ linear_weight,
618
+ linear_bias,
619
+ residual=None,
620
+ eps=1e-6,
621
+ prenorm=False,
622
+ residual_in_fp32=False,
623
+ is_rms_norm=False,
624
+ ):
625
+ return LayerNormLinearFn.apply(
626
+ x,
627
+ norm_weight,
628
+ norm_bias,
629
+ linear_weight,
630
+ linear_bias,
631
+ residual,
632
+ eps,
633
+ prenorm,
634
+ residual_in_fp32,
635
+ is_rms_norm,
636
+ )
mamba_ssm/ops/triton/selective_state_update.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ """We want triton==2.1.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
17
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
18
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
19
+ @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
20
+ @triton.jit
21
+ def _selective_scan_update_kernel(
22
+ # Pointers to matrices
23
+ state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
24
+ # Matrix dimensions
25
+ batch, dim, dstate,
26
+ # Strides
27
+ stride_state_batch, stride_state_dim, stride_state_dstate,
28
+ stride_x_batch, stride_x_dim,
29
+ stride_dt_batch, stride_dt_dim,
30
+ stride_dt_bias_dim,
31
+ stride_A_dim, stride_A_dstate,
32
+ stride_B_batch, stride_B_dstate,
33
+ stride_C_batch, stride_C_dstate,
34
+ stride_D_dim,
35
+ stride_z_batch, stride_z_dim,
36
+ stride_out_batch, stride_out_dim,
37
+ # Meta-parameters
38
+ DT_SOFTPLUS: tl.constexpr,
39
+ BLOCK_SIZE_M: tl.constexpr,
40
+ HAS_DT_BIAS: tl.constexpr,
41
+ HAS_D: tl.constexpr,
42
+ HAS_Z: tl.constexpr,
43
+ BLOCK_SIZE_DSTATE: tl.constexpr,
44
+ ):
45
+ pid_m = tl.program_id(axis=0)
46
+ pid_b = tl.program_id(axis=1)
47
+ state_ptr += pid_b * stride_state_batch
48
+ x_ptr += pid_b * stride_x_batch
49
+ dt_ptr += pid_b * stride_dt_batch
50
+ B_ptr += pid_b * stride_B_batch
51
+ C_ptr += pid_b * stride_C_batch
52
+ if HAS_Z:
53
+ z_ptr += pid_b * stride_z_batch
54
+ out_ptr += pid_b * stride_out_batch
55
+
56
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
57
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
58
+ state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
59
+ x_ptrs = x_ptr + offs_m * stride_x_dim
60
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
61
+ if HAS_DT_BIAS:
62
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
63
+ A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
64
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
65
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
66
+ if HAS_D:
67
+ D_ptrs = D_ptr + offs_m * stride_D_dim
68
+ if HAS_Z:
69
+ z_ptrs = z_ptr + offs_m * stride_z_dim
70
+ out_ptrs = out_ptr + offs_m * stride_out_dim
71
+
72
+ state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
73
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
74
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
75
+ if HAS_DT_BIAS:
76
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
77
+ if DT_SOFTPLUS:
78
+ dt = tl.log(1.0 + tl.exp(dt))
79
+ A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
80
+ dA = tl.exp(A * dt[:, None])
81
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
82
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
83
+ if HAS_D:
84
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85
+ if HAS_Z:
86
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
87
+
88
+ dB = B[None, :] * dt[:, None]
89
+ state = state * dA + dB * x[:, None]
90
+ tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
91
+ out = tl.sum(state * C[None, :], axis=1)
92
+ if HAS_D:
93
+ out += x * D
94
+ if HAS_Z:
95
+ out *= z * tl.sigmoid(z)
96
+ tl.store(out_ptrs, out, mask=offs_m < dim)
97
+
98
+
99
+ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
100
+ """
101
+ Argument:
102
+ state: (batch, dim, dstate)
103
+ x: (batch, dim)
104
+ dt: (batch, dim)
105
+ A: (dim, dstate)
106
+ B: (batch, dstate)
107
+ C: (batch, dstate)
108
+ D: (dim,)
109
+ z: (batch, dim)
110
+ dt_bias: (dim,)
111
+ Return:
112
+ out: (batch, dim)
113
+ """
114
+ batch, dim, dstate = state.shape
115
+ assert x.shape == (batch, dim)
116
+ assert dt.shape == x.shape
117
+ assert A.shape == (dim, dstate)
118
+ assert B.shape == (batch, dstate)
119
+ assert C.shape == B.shape
120
+ if D is not None:
121
+ assert D.shape == (dim,)
122
+ if z is not None:
123
+ assert z.shape == x.shape
124
+ if dt_bias is not None:
125
+ assert dt_bias.shape == (dim,)
126
+ out = torch.empty_like(x)
127
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
128
+ z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
129
+ # We don't want autotune since it will overwrite the state
130
+ # We instead tune by hand.
131
+ BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
132
+ else ((16, 4) if dstate <= 32 else
133
+ ((8, 4) if dstate <= 64 else
134
+ ((4, 4) if dstate <= 128 else
135
+ ((4, 8))))))
136
+ with torch.cuda.device(x.device.index):
137
+ _selective_scan_update_kernel[grid](
138
+ state, x, dt, dt_bias, A, B, C, D, z, out,
139
+ batch, dim, dstate,
140
+ state.stride(0), state.stride(1), state.stride(2),
141
+ x.stride(0), x.stride(1),
142
+ dt.stride(0), dt.stride(1),
143
+ dt_bias.stride(0) if dt_bias is not None else 0,
144
+ A.stride(0), A.stride(1),
145
+ B.stride(0), B.stride(1),
146
+ C.stride(0), C.stride(1),
147
+ D.stride(0) if D is not None else 0,
148
+ z_strides[0], z_strides[1],
149
+ out.stride(0), out.stride(1),
150
+ dt_softplus,
151
+ BLOCK_SIZE_M,
152
+ num_warps=num_warps,
153
+ )
154
+ return out
155
+
156
+
157
+ def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
158
+ """
159
+ Argument:
160
+ state: (batch, dim, dstate)
161
+ x: (batch, dim)
162
+ dt: (batch, dim)
163
+ A: (dim, dstate)
164
+ B: (batch, dstate)
165
+ C: (batch, dstate)
166
+ D: (dim,)
167
+ z: (batch, dim)
168
+ dt_bias: (dim,)
169
+ Return:
170
+ out: (batch, dim)
171
+ """
172
+ batch, dim, dstate = state.shape
173
+ assert x.shape == (batch, dim)
174
+ assert dt.shape == x.shape
175
+ assert A.shape == (dim, dstate)
176
+ assert B.shape == (batch, dstate)
177
+ assert C.shape == B.shape
178
+ if D is not None:
179
+ assert D.shape == (dim,)
180
+ if z is not None:
181
+ assert z.shape == x.shape
182
+ if dt_bias is not None:
183
+ assert dt_bias.shape == (dim,)
184
+ dt = dt + dt_bias
185
+ dt = F.softplus(dt) if dt_softplus else dt
186
+ dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
187
+ dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
188
+ state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
189
+ out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
190
+ if D is not None:
191
+ out += (x * D).to(out.dtype)
192
+ return (out if z is None else out * F.silu(z)).to(x.dtype)
mamba_ssm/utils/__init__.py ADDED
File without changes
mamba_ssm/utils/generation.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
38
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
39
+ def modify_logits_for_top_k_filtering(logits, top_k):
40
+ """Set the logits for none top-k values to -inf. Done in-place."""
41
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+
44
+
45
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
46
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
47
+ def modify_logits_for_top_p_filtering(logits, top_p):
48
+ """Set the logits for none top-p values to -inf. Done in-place."""
49
+ if top_p <= 0.0 or top_p >= 1.0:
50
+ return
51
+ # First sort and calculate cumulative sum of probabilities.
52
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
53
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
54
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
55
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
56
+ # scatter sorted tensors to original indexing
57
+ indices_to_remove = sorted_indices_to_remove.scatter(
58
+ 1, sorted_indices, sorted_indices_to_remove
59
+ )
60
+ logits.masked_fill_(indices_to_remove, float("-inf"))
61
+
62
+
63
+ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
64
+ """Sample from top-k logits.
65
+ Arguments:
66
+ logits: Tensor of shape (batch_size, vocab_size)
67
+ """
68
+ if top_k == 1: # Short-circuit for greedy decoding
69
+ return logits.argmax(dim=-1)
70
+ else:
71
+ if top_p > 0.0:
72
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
73
+ if top_k > 0:
74
+ top_k = min(top_k, logits.size(-1)) # Safety check
75
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
76
+ if temperature != 1.0:
77
+ logits_top /= temperature
78
+ modify_logits_for_top_p_filtering(logits_top, top_p)
79
+ return indices[
80
+ torch.arange(indices.shape[0], device=indices.device),
81
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
82
+ ]
83
+ else:
84
+ # Clone so that when we modify for top_p we don't change the original logits
85
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
86
+ modify_logits_for_top_p_filtering(logits_top, top_p)
87
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
88
+ dim=-1
89
+ )
90
+
91
+
92
+ @torch.inference_mode()
93
+ def decode(
94
+ input_ids,
95
+ model,
96
+ max_length,
97
+ top_k=1,
98
+ top_p=0.0,
99
+ temperature=1.0,
100
+ eos_token_id=None,
101
+ teacher_outputs=None,
102
+ vocab_size=None,
103
+ tensor_parallel=1,
104
+ cg=False,
105
+ enable_timing=False,
106
+ ):
107
+ """Decoding, either greedy or with top-k or top-p sampling.
108
+ If top-k = 0, don't limit the number of candidates (pure sampling).
109
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
110
+ then top-p.
111
+ We assume that all sequences in the same batch have the same length.
112
+
113
+ Arguments:
114
+ input_ids: (batch, seq_len)
115
+ max_length: int
116
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
117
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
118
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
119
+ sequences: (batch, max_length)
120
+ scores: tuples of (batch, vocab_size)
121
+ """
122
+ batch_size, seqlen_og = input_ids.shape
123
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
124
+ if cg:
125
+ if not hasattr(model, "_decoding_cache"):
126
+ model._decoding_cache = None
127
+ model._decoding_cache = update_graph_cache(
128
+ model,
129
+ model._decoding_cache,
130
+ batch_size,
131
+ seqlen_og,
132
+ max_length,
133
+ tensor_parallel=tensor_parallel,
134
+ )
135
+ inference_params = model._decoding_cache.inference_params
136
+ inference_params.reset(max_length, batch_size)
137
+ else:
138
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
139
+
140
+ def get_logits(input_ids, inference_params):
141
+ decoding = inference_params.seqlen_offset > 0
142
+ if decoding:
143
+ position_ids = torch.full(
144
+ (batch_size, 1),
145
+ inference_params.seqlen_offset,
146
+ dtype=torch.long,
147
+ device=input_ids.device,
148
+ )
149
+ else:
150
+ position_ids = None
151
+ if not cg or not decoding:
152
+ logits = model(
153
+ input_ids,
154
+ position_ids=position_ids,
155
+ inference_params=inference_params,
156
+ num_last_tokens=1,
157
+ ).logits.squeeze(dim=1)
158
+ else:
159
+ logits = model._decoding_cache.run(
160
+ input_ids, position_ids, inference_params.seqlen_offset
161
+ ).squeeze(dim=1)
162
+ return logits[..., :vocab_size] if vocab_size is not None else logits
163
+
164
+ def sample_tokens(logits, inference_params):
165
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
166
+ token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
167
+ else:
168
+ token = teacher_outputs[:, inference_params.seqlen_offset]
169
+ # return rearrange(token, "b -> b 1")
170
+ return token.unsqueeze(1)
171
+
172
+ def should_stop(current_token, inference_params):
173
+ if inference_params.seqlen_offset == 0:
174
+ return False
175
+ if eos_token_id is not None and (current_token == eos_token_id).all():
176
+ return True
177
+ if inference_params.seqlen_offset >= max_length - 1:
178
+ return True
179
+ return False
180
+
181
+ start = torch.cuda.Event(enable_timing=enable_timing)
182
+ end = torch.cuda.Event(enable_timing=enable_timing)
183
+
184
+ if enable_timing:
185
+ if tensor_parallel > 1:
186
+ torch.distributed.barrier()
187
+ start.record()
188
+ scores, sequences = [], [input_ids]
189
+ while not should_stop(sequences[-1], inference_params):
190
+ scores.append(get_logits(sequences[-1], inference_params))
191
+ inference_params.seqlen_offset += sequences[-1].shape[1]
192
+ sequences.append(sample_tokens(scores[-1], inference_params))
193
+ if enable_timing:
194
+ end.record()
195
+ if tensor_parallel > 1:
196
+ torch.distributed.barrier()
197
+ torch.cuda.synchronize()
198
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
199
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
200
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
201
+
202
+
203
+ class GenerationMixin:
204
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
205
+ raise NotImplementedError
206
+
207
+ def generate(
208
+ self,
209
+ input_ids,
210
+ max_length,
211
+ top_k=1,
212
+ top_p=0.0,
213
+ temperature=1.0,
214
+ return_dict_in_generate=False,
215
+ output_scores=False,
216
+ **kwargs,
217
+ ):
218
+ output = decode(
219
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
220
+ )
221
+ if not output_scores:
222
+ output.scores = None
223
+ return output if return_dict_in_generate else output.sequences
224
+
225
+
226
+ def allocate_inference_cache(
227
+ max_batch_size,
228
+ max_seqlen,
229
+ nheads,
230
+ headdim,
231
+ layers: Union[int, Sequence],
232
+ device,
233
+ dtype=torch.float16,
234
+ ):
235
+ assert dtype in [torch.float16, torch.bfloat16, torch.float32]
236
+ kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
237
+ if isinstance(layers, int):
238
+ layers = range(layers)
239
+ return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
240
+
241
+
242
+ @dataclass
243
+ class DecodingCGCache:
244
+ max_batch_size: int = 0
245
+ max_seqlen: int = 0
246
+ device = None
247
+ dtype = None
248
+ callables: dict = field(default_factory=dict)
249
+ mempool = None
250
+ inference_params: Optional[InferenceParams] = None
251
+ run: Optional[Callable] = None
252
+
253
+
254
+ @torch.inference_mode()
255
+ def update_graph_cache(
256
+ model,
257
+ cache,
258
+ batch_size,
259
+ seqlen_og,
260
+ max_seqlen,
261
+ decoding_seqlens=(1,),
262
+ tensor_parallel=1,
263
+ dtype=None,
264
+ n_warmups=2,
265
+ ):
266
+ if cache is None:
267
+ cache = DecodingCGCache()
268
+ param_example = next(iter(model.parameters()))
269
+ device = param_example.device
270
+ if dtype is None:
271
+ dtype = param_example.dtype
272
+ if (
273
+ (device, dtype) != (cache.device, cache.dtype)
274
+ or batch_size > cache.max_batch_size
275
+ or max_seqlen > cache.max_seqlen
276
+ ): # Invalidate the cache
277
+ cache.callables = {}
278
+ cache.mempool = None
279
+ cache.inference_params = None
280
+ gc.collect()
281
+ cache.device, cache.dtype = device, dtype
282
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
283
+ if hasattr(model, "allocate_inference_cache"):
284
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
285
+ else:
286
+ headdim = getattr(
287
+ model.config,
288
+ "head_dim",
289
+ model.config.hidden_size // model.config.num_attention_heads,
290
+ )
291
+ inf_cache = allocate_inference_cache(
292
+ batch_size,
293
+ max_seqlen,
294
+ model.config.num_attention_heads // tensor_parallel,
295
+ headdim,
296
+ model.config.num_hidden_layers,
297
+ device,
298
+ dtype,
299
+ )
300
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
301
+ cache.inference_params = InferenceParams(
302
+ max_seqlen=max_seqlen,
303
+ max_batch_size=batch_size,
304
+ seqlen_offset=seqlen_og,
305
+ key_value_memory_dict=inf_cache,
306
+ lengths_per_sample=lengths_per_sample,
307
+ )
308
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
309
+ for decoding_seqlen in decoding_seqlens:
310
+ if (batch_size, decoding_seqlen) not in cache.callables:
311
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
312
+ model,
313
+ cache.inference_params,
314
+ batch_size,
315
+ max_seqlen,
316
+ decoding_seqlen=decoding_seqlen,
317
+ mempool=cache.mempool,
318
+ n_warmups=n_warmups,
319
+ )
320
+
321
+ def dispatch(input_ids, position_ids, seqlen):
322
+ batch_size, decoding_seqlen = input_ids.shape[:2]
323
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
324
+
325
+ cache.run = dispatch
326
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
327
+ return cache
328
+
329
+
330
+ def capture_graph(
331
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
332
+ ):
333
+ device = next(iter(model.parameters())).device
334
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
335
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
336
+ seqlen_offset_og = inference_params.seqlen_offset
337
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
338
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
339
+
340
+ # Warmup before capture
341
+ s = torch.cuda.Stream()
342
+ s.wait_stream(torch.cuda.current_stream())
343
+ with torch.cuda.stream(s):
344
+ for _ in range(n_warmups):
345
+ logits = model(
346
+ input_ids,
347
+ position_ids=position_ids,
348
+ inference_params=inference_params,
349
+ num_last_tokens=decoding_seqlen,
350
+ ).logits
351
+ s.synchronize()
352
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
353
+ # which requires that graph launch and non-captured launch to not overlap (I think,
354
+ # that's how I interpret the documentation). I'm not sure if this is required.
355
+ if torch.distributed.is_initialized():
356
+ torch.distributed.barrier()
357
+ torch.cuda.current_stream().wait_stream(s)
358
+ # Captures the graph
359
+ # To allow capture, automatically sets a side stream as the current stream in the context
360
+ graph = torch.cuda.CUDAGraph()
361
+ with torch.cuda.graph(graph, pool=mempool):
362
+ logits = model(
363
+ input_ids,
364
+ position_ids=position_ids,
365
+ inference_params=inference_params,
366
+ num_last_tokens=decoding_seqlen,
367
+ ).logits
368
+
369
+ def run(new_input_ids, new_position_ids, seqlen):
370
+ inference_params.lengths_per_sample[:] = seqlen
371
+ input_ids.copy_(new_input_ids)
372
+ position_ids.copy_(new_position_ids)
373
+ graph.replay()
374
+ return logits.clone()
375
+
376
+ inference_params.seqlen_offset = seqlen_offset_og
377
+ return run
mamba_ssm/utils/hf.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+
5
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6
+ from transformers.utils.hub import cached_file
7
+
8
+
9
+ def load_config_hf(model_name):
10
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11
+ return json.load(open(resolved_archive_file))
12
+
13
+
14
+ def load_state_dict_hf(model_name, device=None, dtype=None):
15
+ # If not fp32, then we don't want to load directly to the GPU
16
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18
+ return torch.load(resolved_archive_file, map_location=mapped_device)
19
+ # Convert dtype before moving to GPU to save memory
20
+ if dtype is not None:
21
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23
+ return state_dict
neural_methods/__init__.py ADDED
File without changes
neural_methods/loss/NegPearsonLoss.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import argparse, os
5
+ import pandas as pd
6
+ import numpy as np
7
+ import random
8
+ import math
9
+ from torchvision import transforms
10
+ from torch import nn
11
+
12
+
13
+ class Neg_Pearson(nn.Module):
14
+ def __init__(self):
15
+ super(Neg_Pearson, self).__init__()
16
+ return
17
+
18
+ def forward(self, preds, labels):
19
+ cos = nn.CosineSimilarity(dim=0, eps=1e-6)
20
+ pearson = cos(preds - preds.mean(dim=0, keepdim=True), labels - labels.mean(dim=0, keepdim=True))
21
+ return torch.mean(1 - pearson)
22
+
23
+
neural_methods/loss/PhysFormerLossComputer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from here: https://github.com/ZitongYu/PhysFormer/blob/main/TorchLossComputer.py
3
+ Modifed based on the HR-CNN here: https://github.com/radimspetlik/hr-cnn
4
+ '''
5
+ import math
6
+ import torch
7
+ from torch.autograd import Variable
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ import pdb
11
+ import torch.nn as nn
12
+
13
+ def normal_sampling(mean, label_k, std):
14
+ return math.exp(-(label_k-mean)**2/(2*std**2))/(math.sqrt(2*math.pi)*std)
15
+
16
+ def kl_loss(inputs, labels):
17
+ # Reshape the labels tensor to match the shape of inputs
18
+ labels = labels.view(1, -1)
19
+
20
+ # Compute the KL Div Loss
21
+ criterion = nn.KLDivLoss(reduction='sum')
22
+ loss = criterion(F.log_softmax(inputs, dim=-1), labels)
23
+ return loss
24
+
25
+ class TorchLossComputer(object):
26
+ @staticmethod
27
+ def compute_complex_absolute_given_k(output, k, N):
28
+ two_pi_n_over_N = torch.autograd.Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N
29
+ hanning = torch.autograd.Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1)
30
+
31
+ k = k.type(torch.FloatTensor).cuda()
32
+ two_pi_n_over_N = two_pi_n_over_N.cuda()
33
+ hanning = hanning.cuda()
34
+
35
+ output = output.view(1, -1) * hanning
36
+ output = output.view(1, 1, -1).type(torch.cuda.FloatTensor)
37
+ k = k.view(1, -1, 1)
38
+ two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1)
39
+ complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \
40
+ + torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2
41
+
42
+ return complex_absolute
43
+
44
+ @staticmethod
45
+ def complex_absolute(output, Fs, bpm_range=None):
46
+ output = output.view(1, -1)
47
+
48
+ N = output.size()[1]
49
+
50
+ unit_per_hz = Fs / N
51
+ feasible_bpm = bpm_range / 60.0
52
+ k = feasible_bpm / unit_per_hz
53
+
54
+ # only calculate feasible PSD range [0.7,4] Hz
55
+ complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N)
56
+
57
+ return (1.0 / complex_absolute.sum()) * complex_absolute # Analogous Softmax operator
58
+
59
+ @staticmethod
60
+ def cross_entropy_power_spectrum_loss(inputs, target, Fs):
61
+ inputs = inputs.view(1, -1)
62
+ target = target.view(1, -1)
63
+ bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
64
+
65
+ complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
66
+
67
+ whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
68
+ whole_max_idx = whole_max_idx.type(torch.float)
69
+
70
+ return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
71
+
72
+ @staticmethod
73
+ def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma):
74
+ inputs = inputs.view(1, -1)
75
+ target = target.view(1, -1)
76
+ bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
77
+
78
+ complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
79
+
80
+ whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
81
+ whole_max_idx = whole_max_idx.type(torch.float)
82
+
83
+ #pdb.set_trace()
84
+ criterion = FocalLoss(gamma=gamma)
85
+
86
+ return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
87
+
88
+
89
+ @staticmethod
90
+ def cross_entropy_power_spectrum_forward_pred(inputs, Fs):
91
+ inputs = inputs.view(1, -1)
92
+ bpm_range = torch.arange(40, 190, dtype=torch.float).cuda()
93
+
94
+ complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
95
+
96
+ whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
97
+ whole_max_idx = whole_max_idx.type(torch.float)
98
+
99
+ return whole_max_idx
100
+
101
+ @staticmethod
102
+ def cross_entropy_power_spectrum_DLDL_softmax2(inputs, target, Fs, std):
103
+ target_distribution = [normal_sampling(int(target), i, std) for i in range(40, 180)]
104
+ target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
105
+ target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda'))
106
+
107
+ inputs = inputs.view(1, -1)
108
+ target = target.view(1, -1)
109
+
110
+ bpm_range = torch.arange(40, 180, dtype=torch.float).to(torch.device('cuda'))
111
+
112
+ ca = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
113
+
114
+ fre_distribution = ca/torch.sum(ca)
115
+ loss_distribution_kl = kl_loss(fre_distribution, target_distribution)
116
+
117
+ whole_max_val, whole_max_idx = ca.view(-1).max(0)
118
+ whole_max_idx = whole_max_idx.type(torch.float)
119
+ return loss_distribution_kl, F.cross_entropy(ca, (target-bpm_range[0]).view(1).type(torch.long)), torch.abs(target[0]-bpm_range[0]-whole_max_idx)
120
+
neural_methods/loss/PhysNetNegPearsonLoss.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import argparse, os
5
+ import pandas as pd
6
+ import numpy as np
7
+ import random
8
+ import math
9
+ from torchvision import transforms
10
+ from torch import nn
11
+
12
+
13
+ class Neg_Pearson(nn.Module):
14
+ """
15
+ The Neg_Pearson Module is from the orignal author of Physnet.
16
+ Code of 'Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks'
17
+ source: https://github.com/ZitongYu/PhysNet/blob/master/NegPearsonLoss.py
18
+ """
19
+
20
+ def __init__(self):
21
+ super(Neg_Pearson, self).__init__()
22
+ return
23
+
24
+
25
+ def forward(self, preds, labels):
26
+ loss = 0
27
+ for i in range(preds.shape[0]):
28
+ sum_x = torch.sum(preds[i])
29
+ sum_y = torch.sum(labels[i])
30
+ sum_xy = torch.sum(preds[i]*labels[i])
31
+ sum_x2 = torch.sum(torch.pow(preds[i],2))
32
+ sum_y2 = torch.sum(torch.pow(labels[i],2))
33
+ N = preds.shape[1]
34
+ pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2))))
35
+ loss += 1 - pearson
36
+
37
+
38
+ loss = loss/preds.shape[0]
39
+ return loss
40
+
41
+
42
+
43
+
neural_methods/loss/RythmFormerLossComputer.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from here: https://github.com/ZitongYu/PhysFormer/TorchLossComputer.py
3
+ Modifed based on the HR-CNN here: https://github.com/radimspetlik/hr-cnn
4
+ '''
5
+ import math
6
+ import torch
7
+ from torch.autograd import Variable
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ import torch.nn as nn
11
+ from evaluation.post_process import calculate_metric_per_video
12
+
13
+ def normal_sampling(mean, label_k, std):
14
+ return math.exp(-(label_k-mean)**2/(2*std**2))/(math.sqrt(2*math.pi)*std)
15
+
16
+ def kl_loss(inputs, labels):
17
+ criterion = nn.KLDivLoss(reduce=False)
18
+ outputs = torch.log(inputs)
19
+ loss = criterion(outputs, labels)
20
+ #loss = loss.sum()/loss.shape[0]
21
+ loss = loss.sum()
22
+ return loss
23
+
24
+ class Neg_Pearson(nn.Module): # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss
25
+ def __init__(self):
26
+ super(Neg_Pearson,self).__init__()
27
+
28
+ def forward(self, preds, labels): # all variable operation
29
+ loss = 0
30
+ for i in range(preds.shape[0]):
31
+ sum_x = torch.sum(preds[i]) # x
32
+ sum_y = torch.sum(labels[i]) # y
33
+ sum_xy = torch.sum(preds[i]*labels[i]) # xy
34
+ sum_x2 = torch.sum(torch.pow(preds[i],2)) # x^2
35
+ sum_y2 = torch.sum(torch.pow(labels[i],2)) # y^2
36
+ N = preds.shape[1]
37
+ pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2))))
38
+ loss += 1 - pearson
39
+
40
+ loss = loss/preds.shape[0]
41
+ return loss
42
+
43
+ class RhythmFormer_Loss(nn.Module):
44
+ def __init__(self):
45
+ super(RhythmFormer_Loss,self).__init__()
46
+ self.criterion_Pearson = Neg_Pearson()
47
+ def forward(self, pred_ppg, labels , epoch , FS , diff_flag):
48
+ loss_time = self.criterion_Pearson(pred_ppg.view(1,-1) , labels.view(1,-1))
49
+ loss_CE , loss_distribution_kl = TorchLossComputer.Frequency_loss(pred_ppg.squeeze(-1), labels.squeeze(-1), diff_flag=diff_flag, Fs=FS, std=3.0)
50
+ loss_hr = TorchLossComputer.HR_loss(pred_ppg.squeeze(-1), labels.squeeze(-1), diff_flag=diff_flag, Fs=FS, std=3.0)
51
+ if torch.isnan(loss_time) :
52
+ loss_time = 0
53
+
54
+ loss = 0.2 * loss_time + 1.0 * loss_CE + 1.0 * loss_hr
55
+ return loss
56
+
57
+ class TorchLossComputer(object):
58
+ @staticmethod
59
+ def compute_complex_absolute_given_k(output, k, N):
60
+ two_pi_n_over_N = Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N
61
+ hanning = Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1)
62
+
63
+ k = k.type(torch.FloatTensor).cuda()
64
+ two_pi_n_over_N = two_pi_n_over_N.cuda()
65
+ hanning = hanning.cuda()
66
+
67
+ output = output.view(1, -1) * hanning
68
+ output = output.view(1, 1, -1).type(torch.cuda.FloatTensor)
69
+ k = k.view(1, -1, 1)
70
+ two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1)
71
+ complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \
72
+ + torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2
73
+
74
+ return complex_absolute
75
+
76
+ @staticmethod
77
+ def complex_absolute(output, Fs, bpm_range=None):
78
+ output = output.view(1, -1)
79
+
80
+ N = output.size()[1]
81
+
82
+ unit_per_hz = Fs / N
83
+ feasible_bpm = bpm_range / 60.0
84
+ k = feasible_bpm / unit_per_hz
85
+
86
+ # only calculate feasible PSD range [0.7,4]Hz
87
+ complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N)
88
+
89
+ return (1.0 / complex_absolute.sum()) * complex_absolute # Analogous Softmax operator
90
+
91
+
92
+ @staticmethod
93
+ def cross_entropy_power_spectrum_loss(inputs, target, Fs):
94
+ inputs = inputs.view(1, -1)
95
+ target = target.view(1, -1)
96
+ bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
97
+ #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
98
+
99
+ complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
100
+
101
+ whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
102
+ whole_max_idx = whole_max_idx.type(torch.float)
103
+
104
+ #pdb.set_trace()
105
+
106
+ #return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2
107
+ return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
108
+
109
+ @staticmethod
110
+ def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma):
111
+ inputs = inputs.view(1, -1)
112
+ target = target.view(1, -1)
113
+ bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
114
+ #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
115
+
116
+ complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
117
+
118
+ whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
119
+ whole_max_idx = whole_max_idx.type(torch.float)
120
+
121
+ #pdb.set_trace()
122
+ criterion = FocalLoss(gamma=gamma)
123
+
124
+ #return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2
125
+ return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
126
+
127
+
128
+ @staticmethod
129
+ def cross_entropy_power_spectrum_forward_pred(inputs, Fs):
130
+ inputs = inputs.view(1, -1)
131
+ bpm_range = torch.arange(40, 190, dtype=torch.float).cuda()
132
+ #bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
133
+ #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
134
+
135
+ complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
136
+
137
+ whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
138
+ whole_max_idx = whole_max_idx.type(torch.float)
139
+
140
+ return whole_max_idx
141
+
142
+ @staticmethod
143
+ def Frequency_loss(inputs, target, diff_flag , Fs, std):
144
+ hr_gt, pred_hr_peak, SNR, macc = calculate_metric_per_video(inputs.detach().cpu(), target.detach().cpu(), diff_flag = diff_flag, fs=Fs, hr_method='FFT')
145
+ inputs = inputs.view(1, -1)
146
+ target = target.view(1, -1)
147
+ bpm_range = torch.arange(45, 150, dtype=torch.float).to(torch.device('cuda'))
148
+ ca = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
149
+ sa = ca/torch.sum(ca)
150
+
151
+ target_distribution = [normal_sampling(int(hr_gt), i, std) for i in range(45, 150)]
152
+ target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
153
+ target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda'))
154
+
155
+ hr_gt = torch.tensor(hr_gt-45).view(1).type(torch.long).to(torch.device('cuda'))
156
+ return F.cross_entropy(ca, hr_gt) , kl_loss(sa , target_distribution)
157
+
158
+ @staticmethod
159
+ def HR_loss(inputs, target, diff_flag , Fs, std):
160
+ psd_gt, psd_pred, SNR, macc = calculate_metric_per_video(inputs.detach().cpu(), target.detach().cpu(), diff_flag = diff_flag, fs=Fs, hr_method='Peak')
161
+ pred_distribution = [normal_sampling(np.argmax(psd_pred), i, std) for i in range(psd_pred.size)]
162
+ pred_distribution = [i if i > 1e-15 else 1e-15 for i in pred_distribution]
163
+ pred_distribution = torch.Tensor(pred_distribution).to(torch.device('cuda'))
164
+ target_distribution = [normal_sampling(np.argmax(psd_gt), i, std) for i in range(psd_gt.size)]
165
+ target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
166
+ target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda'))
167
+ return kl_loss(pred_distribution , target_distribution)
neural_methods/loss/__init__.py ADDED
File without changes
neural_methods/model/BigSmall.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BigSmall: Multitask Network for AU / Respiration / PPG
2
+
3
+ BigSmall: Efficient Multi-Task Learning
4
+ For Physiological Measurements
5
+ Girish Narayanswamy, Yujia (Nancy) Liu, Yuzhe Yang, Chengqian (Jack) Ma,
6
+ Xin Liu, Daniel McDuff, Shwetak Patel
7
+
8
+ https://arxiv.org/abs/2303.11573
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ #####################################################
16
+ ############ Wrapping Time Shift Module #############
17
+ #####################################################
18
+ class WTSM(nn.Module):
19
+ def __init__(self, n_segment=3, fold_div=3):
20
+ super(WTSM, self).__init__()
21
+ self.n_segment = n_segment
22
+ self.fold_div = fold_div
23
+
24
+ def forward(self, x):
25
+ nt, c, h, w = x.size()
26
+ n_batch = nt // self.n_segment
27
+ x = x.view(n_batch, self.n_segment, c, h, w)
28
+ fold = c // self.fold_div
29
+ out = torch.zeros_like(x)
30
+ out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
31
+ out[:, -1, :fold] = x[:, 0, :fold] # wrap left
32
+ out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
33
+ out[:, 0, fold: 2 * fold] = x[:, -1, fold: 2 * fold] # wrap right
34
+ out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # no shift for final fold
35
+ return out.view(nt, c, h, w)
36
+
37
+
38
+
39
+ #######################################################################################
40
+ ##################################### BigSmall Model ##################################
41
+ #######################################################################################
42
+ class BigSmall(nn.Module):
43
+
44
+ def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3,
45
+ dropout_rate1=0.25, dropout_rate2=0.5, dropout_rate3=0.5, pool_size1=(2, 2), pool_size2=(4,4),
46
+ nb_dense=128, out_size_bvp=1, out_size_resp=1, out_size_au=12, n_segment=3):
47
+
48
+ super(BigSmall, self).__init__()
49
+
50
+ self.in_channels = in_channels
51
+ self.kernel_size = kernel_size
52
+ self.dropout_rate1 = dropout_rate1
53
+ self.dropout_rate2 = dropout_rate2
54
+ self.dropout_rate3 = dropout_rate3
55
+ self.pool_size1 = pool_size1
56
+ self.pool_size2 = pool_size2
57
+ self.nb_filters1 = nb_filters1
58
+ self.nb_filters2 = nb_filters2
59
+ self.nb_dense = nb_dense
60
+
61
+ self.out_size_bvp = out_size_bvp
62
+ self.out_size_resp = out_size_resp
63
+ self.out_size_au = out_size_au
64
+
65
+ self.n_segment = n_segment
66
+
67
+ # Big Convolutional Layers
68
+ self.big_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
69
+ self.big_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
70
+ self.big_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
71
+ self.big_conv4 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
72
+ self.big_conv5 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
73
+ self.big_conv6 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
74
+
75
+ # Big Avg Pooling / Dropout Layers
76
+ self.big_avg_pooling1 = nn.AvgPool2d(self.pool_size1)
77
+ self.big_dropout1 = nn.Dropout(self.dropout_rate1)
78
+ self.big_avg_pooling2 = nn.AvgPool2d(self.pool_size1)
79
+ self.big_dropout2 = nn.Dropout(self.dropout_rate2)
80
+ self.big_avg_pooling3 = nn.AvgPool2d(self.pool_size2)
81
+ self.big_dropout3 = nn.Dropout(self.dropout_rate3)
82
+
83
+ # TSM layers
84
+ self.TSM_1 = WTSM(n_segment=self.n_segment)
85
+ self.TSM_2 = WTSM(n_segment=self.n_segment)
86
+ self.TSM_3 = WTSM(n_segment=self.n_segment)
87
+ self.TSM_4 = WTSM(n_segment=self.n_segment)
88
+
89
+ # Small Convolutional Layers
90
+ self.small_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True)
91
+ self.small_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True)
92
+ self.small_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True)
93
+ self.small_conv4 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1,1), bias=True)
94
+
95
+ # AU Fully Connected Layers
96
+ self.au_fc1 = nn.Linear(5184, self.nb_dense, bias=True)
97
+ self.au_fc2 = nn.Linear(self.nb_dense, self.out_size_au, bias=True)
98
+
99
+ # BVP Fully Connected Layers
100
+ self.bvp_fc1 = nn.Linear(5184, self.nb_dense, bias=True)
101
+ self.bvp_fc2 = nn.Linear(self.nb_dense, self.out_size_bvp, bias=True)
102
+
103
+ # Resp Fully Connected Layers
104
+ self.resp_fc1 = nn.Linear(5184, self.nb_dense, bias=True)
105
+ self.resp_fc2 = nn.Linear(self.nb_dense, self.out_size_resp, bias=True)
106
+
107
+
108
+ def forward(self, inputs, params=None):
109
+
110
+ big_input = inputs[0] # big res
111
+ small_input = inputs[1] # small res
112
+
113
+ # reshape Big
114
+ nt, c, h, w = big_input.size()
115
+ n_batch = nt // self.n_segment
116
+ big_input = big_input.view(n_batch, self.n_segment, c, h, w)
117
+ big_input = torch.moveaxis(big_input, 1, 2) # color channel to idx 1, sequence channel to idx 2
118
+ big_input = big_input[:, :, 0, :, :] # use only first frame in sequences
119
+
120
+
121
+ # Big Conv block 1
122
+ b1 = nn.functional.relu(self.big_conv1(big_input))
123
+ b2 = nn.functional.relu(self.big_conv2(b1))
124
+ b3 = self.big_avg_pooling1(b2)
125
+ b4 = self.big_dropout1(b3)
126
+
127
+ # Big Conv block 2
128
+ b5 = nn.functional.relu(self.big_conv3(b4))
129
+ b6 = nn.functional.relu(self.big_conv4(b5))
130
+ b7 = self.big_avg_pooling2(b6)
131
+ b8 = self.big_dropout2(b7)
132
+
133
+ # Big Conv block 3
134
+ b9 = nn.functional.relu(self.big_conv5(b8))
135
+ b10 = nn.functional.relu(self.big_conv6(b9))
136
+ b11 = self.big_avg_pooling3(b10)
137
+ b12 = self.big_dropout3(b11)
138
+
139
+ # Reformat Big Shape For Concat w/ Small Branch
140
+ b13 = torch.stack((b12, b12, b12), 2) #TODO: this is hardcoded for num_segs = 3: change this...
141
+ b14 = torch.moveaxis(b13, 1, 2)
142
+ bN, bD, bC, bH, bW = b14.size()
143
+ b15 = b14.reshape(int(bN*bD), bC, bH, bW)
144
+
145
+ # Small Conv block 1
146
+ s1 = self.TSM_1(small_input)
147
+ s2 = nn.functional.relu(self.small_conv1(s1))
148
+ s3 = self.TSM_2(s2)
149
+ s4 = nn.functional.relu(self.small_conv2(s3))
150
+
151
+ # Small Conv block 2
152
+ s5 = self.TSM_3(s4)
153
+ s6 = nn.functional.relu(self.small_conv3(s5))
154
+ s7 = self.TSM_4(s6)
155
+ s8 = nn.functional.relu(self.small_conv4(s7))
156
+
157
+ # Shared Layers
158
+ concat = b15 + s8 # sum layers
159
+
160
+ # share1 = concat.view(concat.size(0), -1) # flatten entire tensors
161
+ share1 = concat.reshape(concat.size(0), -1)
162
+
163
+ # AU Output Layers
164
+ aufc1 = nn.functional.relu(self.au_fc1(share1))
165
+ au_out = self.au_fc2(aufc1)
166
+
167
+ # BVP Output Layers
168
+ bvpfc1 = nn.functional.relu(self.bvp_fc1(share1))
169
+ bvp_out = self.bvp_fc2(bvpfc1)
170
+
171
+ # Resp Output Layers
172
+ respfc1 = nn.functional.relu(self.resp_fc1(share1))
173
+ resp_out = self.resp_fc2(respfc1)
174
+
175
+ return au_out, bvp_out, resp_out
176
+
177
+
neural_methods/model/DeepPhys.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DeepPhys - 2D Convolutional Attention Network.
2
+ DeepPhys: Video-Based Physiological Measurement Using Convolutional Attention Networks
3
+ ECCV, 2018
4
+ Weixuan Chen, Daniel McDuff
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class Attention_mask(nn.Module):
12
+ def __init__(self):
13
+ super(Attention_mask, self).__init__()
14
+
15
+ def forward(self, x):
16
+ xsum = torch.sum(x, dim=2, keepdim=True)
17
+ xsum = torch.sum(xsum, dim=3, keepdim=True)
18
+ xshape = tuple(x.size())
19
+ return x / xsum * xshape[2] * xshape[3] * 0.5
20
+
21
+ def get_config(self):
22
+ """May be generated manually. """
23
+ config = super(Attention_mask, self).get_config()
24
+ return config
25
+
26
+
27
+ class DeepPhys(nn.Module):
28
+
29
+ def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
30
+ dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, img_size=36):
31
+ """Definition of DeepPhys.
32
+ Args:
33
+ in_channels: the number of input channel. Default: 3
34
+ img_size: height/width of each frame. Default: 36.
35
+ Returns:
36
+ DeepPhys model.
37
+ """
38
+ super(DeepPhys, self).__init__()
39
+ self.in_channels = in_channels
40
+ self.kernel_size = kernel_size
41
+ self.dropout_rate1 = dropout_rate1
42
+ self.dropout_rate2 = dropout_rate2
43
+ self.pool_size = pool_size
44
+ self.nb_filters1 = nb_filters1
45
+ self.nb_filters2 = nb_filters2
46
+ self.nb_dense = nb_dense
47
+ # Motion branch convs
48
+ self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
49
+ bias=True)
50
+ self.motion_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
51
+ self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
52
+ bias=True)
53
+ self.motion_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
54
+ # Apperance branch convs
55
+ self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size,
56
+ padding=(1, 1), bias=True)
57
+ self.apperance_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
58
+ self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size,
59
+ padding=(1, 1), bias=True)
60
+ self.apperance_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
61
+ # Attention layers
62
+ self.apperance_att_conv1 = nn.Conv2d(self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
63
+ self.attn_mask_1 = Attention_mask()
64
+ self.apperance_att_conv2 = nn.Conv2d(self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
65
+ self.attn_mask_2 = Attention_mask()
66
+ # Avg pooling
67
+ self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
68
+ self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
69
+ self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
70
+ # Dropout layers
71
+ self.dropout_1 = nn.Dropout(self.dropout_rate1)
72
+ self.dropout_2 = nn.Dropout(self.dropout_rate1)
73
+ self.dropout_3 = nn.Dropout(self.dropout_rate1)
74
+ self.dropout_4 = nn.Dropout(self.dropout_rate2)
75
+ # Dense layers
76
+ if img_size == 36:
77
+ self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True)
78
+ elif img_size == 72:
79
+ self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True)
80
+ elif img_size == 96:
81
+ self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True)
82
+ else:
83
+ raise Exception('Unsupported image size')
84
+ self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True)
85
+
86
+ def forward(self, inputs, params=None):
87
+
88
+ diff_input = inputs[:, :3, :, :]
89
+ raw_input = inputs[:, 3:, :, :]
90
+
91
+ d1 = torch.tanh(self.motion_conv1(diff_input))
92
+ d2 = torch.tanh(self.motion_conv2(d1))
93
+
94
+ r1 = torch.tanh(self.apperance_conv1(raw_input))
95
+ r2 = torch.tanh(self.apperance_conv2(r1))
96
+
97
+ g1 = torch.sigmoid(self.apperance_att_conv1(r2))
98
+ g1 = self.attn_mask_1(g1)
99
+ gated1 = d2 * g1
100
+
101
+ d3 = self.avg_pooling_1(gated1)
102
+ d4 = self.dropout_1(d3)
103
+
104
+ r3 = self.avg_pooling_2(r2)
105
+ r4 = self.dropout_2(r3)
106
+
107
+ d5 = torch.tanh(self.motion_conv3(d4))
108
+ d6 = torch.tanh(self.motion_conv4(d5))
109
+
110
+ r5 = torch.tanh(self.apperance_conv3(r4))
111
+ r6 = torch.tanh(self.apperance_conv4(r5))
112
+
113
+ g2 = torch.sigmoid(self.apperance_att_conv2(r6))
114
+ g2 = self.attn_mask_2(g2)
115
+ gated2 = d6 * g2
116
+
117
+ d7 = self.avg_pooling_3(gated2)
118
+ d8 = self.dropout_3(d7)
119
+ d9 = d8.view(d8.size(0), -1)
120
+ d10 = torch.tanh(self.final_dense_1(d9))
121
+ d11 = self.dropout_4(d10)
122
+ out = self.final_dense_2(d11)
123
+
124
+ return out
125
+
neural_methods/model/EfficientPhys.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EfficientPhys: Enabling Simple, Fast and Accurate Camera-Based Vitals Measurement
2
+ Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV 2023)
3
+ Xin Liu, Brial Hill, Ziheng Jiang, Shwetak Patel, Daniel McDuff
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class Attention_mask(nn.Module):
11
+ def __init__(self):
12
+ super(Attention_mask, self).__init__()
13
+
14
+ def forward(self, x):
15
+ xsum = torch.sum(x, dim=2, keepdim=True)
16
+ xsum = torch.sum(xsum, dim=3, keepdim=True)
17
+ xshape = tuple(x.size())
18
+ return x / xsum * xshape[2] * xshape[3] * 0.5
19
+
20
+ def get_config(self):
21
+ """May be generated manually. """
22
+ config = super(Attention_mask, self).get_config()
23
+ return config
24
+
25
+
26
+ class TSM(nn.Module):
27
+ def __init__(self, n_segment=10, fold_div=3):
28
+ super(TSM, self).__init__()
29
+ self.n_segment = n_segment
30
+ self.fold_div = fold_div
31
+
32
+ def forward(self, x):
33
+ nt, c, h, w = x.size()
34
+ n_batch = nt // self.n_segment
35
+ x = x.view(n_batch, self.n_segment, c, h, w)
36
+ fold = c // self.fold_div
37
+ out = torch.zeros_like(x)
38
+ out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
39
+ out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
40
+ out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
41
+ return out.view(nt, c, h, w)
42
+
43
+
44
+ class EfficientPhys(nn.Module):
45
+
46
+ def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
47
+ dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20, img_size=36, channel='raw'):
48
+ super(EfficientPhys, self).__init__()
49
+ self.in_channels = in_channels
50
+ self.kernel_size = kernel_size
51
+ self.dropout_rate1 = dropout_rate1
52
+ self.dropout_rate2 = dropout_rate2
53
+ self.pool_size = pool_size
54
+ self.nb_filters1 = nb_filters1
55
+ self.nb_filters2 = nb_filters2
56
+ self.nb_dense = nb_dense
57
+ # TSM layers
58
+ self.TSM_1 = TSM(n_segment=frame_depth)
59
+ self.TSM_2 = TSM(n_segment=frame_depth)
60
+ self.TSM_3 = TSM(n_segment=frame_depth)
61
+ self.TSM_4 = TSM(n_segment=frame_depth)
62
+ # Motion branch convs
63
+ self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
64
+ bias=True)
65
+ self.motion_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
66
+ self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
67
+ bias=True)
68
+ self.motion_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
69
+ # Attention layers
70
+ self.apperance_att_conv1 = nn.Conv2d(self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
71
+ self.attn_mask_1 = Attention_mask()
72
+ self.apperance_att_conv2 = nn.Conv2d(self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
73
+ self.attn_mask_2 = Attention_mask()
74
+ # Avg pooling
75
+ self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
76
+ self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
77
+ self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
78
+ # Dropout layers
79
+ self.dropout_1 = nn.Dropout(self.dropout_rate1)
80
+ self.dropout_2 = nn.Dropout(self.dropout_rate1)
81
+ self.dropout_3 = nn.Dropout(self.dropout_rate1)
82
+ self.dropout_4 = nn.Dropout(self.dropout_rate2)
83
+ # Dense layers
84
+ if img_size == 36:
85
+ self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True)
86
+ elif img_size == 72:
87
+ self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True)
88
+ elif img_size == 96:
89
+ self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True)
90
+ else:
91
+ raise Exception('Unsupported image size')
92
+ self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True)
93
+ self.batch_norm = nn.BatchNorm2d(3)
94
+ self.channel = channel
95
+
96
+ def forward(self, inputs, params=None):
97
+ inputs = torch.diff(inputs, dim=0)
98
+ inputs = self.batch_norm(inputs)
99
+
100
+ network_input = self.TSM_1(inputs)
101
+ d1 = torch.tanh(self.motion_conv1(network_input))
102
+ d1 = self.TSM_2(d1)
103
+ d2 = torch.tanh(self.motion_conv2(d1))
104
+
105
+ g1 = torch.sigmoid(self.apperance_att_conv1(d2))
106
+ g1 = self.attn_mask_1(g1)
107
+ gated1 = d2 * g1
108
+
109
+ d3 = self.avg_pooling_1(gated1)
110
+ d4 = self.dropout_1(d3)
111
+
112
+ d4 = self.TSM_3(d4)
113
+ d5 = torch.tanh(self.motion_conv3(d4))
114
+ d5 = self.TSM_4(d5)
115
+ d6 = torch.tanh(self.motion_conv4(d5))
116
+
117
+ g2 = torch.sigmoid(self.apperance_att_conv2(d6))
118
+ g2 = self.attn_mask_2(g2)
119
+ gated2 = d6 * g2
120
+
121
+ d7 = self.avg_pooling_3(gated2)
122
+ d8 = self.dropout_3(d7)
123
+ d9 = d8.view(d8.size(0), -1)
124
+ d10 = torch.tanh(self.final_dense_1(d9))
125
+ d11 = self.dropout_4(d10)
126
+ out = self.final_dense_2(d11)
127
+
128
+ return out
neural_methods/model/FactorizePhys/FSAM.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
3
+ NeurIPS 2024
4
+ Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.modules.batchnorm import _BatchNorm
11
+ import numpy as np
12
+ import neurokit2 as nk
13
+
14
+
15
+ class _MatrixDecompositionBase(nn.Module):
16
+ def __init__(self, device, md_config, debug=False, dim="3D"):
17
+ super().__init__()
18
+
19
+ self.dim = dim
20
+ self.md_type = md_config["MD_TYPE"]
21
+ if dim == "3D":
22
+ self.transform = md_config["MD_TRANSFORM"]
23
+ self.S = md_config["MD_S"]
24
+ self.R = md_config["MD_R"]
25
+ self.debug = debug
26
+
27
+ self.train_steps = md_config["MD_STEPS"]
28
+ self.eval_steps = md_config["MD_STEPS"]
29
+
30
+ self.inv_t = md_config["INV_T"]
31
+ self.eta = md_config["ETA"]
32
+
33
+ self.rand_init = md_config["RAND_INIT"]
34
+ self.device = device
35
+
36
+ # print('Dimension:', self.dim)
37
+ # print('S', self.S)
38
+ # print('D', self.D)
39
+ # print('R', self.R)
40
+ # print('train_steps', self.train_steps)
41
+ # print('eval_steps', self.eval_steps)
42
+ # print('inv_t', self.inv_t)
43
+ # print('eta', self.eta)
44
+ # print('rand_init', self.rand_init)
45
+
46
+ def _build_bases(self, B, S, D, R):
47
+ raise NotImplementedError
48
+
49
+ def local_step(self, x, bases, coef):
50
+ raise NotImplementedError
51
+
52
+ @torch.no_grad()
53
+ def local_inference(self, x, bases):
54
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
55
+ coef = torch.bmm(x.transpose(1, 2), bases)
56
+ coef = F.softmax(self.inv_t * coef, dim=-1)
57
+
58
+ steps = self.train_steps if self.training else self.eval_steps
59
+ for _ in range(steps):
60
+ bases, coef = self.local_step(x, bases, coef)
61
+
62
+ return bases, coef
63
+
64
+ def compute_coef(self, x, bases, coef):
65
+ raise NotImplementedError
66
+
67
+ def forward(self, x, return_bases=False):
68
+
69
+ if self.debug:
70
+ print("Org x.shape", x.shape)
71
+
72
+ if self.dim == "3D": # (B, C, T, H, W) -> (B * S, D, N)
73
+ B, C, T, H, W = x.shape
74
+
75
+ # t = Time, k = Channels, a & B = height and width
76
+ if self.transform.lower() == "t_kab":
77
+ # # dimension of vector of our interest is T (rPPG signal as T dimension), so forming this as vector
78
+ # # From spatial and channel dimension, which are features, only 2-4 shall be enough to generate the approximated attention matrix
79
+ D = T // self.S
80
+ N = C * H * W
81
+
82
+ elif self.transform.lower() == "tk_ab":
83
+ D = T * C // self.S
84
+ N = H * W
85
+
86
+ elif self.transform.lower() == "k_tab":
87
+ D = C // self.S
88
+ N = T * H * W
89
+
90
+ else:
91
+ print("Invalid MD_TRANSFORM specified:", self.transform)
92
+ exit()
93
+
94
+ # # smoothening the temporal dimension
95
+ # x = x.view(B * self.S, N, D)
96
+ # # print("Intermediate-1 x", x.shape)
97
+
98
+ # sample_1 = x[:, :, 0].unsqueeze(2)
99
+ # sample_2 = x[:, :, -1].unsqueeze(2)
100
+ # x = torch.cat([sample_1, x, sample_2], dim=2)
101
+ # gaussian_kernel = [1.0, 1.0, 1.0]
102
+ # kernels = torch.FloatTensor([[gaussian_kernel]]).repeat(N, N, 1).to(self.device)
103
+ # bias = torch.FloatTensor(torch.zeros(N)).to(self.device)
104
+ # x = F.conv1d(x, kernels, bias=bias, padding="valid")
105
+ # x = (x - x.min()) / (x.max() - x.min())
106
+
107
+ # x = x.permute(0, 2, 1)
108
+ # # print("Intermediate-2 x", x.shape)
109
+
110
+ x = x.view(B * self.S, D, N)
111
+
112
+ elif self.dim == "2D": # (B, C, H, W) -> (B * S, D, N)
113
+ B, C, H, W = x.shape
114
+ D = C // self.S
115
+ N = H * W
116
+ x = x.view(B * self.S, D, N)
117
+
118
+ elif self.dim == "2D_TSM": # (B*frame_depth, C, H, W) -> (B, D, N)
119
+ B, C, H, W = x.shape
120
+ BN = B
121
+ B = B // self.S
122
+ D = self.S
123
+ N = C * H * W
124
+ x = x.view(B, D, N)
125
+ self.S = 1 # re-setting this for local inference
126
+
127
+ elif self.dim == "1D": # (B, C, L) -> (B * S, D, N)
128
+ B, C, L = x.shape
129
+ D = L // self.S
130
+ N = C
131
+ x = x.view(B * self.S, D, N)
132
+
133
+ else:
134
+ print("Dimension not supported")
135
+ exit()
136
+
137
+ if self.debug:
138
+ print("MD_Type", self.md_type)
139
+ print("MD_S", self.S)
140
+ print("MD_D", D)
141
+ print("MD_N", N)
142
+ print("MD_R", self.R)
143
+ print("MD_TRAIN_STEPS", self.train_steps)
144
+ print("MD_EVAL_STEPS", self.eval_steps)
145
+ print("x.view(B * self.S, D, N)", x.shape)
146
+
147
+ if not self.rand_init and not hasattr(self, 'bases'):
148
+ bases = self._build_bases(1, self.S, D, self.R)
149
+ self.register_buffer('bases', bases)
150
+
151
+ # (S, D, R) -> (B * S, D, R)
152
+ if self.rand_init:
153
+ bases = self._build_bases(B, self.S, D, self.R)
154
+ else:
155
+ bases = self.bases.repeat(B, 1, 1).to(self.device)
156
+
157
+ bases, coef = self.local_inference(x, bases)
158
+
159
+ # (B * S, N, R)
160
+ coef = self.compute_coef(x, bases, coef)
161
+
162
+ # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
163
+ x = torch.bmm(bases, coef.transpose(1, 2))
164
+
165
+
166
+ if self.dim == "3D":
167
+
168
+ apply_smoothening = False
169
+ if apply_smoothening:
170
+ # smoothening the temporal dimension
171
+ x = x.view(B, D * self.S, N) #Joining temporal dimension for contiguous smoothening
172
+ # print("Intermediate-0 x", x.shape)
173
+ x = x.permute(0, 2, 1)
174
+ # print("Intermediate-1 x", x.shape)
175
+
176
+ sample_1 = x[:, :, 0].unsqueeze(2)
177
+ # sample_2 = x[:, :, 0].unsqueeze(2)
178
+ sample_3 = x[:, :, -1].unsqueeze(2)
179
+ # sample_4 = x[:, :, -1].unsqueeze(2)
180
+ x = torch.cat([sample_1, x, sample_3], dim=2)
181
+ # x = torch.cat([sample_1, sample_2, x, sample_3, sample_4], dim=2)
182
+ # gaussian_kernel = [0.25, 0.50, 0.75, 0.50, 0.25]
183
+ # gaussian_kernel = [0.33, 0.66, 1.00, 0.66, 0.33]
184
+ # gaussian_kernel = [0.3, 0.7, 1.0, 0.7, 0.3]
185
+ # gaussian_kernel = [0.3, 1.0, 1.0, 1.0, 0.3]
186
+ # gaussian_kernel = [0.20, 0.80, 1.00, 0.80, 0.20]
187
+ # gaussian_kernel = [1.0, 1.0, 1.0]
188
+ gaussian_kernel = [0.8, 1.0, 0.8]
189
+ kernels = torch.FloatTensor([[gaussian_kernel]]).repeat(N, N, 1).to(self.device)
190
+ bias = torch.FloatTensor(torch.zeros(N)).to(self.device)
191
+ x = F.conv1d(x, kernels, bias=bias, padding="valid")
192
+ # x = (x - x.min()) / (x.max() - x.min())
193
+ # x = (x - x.mean()) / (x.std())
194
+ # x = x - x.min()
195
+ x = (x - x.min())/(x.std())
196
+
197
+ # print("Intermediate-2 x", x.shape)
198
+
199
+ # (B * S, D, N) -> (B, C, T, H, W)
200
+ x = x.view(B, C, T, H, W)
201
+ elif self.dim == "2D":
202
+ # (B * S, D, N) -> (B, C, H, W)
203
+ x = x.view(B, C, H, W)
204
+
205
+ elif self.dim == "2D_TSM":
206
+ # (B, D, N) -> (B, C, H, W)
207
+ x = x.view(BN, C, H, W)
208
+
209
+ else:
210
+ # (B * S, D, N) -> (B, C, L)
211
+ x = x.view(B, C, L)
212
+
213
+ # (B * L, D, R) -> (B, L, N, D)
214
+ bases = bases.view(B, self.S, D, self.R)
215
+
216
+ if not self.rand_init and not self.training and not return_bases:
217
+ self.online_update(bases)
218
+
219
+ # if not self.rand_init or return_bases:
220
+ # return x, bases
221
+ # else:
222
+ return x
223
+
224
+ @torch.no_grad()
225
+ def online_update(self, bases):
226
+ # (B, S, D, R) -> (S, D, R)
227
+ update = bases.mean(dim=0)
228
+ self.bases += self.eta * (update - self.bases)
229
+ self.bases = F.normalize(self.bases, dim=1)
230
+
231
+
232
+ class NMF(_MatrixDecompositionBase):
233
+ def __init__(self, device, md_config, debug=False, dim="3D"):
234
+ super().__init__(device, md_config, debug=debug, dim=dim)
235
+ self.device = device
236
+ self.inv_t = 1
237
+
238
+ def _build_bases(self, B, S, D, R):
239
+ # bases = torch.rand((B * S, D, R)).to(self.device)
240
+ bases = torch.ones((B * S, D, R)).to(self.device)
241
+ bases = F.normalize(bases, dim=1)
242
+
243
+ return bases
244
+
245
+ @torch.no_grad()
246
+ def local_step(self, x, bases, coef):
247
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
248
+ numerator = torch.bmm(x.transpose(1, 2), bases)
249
+ # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
250
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
251
+ # Multiplicative Update
252
+ coef = coef * numerator / (denominator + 1e-6)
253
+
254
+ # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
255
+ numerator = torch.bmm(x, coef)
256
+ # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
257
+ denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
258
+ # Multiplicative Update
259
+ bases = bases * numerator / (denominator + 1e-6)
260
+
261
+ return bases, coef
262
+
263
+ def compute_coef(self, x, bases, coef):
264
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
265
+ numerator = torch.bmm(x.transpose(1, 2), bases)
266
+ # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
267
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
268
+ # multiplication update
269
+ coef = coef * numerator / (denominator + 1e-6)
270
+
271
+ return coef
272
+
273
+
274
+ class VQ(_MatrixDecompositionBase):
275
+ def __init__(self, device, md_config, debug=False, dim="3D"):
276
+ super().__init__(device, md_config, debug=debug, dim=dim)
277
+ self.device = device
278
+
279
+ def _build_bases(self, B, S, D, R):
280
+ # bases = torch.randn((B * S, D, R)).to(self.device)
281
+ bases = torch.ones((B * S, D, R)).to(self.device)
282
+ bases = F.normalize(bases, dim=1)
283
+ return bases
284
+
285
+ @torch.no_grad()
286
+ def local_step(self, x, bases, _):
287
+ # (B * S, D, N), normalize x along D (for cosine similarity)
288
+ std_x = F.normalize(x, dim=1)
289
+
290
+ # (B * S, D, R), normalize bases along D (for cosine similarity)
291
+ std_bases = F.normalize(bases, dim=1, eps=1e-6)
292
+
293
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
294
+ coef = torch.bmm(std_x.transpose(1, 2), std_bases)
295
+
296
+ # softmax along R
297
+ coef = F.softmax(self.inv_t * coef, dim=-1)
298
+
299
+ # normalize along N
300
+ coef = coef / (1e-6 + coef.sum(dim=1, keepdim=True))
301
+
302
+ # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
303
+ bases = torch.bmm(x, coef)
304
+
305
+ return bases, coef
306
+
307
+
308
+ def compute_coef(self, x, bases, _):
309
+ with torch.no_grad():
310
+ # (B * S, D, N) -> (B * S, 1, N)
311
+ x_norm = x.norm(dim=1, keepdim=True)
312
+
313
+ # (B * S, D, N) / (B * S, 1, N) -> (B * S, D, N)
314
+ std_x = x / (1e-6 + x_norm)
315
+
316
+ # (B * S, D, R), normalize bases along D (for cosine similarity)
317
+ std_bases = F.normalize(bases, dim=1, eps=1e-6)
318
+
319
+ # (B * S, N, D)^T @ (B * S, D, R) -> (B * S, N, R)
320
+ coef = torch.bmm(std_x.transpose(1, 2), std_bases)
321
+
322
+ # softmax along R
323
+ coef = F.softmax(self.inv_t * coef, dim=-1)
324
+
325
+ return coef
326
+
327
+
328
+ class ConvBNReLU(nn.Module):
329
+ @classmethod
330
+ def _same_paddings(cls, kernel_size, dim):
331
+ if dim == "3D":
332
+ if kernel_size == (1, 1, 1):
333
+ return (0, 0, 0)
334
+ elif kernel_size == (3, 3, 3):
335
+ return (1, 1, 1)
336
+ elif dim == "2D" or dim == "2D_TSM":
337
+ if kernel_size == (1, 1):
338
+ return (0, 0)
339
+ elif kernel_size == (3, 3):
340
+ return (1, 1)
341
+ else:
342
+ if kernel_size == 1:
343
+ return 0
344
+ elif kernel_size == 3:
345
+ return 1
346
+
347
+ def __init__(self, in_c, out_c, dim,
348
+ kernel_size=1, stride=1, padding='same',
349
+ dilation=1, groups=1, act='relu', apply_bn=False, apply_act=True):
350
+ super().__init__()
351
+
352
+ self.apply_bn = apply_bn
353
+ self.apply_act = apply_act
354
+ self.dim = dim
355
+ if dilation == 1:
356
+ if self.dim == "3D":
357
+ dilation = (1, 1, 1)
358
+ elif self.dim == "2D" or dim == "2D_TSM":
359
+ dilation = (1, 1)
360
+ else:
361
+ dilation = 1
362
+
363
+ if kernel_size == 1:
364
+ if self.dim == "3D":
365
+ kernel_size = (1, 1, 1)
366
+ elif self.dim == "2D" or dim == "2D_TSM":
367
+ kernel_size = (1, 1)
368
+ else:
369
+ kernel_size = 1
370
+
371
+ if stride == 1:
372
+ if self.dim == "3D":
373
+ stride = (1, 1, 1)
374
+ elif self.dim == "2D" or dim == "2D_TSM":
375
+ stride = (1, 1)
376
+ else:
377
+ stride = 1
378
+
379
+ if padding == 'same':
380
+ padding = self._same_paddings(kernel_size, dim)
381
+
382
+ if self.dim == "3D":
383
+ self.conv = nn.Conv3d(in_c, out_c,
384
+ kernel_size=kernel_size, stride=stride,
385
+ padding=padding, dilation=dilation,
386
+ groups=groups,
387
+ bias=False)
388
+ elif self.dim == "2D" or dim == "2D_TSM":
389
+ self.conv = nn.Conv2d(in_c, out_c,
390
+ kernel_size=kernel_size, stride=stride,
391
+ padding=padding, dilation=dilation,
392
+ groups=groups,
393
+ bias=False)
394
+ else:
395
+ self.conv = nn.Conv1d(in_c, out_c,
396
+ kernel_size=kernel_size, stride=stride,
397
+ padding=padding, dilation=dilation,
398
+ groups=groups,
399
+ bias=False)
400
+
401
+ if act == "sigmoid":
402
+ self.act = nn.Sigmoid()
403
+ else:
404
+ self.act = nn.ReLU(inplace=True)
405
+
406
+ if self.apply_bn:
407
+ if self.dim == "3D":
408
+ self.bn = nn.InstanceNorm3d(out_c)
409
+ elif self.dim == "2D" or dim == "2D_TSM":
410
+ self.bn = nn.InstanceNorm2d(out_c)
411
+ else:
412
+ self.bn = nn.InstanceNorm1d(out_c)
413
+
414
+ def forward(self, x):
415
+ x = self.conv(x)
416
+ if self.apply_act:
417
+ x = self.act(x)
418
+ if self.apply_bn:
419
+ x = self.bn(x)
420
+ return x
421
+
422
+
423
+ class FeaturesFactorizationModule(nn.Module):
424
+ def __init__(self, inC, device, md_config, dim="3D", debug=False):
425
+ super().__init__()
426
+
427
+ self.device = device
428
+ self.dim = dim
429
+ md_type = md_config["MD_TYPE"]
430
+ align_C = md_config["align_channels"] # inC // 2 # // 2 #// 8
431
+
432
+ if self.dim == "3D":
433
+ if "nmf" in md_type.lower():
434
+ self.pre_conv_block = nn.Sequential(
435
+ nn.Conv3d(inC, align_C, (1, 1, 1)),
436
+ nn.ReLU(inplace=True))
437
+ else:
438
+ self.pre_conv_block = nn.Conv3d(inC, align_C, (1, 1, 1))
439
+ elif self.dim == "2D" or self.dim == "2D_TSM":
440
+ if "nmf" in md_type.lower():
441
+ self.pre_conv_block = nn.Sequential(
442
+ nn.Conv2d(inC, align_C, (1, 1)),
443
+ nn.ReLU(inplace=True)
444
+ )
445
+ else:
446
+ self.pre_conv_block = nn.Conv2d(inC, align_C, (1, 1))
447
+ elif self.dim == "1D":
448
+ if "nmf" in md_type.lower():
449
+ self.pre_conv_block = nn.Sequential(
450
+ nn.Conv1d(inC, align_C, 1),
451
+ nn.ReLU(inplace=True)
452
+ )
453
+ else:
454
+ self.pre_conv_block = nn.Conv1d(inC, align_C, 1)
455
+ else:
456
+ print("Dimension not supported")
457
+
458
+ if "nmf" in md_type.lower():
459
+ self.md_block = NMF(self.device, md_config, dim=self.dim, debug=debug)
460
+ elif "vq" in md_type.lower():
461
+ self.md_block = VQ(self.device, md_config, dim=self.dim, debug=debug)
462
+ else:
463
+ print("Unknown type specified for MD_TYPE:", md_type)
464
+ exit()
465
+
466
+ if self.dim == "3D":
467
+ if "nmf" in md_type.lower():
468
+ self.post_conv_block = nn.Sequential(
469
+ ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1),
470
+ nn.Conv3d(align_C, inC, 1, bias=False)
471
+ )
472
+ else:
473
+ self.post_conv_block = nn.Sequential(
474
+ ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False),
475
+ nn.Conv3d(align_C, inC, 1, bias=False)
476
+ )
477
+ elif self.dim == "2D" or self.dim == "2D_TSM":
478
+ if "nmf" in md_type.lower():
479
+ self.post_conv_block = nn.Sequential(
480
+ ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1),
481
+ nn.Conv2d(align_C, inC, 1, bias=False)
482
+ )
483
+ else:
484
+ self.post_conv_block = nn.Sequential(
485
+ ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False),
486
+ nn.Conv2d(align_C, inC, 1, bias=False)
487
+ )
488
+ else:
489
+ if "nmf" in md_type.lower():
490
+ self.post_conv_block = nn.Sequential(
491
+ ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1),
492
+ nn.Conv1d(align_C, inC, 1, bias=False)
493
+ )
494
+ else:
495
+ self.post_conv_block = nn.Sequential(
496
+ ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False),
497
+ nn.Conv1d(align_C, inC, 1, bias=False)
498
+ )
499
+
500
+ self._init_weight()
501
+
502
+
503
+ def _init_weight(self):
504
+ for m in self.modules():
505
+ if isinstance(m, nn.Conv3d):
506
+ N = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
507
+ m.weight.data.normal_(0, np.sqrt(2. / N))
508
+ elif isinstance(m, nn.Conv2d):
509
+ N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
510
+ m.weight.data.normal_(0, np.sqrt(2. / N))
511
+ elif isinstance(m, nn.Conv1d):
512
+ N = m.kernel_size[0] * m.out_channels
513
+ m.weight.data.normal_(0, np.sqrt(2. / N))
514
+ elif isinstance(m, _BatchNorm):
515
+ m.weight.data.fill_(1)
516
+ if m.bias is not None:
517
+ m.bias.data.zero_()
518
+
519
+ def forward(self, x):
520
+ x = self.pre_conv_block(x)
521
+ att = self.md_block(x)
522
+ dist = torch.dist(x, att)
523
+ att = self.post_conv_block(att)
524
+
525
+ return att, dist
526
+
527
+ def online_update(self, bases):
528
+ if hasattr(self.md_block, 'online_update'):
529
+ self.md_block.online_update(bases)
530
+
neural_methods/model/FactorizePhys/FactorizePhys.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
3
+ NeurIPS 2024
4
+ Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from neural_methods.model.FactorizePhys.FSAM import FeaturesFactorizationModule
10
+
11
+ nf = [8, 12, 16]
12
+
13
+ model_config = {
14
+ "MD_FSAM": True,
15
+ "MD_TYPE": "NMF",
16
+ "MD_TRANSFORM": "T_KAB",
17
+ "MD_R": 1,
18
+ "MD_S": 1,
19
+ "MD_STEPS": 4,
20
+ "MD_INFERENCE": False,
21
+ "MD_RESIDUAL": False,
22
+ "INV_T": 1,
23
+ "ETA": 0.9,
24
+ "RAND_INIT": True,
25
+ "in_channels": 3,
26
+ "data_channels": 4,
27
+ "align_channels": nf[2] // 2,
28
+ "height": 72,
29
+ "weight": 72,
30
+ "batch_size": 4,
31
+ "frames": 160,
32
+ "debug": False,
33
+ "assess_latency": False,
34
+ "num_trials": 20,
35
+ "visualize": False,
36
+ "ckpt_path": "",
37
+ "data_path": "",
38
+ "label_path": ""
39
+ }
40
+
41
+
42
+ class ConvBlock3D(nn.Module):
43
+ def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
44
+ super(ConvBlock3D, self).__init__()
45
+ self.conv_block_3d = nn.Sequential(
46
+ nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding=padding, bias=False),
47
+ nn.Tanh(),
48
+ nn.InstanceNorm3d(out_channel),
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.conv_block_3d(x)
53
+
54
+
55
+ class rPPG_FeatureExtractor(nn.Module):
56
+ def __init__(self, inCh, dropout_rate=0.1, debug=False):
57
+ super(rPPG_FeatureExtractor, self).__init__()
58
+ # inCh, out_channel, kernel_size, stride, padding
59
+
60
+ self.debug = debug
61
+ # Input: #B, inCh, 160, 72, 72
62
+ self.FeatureExtractor = nn.Sequential(
63
+ ConvBlock3D(inCh, nf[0], [3, 3, 3], [1, 1, 1], [1, 1, 1]), #B, nf[0], 160, 72, 72
64
+ ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 160, 35, 35
65
+ ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 33, 33
66
+ nn.Dropout3d(p=dropout_rate),
67
+
68
+ ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 31, 31
69
+ ConvBlock3D(nf[1], nf[2], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[2], 160, 15, 15
70
+ ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 13, 13
71
+ nn.Dropout3d(p=dropout_rate),
72
+ )
73
+
74
+ def forward(self, x):
75
+ voxel_embeddings = self.FeatureExtractor(x)
76
+ if self.debug:
77
+ print("rPPG Feature Extractor")
78
+ print(" voxel_embeddings.shape", voxel_embeddings.shape)
79
+ return voxel_embeddings
80
+
81
+
82
+ class BVP_Head(nn.Module):
83
+ def __init__(self, md_config, device, dropout_rate=0.1, debug=False):
84
+ super(BVP_Head, self).__init__()
85
+ self.debug = debug
86
+
87
+ self.use_fsam = md_config["MD_FSAM"]
88
+ self.md_type = md_config["MD_TYPE"]
89
+ self.md_infer = md_config["MD_INFERENCE"]
90
+ self.md_res = md_config["MD_RESIDUAL"]
91
+
92
+ self.conv_block = nn.Sequential(
93
+ ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 11, 11
94
+ ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 9, 9
95
+ ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 7, 7
96
+ nn.Dropout3d(p=dropout_rate),
97
+ )
98
+
99
+ if self.use_fsam:
100
+ inC = nf[2]
101
+ self.fsam = FeaturesFactorizationModule(inC, device, md_config, dim="3D", debug=debug)
102
+ self.fsam_norm = nn.InstanceNorm3d(inC)
103
+ self.bias1 = nn.Parameter(torch.tensor(1.0), requires_grad=True).to(device)
104
+ else:
105
+ inC = nf[2]
106
+
107
+ self.final_layer = nn.Sequential(
108
+ ConvBlock3D(inC, nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 5, 5
109
+ ConvBlock3D(nf[1], nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 160, 3, 3
110
+ nn.Conv3d(nf[0], 1, (3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), #B, 1, 160, 1, 1
111
+ )
112
+
113
+
114
+ def forward(self, voxel_embeddings, batch, length):
115
+
116
+ if self.debug:
117
+ print("BVP Head")
118
+ print(" voxel_embeddings.shape", voxel_embeddings.shape)
119
+
120
+ voxel_embeddings = self.conv_block(voxel_embeddings)
121
+
122
+ if (self.md_infer or self.training or self.debug) and self.use_fsam:
123
+ if "NMF" in self.md_type:
124
+ att_mask, appx_error = self.fsam(voxel_embeddings - voxel_embeddings.min()) # to make it positive (>= 0)
125
+ else:
126
+ att_mask, appx_error = self.fsam(voxel_embeddings)
127
+
128
+ if self.debug:
129
+ print("att_mask.shape", att_mask.shape)
130
+
131
+ # # directly use att_mask ---> difficult to converge without Residual connection. Needs high rank
132
+ # factorized_embeddings = self.fsam_norm(att_mask)
133
+
134
+ # # Residual connection:
135
+ # factorized_embeddings = voxel_embeddings + self.fsam_norm(att_mask)
136
+
137
+ if self.md_res:
138
+ # Multiplication with Residual connection
139
+ x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1)
140
+ factorized_embeddings = self.fsam_norm(x)
141
+ factorized_embeddings = voxel_embeddings + factorized_embeddings
142
+ else:
143
+ # Multiplication
144
+ x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1)
145
+ factorized_embeddings = self.fsam_norm(x)
146
+
147
+ # # Concatenate
148
+ # factorized_embeddings = torch.cat([voxel_embeddings, self.fsam_norm(x)], dim=1)
149
+
150
+ x = self.final_layer(factorized_embeddings)
151
+
152
+ else:
153
+ x = self.final_layer(voxel_embeddings)
154
+
155
+ rPPG = x.view(-1, length)
156
+
157
+ if self.debug:
158
+ print(" rPPG.shape", rPPG.shape)
159
+
160
+ if (self.md_infer or self.training or self.debug) and self.use_fsam:
161
+ return rPPG, factorized_embeddings, appx_error
162
+ else:
163
+ return rPPG
164
+
165
+
166
+
167
+ class FactorizePhys(nn.Module):
168
+ def __init__(self, frames, md_config, in_channels=3, dropout=0.1, device=torch.device("cpu"), debug=False):
169
+ super(FactorizePhys, self).__init__()
170
+ self.debug = debug
171
+
172
+ self.in_channels = in_channels
173
+ if self.in_channels == 1 or self.in_channels == 3:
174
+ self.norm = nn.InstanceNorm3d(self.in_channels)
175
+ elif self.in_channels == 4:
176
+ self.rgb_norm = nn.InstanceNorm3d(3)
177
+ self.thermal_norm = nn.InstanceNorm3d(1)
178
+ else:
179
+ print("Unsupported input channels")
180
+
181
+ self.use_fsam = md_config["MD_FSAM"]
182
+ self.md_infer = md_config["MD_INFERENCE"]
183
+
184
+ for key in model_config:
185
+ if key not in md_config:
186
+ md_config[key] = model_config[key]
187
+
188
+ if self.debug:
189
+ print("nf:", nf)
190
+
191
+ self.rppg_feature_extractor = rPPG_FeatureExtractor(self.in_channels, dropout_rate=dropout, debug=debug)
192
+
193
+ self.rppg_head = BVP_Head(md_config, device=device, dropout_rate=dropout, debug=debug)
194
+
195
+
196
+ def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32]
197
+
198
+ [batch, channel, length, width, height] = x.shape
199
+
200
+ # if self.in_channels == 1:
201
+ # x = x[:, :, :-1, :, :]
202
+ # else:
203
+ # x = torch.diff(x, dim=2)
204
+
205
+ x = torch.diff(x, dim=2)
206
+
207
+ if self.debug:
208
+ print("Input.shape", x.shape)
209
+
210
+ if self.in_channels == 1:
211
+ x = self.norm(x[:, -1:, :, :, :])
212
+ elif self.in_channels == 3:
213
+ x = self.norm(x[:, :3, :, :, :])
214
+ elif self.in_channels == 4:
215
+ rgb_x = self.rgb_norm(x[:, :3, :, :, :])
216
+ thermal_x = self.thermal_norm(x[:, -1:, :, :, :])
217
+ x = torch.concat([rgb_x, thermal_x], dim = 1)
218
+ else:
219
+ try:
220
+ print("Specified input channels:", self.in_channels)
221
+ print("Data channels", channel)
222
+ assert self.in_channels <= channel
223
+ except:
224
+ print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels")
225
+ print("Default or specified channels:", self.in_channels)
226
+ print("Data channels [B, C, N, W, H]", x.shape)
227
+ print("Exiting")
228
+ exit()
229
+
230
+ if self.debug:
231
+ print("Diff Normalized shape", x.shape)
232
+
233
+ voxel_embeddings = self.rppg_feature_extractor(x)
234
+
235
+ if (self.md_infer or self.training or self.debug) and self.use_fsam:
236
+ rPPG, factorized_embeddings, appx_error = self.rppg_head(voxel_embeddings, batch, length-1)
237
+ else:
238
+ rPPG = self.rppg_head(voxel_embeddings, batch, length-1)
239
+
240
+ # if self.debug:
241
+ # print("rppg_feats.shape", rppg_feats.shape)
242
+
243
+ # rPPG = rppg_feats.view(-1, length-1)
244
+
245
+ if self.debug:
246
+ print("rPPG.shape", rPPG.shape)
247
+
248
+ if (self.md_infer or self.training or self.debug) and self.use_fsam:
249
+ return rPPG, voxel_embeddings, factorized_embeddings, appx_error
250
+ else:
251
+ return rPPG, voxel_embeddings
neural_methods/model/FactorizePhys/FactorizePhysBig.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
3
+ NeurIPS 2024
4
+ Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from neural_methods.model.FactorizePhys.FSAM import FeaturesFactorizationModule
10
+
11
+ nf = [8, 12, 16]
12
+
13
+ model_config = {
14
+ "MD_FSAM": True,
15
+ "MD_TYPE": "NMF",
16
+ "MD_TRANSFORM": "T_KAB",
17
+ "MD_R": 1,
18
+ "MD_S": 1,
19
+ "MD_STEPS": 4,
20
+ "MD_INFERENCE": False,
21
+ "MD_RESIDUAL": False,
22
+ "INV_T": 1,
23
+ "ETA": 0.9,
24
+ "RAND_INIT": True,
25
+ "in_channels": 3,
26
+ "data_channels": 4,
27
+ "align_channels": nf[2] // 2,
28
+ "height": 128,
29
+ "weight": 128,
30
+ "batch_size": 4,
31
+ "frames": 240,
32
+ "debug": False,
33
+ "assess_latency": False,
34
+ "num_trials": 20,
35
+ "visualize": False,
36
+ "ckpt_path": "",
37
+ "data_path": "",
38
+ "label_path": ""
39
+ }
40
+
41
+
42
+ class ConvBlock3D(nn.Module):
43
+ def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
44
+ super(ConvBlock3D, self).__init__()
45
+ self.conv_block_3d = nn.Sequential(
46
+ nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding=padding, bias=False),
47
+ nn.Tanh(),
48
+ nn.InstanceNorm3d(out_channel),
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.conv_block_3d(x)
53
+
54
+
55
+ class rPPG_FeatureExtractor(nn.Module):
56
+ def __init__(self, inCh, dropout_rate=0.1, debug=False):
57
+ super(rPPG_FeatureExtractor, self).__init__()
58
+ # inCh, out_channel, kernel_size, stride, padding
59
+
60
+ self.debug = debug
61
+ # Input: #B, inCh, 240, 128, 128
62
+ self.FeatureExtractor = nn.Sequential(
63
+ ConvBlock3D(inCh, nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 240, 126, 126
64
+ ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 240, 62, 62
65
+ ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 240, 60, 60
66
+ nn.Dropout3d(p=dropout_rate),
67
+
68
+ ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 240, 29, 29
69
+ ConvBlock3D(nf[1], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 27, 27
70
+ ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 25, 25
71
+ nn.Dropout3d(p=dropout_rate),
72
+ )
73
+
74
+ def forward(self, x):
75
+ voxel_embeddings = self.FeatureExtractor(x)
76
+ if self.debug:
77
+ print("rPPG Feature Extractor")
78
+ print(" voxel_embeddings.shape", voxel_embeddings.shape)
79
+ return voxel_embeddings
80
+
81
+
82
+ class BVP_Head(nn.Module):
83
+ def __init__(self, md_config, device, dropout_rate=0.1, debug=False):
84
+ super(BVP_Head, self).__init__()
85
+ self.debug = debug
86
+
87
+ self.use_fsam = md_config["MD_FSAM"]
88
+ self.md_type = md_config["MD_TYPE"]
89
+ self.md_infer = md_config["MD_INFERENCE"]
90
+ self.md_res = md_config["MD_RESIDUAL"]
91
+
92
+ self.conv_block = nn.Sequential(
93
+ ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[2], 240, 12, 12
94
+ ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 10, 10
95
+ ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 8, 8
96
+ nn.Dropout3d(p=dropout_rate),
97
+ )
98
+
99
+ if self.use_fsam:
100
+ inC = nf[2]
101
+ self.fsam = FeaturesFactorizationModule(inC, device, md_config, dim="3D", debug=debug)
102
+ self.fsam_norm = nn.InstanceNorm3d(inC)
103
+ self.bias1 = nn.Parameter(torch.tensor(1.0), requires_grad=True).to(device)
104
+ else:
105
+ inC = nf[2]
106
+
107
+ self.final_layer = nn.Sequential(
108
+ ConvBlock3D(inC, nf[1], [3, 4, 4], [1, 1, 1], [1, 0, 0]), #B, nf[1], 240, 5, 5
109
+ ConvBlock3D(nf[1], nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 240, 3, 3
110
+ nn.Conv3d(nf[0], 1, (3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), #B, 1, 240, 1, 1
111
+ )
112
+
113
+
114
+ def forward(self, voxel_embeddings, batch, length):
115
+
116
+ if self.debug:
117
+ print("BVP Head")
118
+ print(" voxel_embeddings.shape", voxel_embeddings.shape)
119
+
120
+ if (self.md_infer or self.training or self.debug) and self.use_fsam:
121
+ if "NMF" in self.md_type:
122
+ att_mask, appx_error = self.fsam(voxel_embeddings - voxel_embeddings.min()) # to make it positive (>= 0)
123
+ else:
124
+ att_mask, appx_error = self.fsam(voxel_embeddings)
125
+
126
+ if self.debug:
127
+ print("att_mask.shape", att_mask.shape)
128
+
129
+ # # directly use att_mask ---> difficult to converge without Residual connection. Needs high rank
130
+ # factorized_embeddings = self.fsam_norm(att_mask)
131
+
132
+ # # Residual connection:
133
+ # factorized_embeddings = voxel_embeddings + self.fsam_norm(att_mask)
134
+
135
+ if self.md_res:
136
+ # Multiplication with Residual connection
137
+ x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1)
138
+ factorized_embeddings = self.fsam_norm(x)
139
+ factorized_embeddings = voxel_embeddings + factorized_embeddings
140
+ else:
141
+ # Multiplication
142
+ x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1)
143
+ factorized_embeddings = self.fsam_norm(x)
144
+
145
+ # # Concatenate
146
+ # factorized_embeddings = torch.cat([voxel_embeddings, self.fsam_norm(x)], dim=1)
147
+
148
+ x = self.conv_block(factorized_embeddings)
149
+ x = self.final_layer(x)
150
+
151
+ else:
152
+ x = self.conv_block(voxel_embeddings)
153
+ x = self.final_layer(x)
154
+
155
+ rPPG = x.view(-1, length)
156
+
157
+ if self.debug:
158
+ print(" rPPG.shape", rPPG.shape)
159
+
160
+ if (self.md_infer or self.training or self.debug) and self.use_fsam:
161
+ return rPPG, factorized_embeddings, appx_error
162
+ else:
163
+ return rPPG
164
+
165
+
166
+
167
+ class FactorizePhysBig(nn.Module):
168
+ def __init__(self, frames, md_config, in_channels=3, dropout=0.1, device=torch.device("cpu"), debug=False):
169
+ super(FactorizePhysBig, self).__init__()
170
+ self.debug = debug
171
+
172
+ self.in_channels = in_channels
173
+ if self.in_channels == 1 or self.in_channels == 3:
174
+ self.norm = nn.InstanceNorm3d(self.in_channels)
175
+ elif self.in_channels == 4:
176
+ self.rgb_norm = nn.InstanceNorm3d(3)
177
+ self.thermal_norm = nn.InstanceNorm3d(1)
178
+ else:
179
+ print("Unsupported input channels")
180
+
181
+ self.use_fsam = md_config["MD_FSAM"]
182
+ self.md_infer = md_config["MD_INFERENCE"]
183
+
184
+ for key in model_config:
185
+ if key not in md_config:
186
+ md_config[key] = model_config[key]
187
+
188
+ if self.debug:
189
+ print("nf:", nf)
190
+
191
+ self.rppg_feature_extractor = rPPG_FeatureExtractor(self.in_channels, dropout_rate=dropout, debug=debug)
192
+
193
+ self.rppg_head = BVP_Head(md_config, device=device, dropout_rate=dropout, debug=debug)
194
+
195
+
196
+ def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32]
197
+
198
+ [batch, channel, length, width, height] = x.shape
199
+
200
+ # if self.in_channels == 1:
201
+ # x = x[:, :, :-1, :, :]
202
+ # else:
203
+ # x = torch.diff(x, dim=2)
204
+
205
+ x = torch.diff(x, dim=2)
206
+
207
+ if self.debug:
208
+ print("Input.shape", x.shape)
209
+
210
+ if self.in_channels == 1:
211
+ x = self.norm(x[:, -1:, :, :, :])
212
+ elif self.in_channels == 3:
213
+ x = self.norm(x[:, :3, :, :, :])
214
+ elif self.in_channels == 4:
215
+ rgb_x = self.rgb_norm(x[:, :3, :, :, :])
216
+ thermal_x = self.thermal_norm(x[:, -1:, :, :, :])
217
+ x = torch.concat([rgb_x, thermal_x], dim = 1)
218
+ else:
219
+ try:
220
+ print("Specified input channels:", self.in_channels)
221
+ print("Data channels", channel)
222
+ assert self.in_channels <= channel
223
+ except:
224
+ print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels")
225
+ print("Default or specified channels:", self.in_channels)
226
+ print("Data channels [B, C, N, W, H]", x.shape)
227
+ print("Exiting")
228
+ exit()
229
+
230
+ if self.debug:
231
+ print("Diff Normalized shape", x.shape)
232
+
233
+ voxel_embeddings = self.rppg_feature_extractor(x)
234
+
235
+ if (self.md_infer or self.training or self.debug) and self.use_fsam:
236
+ rPPG, factorized_embeddings, appx_error = self.rppg_head(voxel_embeddings, batch, length-1)
237
+ else:
238
+ rPPG = self.rppg_head(voxel_embeddings, batch, length-1)
239
+
240
+ # if self.debug:
241
+ # print("rppg_feats.shape", rppg_feats.shape)
242
+
243
+ # rPPG = rppg_feats.view(-1, length-1)
244
+
245
+ if self.debug:
246
+ print("rPPG.shape", rPPG.shape)
247
+
248
+ if (self.md_infer or self.training or self.debug) and self.use_fsam:
249
+ return rPPG, voxel_embeddings, factorized_embeddings, appx_error
250
+ else:
251
+ return rPPG, voxel_embeddings
neural_methods/model/FactorizePhys/__init__.py ADDED
File without changes
neural_methods/model/FactorizePhys/test_FactorizePhys.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
3
+ NeurIPS 2024
4
+ Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
5
+ """
6
+
7
+ import time
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.cm as cm
12
+ from scipy.signal import resample
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys
17
+ # from torch.utils.tensorboard import SummaryWriter
18
+
19
+ model_config = {
20
+ "MD_FSAM": True,
21
+ "MD_TYPE": "NMF",
22
+ "MD_TRANSFORM": "T_KAB",
23
+ "MD_R": 1,
24
+ "MD_S": 1,
25
+ "MD_STEPS": 4,
26
+ "MD_INFERENCE": True,
27
+ "MD_RESIDUAL": True,
28
+ "in_channels": 3,
29
+ "data_channels": 4,
30
+ "height": 72,
31
+ "weight": 72,
32
+ "batch_size": 2,
33
+ "frames": 160,
34
+ "debug": True,
35
+ "assess_latency": False,
36
+ "num_trials": 20,
37
+ "visualize": False,
38
+ "ckpt_path": "./final_model_release/iBVP_FactorizePhys_FSAM_Res.pth",
39
+ "data_path": "/mnt/sda/data/prep/iBVP_Dataset/iBVP_RGB_160_72x72",
40
+ "label_path": "/mnt/sda/data/prep/iBVP_Dataset/iBVP_RGB_160_72x72"
41
+ }
42
+
43
+
44
+ class TestFactorizePhysBig(object):
45
+ def __init__(self) -> None:
46
+ self.ckpt_path = Path(model_config["ckpt_path"])
47
+ self.data_path = Path(model_config["data_path"])
48
+ self.label_path = Path(model_config["label_path"])
49
+
50
+ self.use_fsam = model_config["MD_FSAM"]
51
+ self.md_infer = model_config["MD_INFERENCE"]
52
+
53
+ self.batch_size = model_config["batch_size"]
54
+ self.frames = model_config["frames"]
55
+ self.in_channels = model_config["in_channels"]
56
+ self.data_channels = model_config["data_channels"]
57
+ self.height = model_config["height"]
58
+ self.width = model_config["weight"]
59
+ self.debug = bool(model_config["debug"])
60
+ self.assess_latency = bool(model_config["assess_latency"])
61
+ self.visualize = model_config["visualize"]
62
+
63
+ if self.visualize:
64
+ self.data_files = list(sorted(self.data_path.rglob("*input*.npy")))
65
+ self.label_files = list(sorted(self.data_path.rglob("*label*.npy")))
66
+ self.num_trials = len(self.data_files)
67
+
68
+ self.plot_dir = Path.cwd().joinpath("plots").joinpath("inference")
69
+ self.plot_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ self.attention_map_dir = self.plot_dir.joinpath("attention_maps").joinpath(self.data_path.name).joinpath(self.ckpt_path.name)
72
+ self.attention_map_dir.mkdir(parents=True, exist_ok=True)
73
+
74
+ else:
75
+ if self.assess_latency:
76
+ self.num_trials = model_config["num_trials"]
77
+ else:
78
+ self.num_trials = 1
79
+
80
+ if torch.cuda.is_available():
81
+ self.device = torch.device(0)
82
+ else:
83
+ self.device = torch.device("cpu")
84
+
85
+ md_config = {}
86
+ md_config["FRAME_NUM"] = model_config["frames"]
87
+ md_config["MD_S"] = model_config["MD_S"]
88
+ md_config["MD_R"] = model_config["MD_R"]
89
+ md_config["MD_STEPS"] = model_config["MD_STEPS"]
90
+ md_config["MD_FSAM"] = model_config["MD_FSAM"]
91
+ md_config["MD_TYPE"] = model_config["MD_TYPE"]
92
+ md_config["MD_TRANSFORM"] = model_config["MD_TRANSFORM"]
93
+ md_config["MD_INFERENCE"] = model_config["MD_INFERENCE"]
94
+ md_config["MD_RESIDUAL"] = model_config["MD_RESIDUAL"]
95
+
96
+ if self.visualize:
97
+ self.net = nn.DataParallel(FactorizePhys(frames=self.frames, md_config=md_config,
98
+ device=self.device, in_channels=self.in_channels, debug=self.debug), device_ids=[0]).to(self.device)
99
+ self.net.load_state_dict(torch.load(str(self.ckpt_path), map_location=self.device))
100
+ else:
101
+ self.net = FactorizePhys(frames=self.frames, md_config=md_config,
102
+ device=self.device, in_channels=self.in_channels, debug=self.debug).to(self.device)
103
+
104
+ self.net.eval()
105
+ if self.assess_latency:
106
+ self.time_vec = []
107
+
108
+ if self.debug:
109
+ self.appx_error_list = []
110
+
111
+
112
+ def load_data(self, num_trial):
113
+
114
+ if self.visualize:
115
+ self.np_data = np.load(str(self.data_files[num_trial]))
116
+ self.np_label = np.load(str(self.label_files[num_trial]))
117
+ self.np_label = np.expand_dims(self.np_label, 0)
118
+ self.np_label = torch.tensor(self.np_label)
119
+
120
+ # print("Chunk data shape", self.np_data.shape)
121
+ # print("Chunk label shape", self.np_label.shape)
122
+ # print("Min Max of input data:", np.min(self.np_data), np.max(self.np_data))
123
+ # exit()
124
+
125
+ self.test_data = np.transpose(self.np_data, (3, 0, 1, 2))
126
+ self.test_data = torch.from_numpy(self.test_data)
127
+ self.test_data = self.test_data.unsqueeze(0)
128
+
129
+ last_frame = torch.unsqueeze(self.test_data[:, :, -1, :, :], 2).repeat(1, 1, 1, 1, 1)
130
+ self.test_data = torch.cat((self.test_data, last_frame), 2)
131
+ self.test_data = self.test_data.to(torch.float32).to(self.device)
132
+ else:
133
+ self.test_data = torch.rand(self.batch_size, self.data_channels, self.frames + 1, self.height, self.width)
134
+ self.test_data = self.test_data.to(torch.float32).to(self.device)
135
+
136
+
137
+ def run_inference(self, num_trial):
138
+
139
+ if self.visualize:
140
+ print("Processing:", self.data_files[num_trial].name)
141
+ if self.assess_latency:
142
+ t0 = time.time()
143
+
144
+ if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
145
+ self.pred, self.vox_embed, self.factorized_embed, self.appx_error = self.net(self.test_data)
146
+ else:
147
+ self.pred, self.vox_embed = self.net(self.test_data)
148
+
149
+ if self.assess_latency:
150
+ t1 = time.time()
151
+ self.time_vec.append(t1-t0)
152
+
153
+ if self.debug:
154
+ print("pred.shape", self.pred.shape)
155
+ if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
156
+ self.appx_error_list.append(self.appx_error.item())
157
+
158
+ if self.visualize:
159
+ self.save_attention_maps(num_trial)
160
+
161
+
162
+ def save_attention_maps(self, num_trial):
163
+ b, channels, enc_frames, enc_height, enc_width = self.vox_embed.shape
164
+ label_matrix = self.np_label.unsqueeze(0).repeat(1, channels, 1).unsqueeze(
165
+ 2).unsqueeze(2).permute(0, 1, 4, 3, 2).repeat(1, 1, 1, enc_height, enc_width)
166
+ label_matrix = label_matrix.to(device=self.device)
167
+ corr_matrix = F.cosine_similarity(self.vox_embed, label_matrix, dim=2).abs()
168
+
169
+ # avg_emb = torch.mean(self.vox_embed, dim=1)
170
+ # b, enc_frames, enc_height, enc_width = avg_emb.shape
171
+ # label_matrix = np_label.unsqueeze(0).unsqueeze(2).permute(0, 3, 2, 1).repeat(1, 1, enc_height, enc_width)
172
+ # label_matrix = label_matrix.to(device=device)
173
+ # corr_matrix = F.cosine_similarity(avg_emb, label_matrix, dim=1)
174
+
175
+ if self.debug:
176
+ print("corr_matrix.shape", corr_matrix.shape)
177
+ print("self.test_data.shape:", self.test_data.shape)
178
+ print("self.vox_embed.shape:", self.vox_embed.shape)
179
+
180
+ self.test_data = self.test_data.detach().cpu().numpy()
181
+ self.vox_embed = self.vox_embed.detach().cpu().numpy()
182
+ corr_matrix = corr_matrix.detach().cpu().numpy()
183
+
184
+ fig, ax = plt.subplots(4, 4, figsize=[16, 16])
185
+ fig.tight_layout()
186
+
187
+ ax[0, 0].imshow(self.np_data[enc_frames//2, ...].astype(np.uint8))
188
+ ax[0, 0].axis('off')
189
+ cmap = "coolwarm"
190
+
191
+ ch = 0
192
+ ax[0, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
193
+ ax[0, 1].axis('off')
194
+
195
+ ch = 1
196
+ ax[0, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
197
+ ax[0, 2].axis('off')
198
+
199
+ ch = 2
200
+ ax[0, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
201
+ ax[0, 3].axis('off')
202
+
203
+ ch = 3
204
+ ax[1, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
205
+ ax[1, 0].axis('off')
206
+
207
+ ch = 4
208
+ ax[1, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
209
+ ax[1, 1].axis('off')
210
+
211
+ ch = 5
212
+ ax[1, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
213
+ ax[1, 2].axis('off')
214
+
215
+ ch = 6
216
+ ax[1, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
217
+ ax[1, 3].axis('off')
218
+
219
+ ch = 7
220
+ ax[2, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
221
+ ax[2, 0].axis('off')
222
+
223
+ ch = 8
224
+ ax[2, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
225
+ ax[2, 1].axis('off')
226
+
227
+ ch = 9
228
+ ax[2, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
229
+ ax[2, 2].axis('off')
230
+
231
+ ch = 10
232
+ ax[2, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
233
+ ax[2, 3].axis('off')
234
+
235
+ ch = 11
236
+ ax[3, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
237
+ ax[3, 0].axis('off')
238
+
239
+ ch = 12
240
+ ax[3, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
241
+ ax[3, 1].axis('off')
242
+
243
+ ch = 13
244
+ ax[3, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
245
+ ax[3, 2].axis('off')
246
+
247
+ ch = 14
248
+ ax[3, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
249
+ ax[3, 3].axis('off')
250
+
251
+ # plt.show()
252
+ plt.savefig(str(self.attention_map_dir.joinpath(str(self.data_files[num_trial].name.replace(".npy", "_attention_map.jpg")))))
253
+ plt.close(fig)
254
+
255
+
256
+ def output_summary_results(self):
257
+ if self.assess_latency:
258
+ print("Median time: ", np.median(self.time_vec))
259
+ plt.plot(self.time_vec)
260
+ plt.savefig(str(self.plot_dir.joinpath("Latency.jpg")))
261
+
262
+ if self.debug:
263
+ if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
264
+ print("Median error:", np.median(self.appx_error_list))
265
+
266
+ pytorch_total_params = sum(p.numel() for p in self.net.parameters())
267
+ print("Total parameters = ", pytorch_total_params)
268
+
269
+ pytorch_trainable_params = sum(p.numel()
270
+ for p in self.net.parameters() if p.requires_grad)
271
+ print("Trainable parameters = ", pytorch_trainable_params)
272
+
273
+
274
+ if __name__ == "__main__":
275
+
276
+ testObj = TestFactorizePhysBig()
277
+
278
+ print("testObj.num_trials:", testObj.num_trials)
279
+ for trial_num in range(testObj.num_trials):
280
+ testObj.load_data(trial_num)
281
+ testObj.run_inference(trial_num)
282
+
283
+ testObj.output_summary_results()
284
+
285
+ # writer.add_graph(net, test_data)
286
+ # writer.close()
neural_methods/model/FactorizePhys/test_FactorizePhysBig.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
3
+ NeurIPS 2024
4
+ Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
5
+ """
6
+
7
+ import time
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.cm as cm
12
+ from scipy.signal import resample
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig
18
+ # from torch.utils.tensorboard import SummaryWriter
19
+
20
+ model_config = {
21
+ "MD_FSAM": True,
22
+ "MD_TYPE": "NMF",
23
+ "MD_TRANSFORM": "T_KAB",
24
+ "MD_R": 1,
25
+ "MD_S": 1,
26
+ "MD_STEPS": 4,
27
+ "MD_INFERENCE": True,
28
+ "MD_RESIDUAL": True,
29
+ "in_channels": 3,
30
+ "data_channels": 4,
31
+ "height": 128,
32
+ "weight": 128,
33
+ "batch_size": 1,
34
+ "frames": 240,
35
+ "debug": True,
36
+ "assess_latency": False,
37
+ "num_trials": 20,
38
+ "visualize": False,
39
+ # "ckpt_path": "./final_model_release/UBFC-rPPG_Intra_FactorizePhys_Base_HighRes.pth",
40
+ "ckpt_path": "./final_model_release/UBFC-rPPG_Intra_FactorizePhys_FSAM_Res_HighRes.pth",
41
+ "data_path": "/mnt/sda/data/prep/UBFC-rPPG/UBFC-rPPG_Raw_240_128x128",
42
+ "label_path": "/mnt/sda/data/prep/UBFC-rPPG/UBFC-rPPG_Raw_240_128x128"
43
+ }
44
+
45
+ # default `log_dir` is "runs" - we'll be more specific here
46
+ # writer = SummaryWriter('runs/FactorizePhys')
47
+
48
+ class TestFactorizePhysBig(object):
49
+ def __init__(self) -> None:
50
+ self.ckpt_path = Path(model_config["ckpt_path"])
51
+ self.data_path = Path(model_config["data_path"])
52
+ self.label_path = Path(model_config["label_path"])
53
+
54
+ self.use_fsam = model_config["MD_FSAM"]
55
+ self.md_infer = model_config["MD_INFERENCE"]
56
+
57
+ self.batch_size = model_config["batch_size"]
58
+ self.frames = model_config["frames"]
59
+ self.in_channels = model_config["in_channels"]
60
+ self.data_channels = model_config["data_channels"]
61
+ self.height = model_config["height"]
62
+ self.width = model_config["weight"]
63
+ self.debug = bool(model_config["debug"])
64
+ self.assess_latency = bool(model_config["assess_latency"])
65
+ self.visualize = model_config["visualize"]
66
+
67
+ if self.visualize:
68
+ # self.data_files = list(sorted(self.data_path.rglob("*subject12*input*.npy")))
69
+ # self.label_files = list(sorted(self.data_path.rglob("*subject12*label*.npy")))
70
+ self.data_files = list(sorted(self.data_path.rglob("*input*.npy")))
71
+ self.label_files = list(sorted(self.data_path.rglob("*label*.npy")))
72
+ self.num_trials = len(self.data_files)
73
+
74
+ self.plot_dir = Path.cwd().joinpath("plots").joinpath("inference")
75
+ self.plot_dir.mkdir(parents=True, exist_ok=True)
76
+
77
+ self.attention_map_dir = self.plot_dir.joinpath("attention_maps").joinpath(self.data_path.name).joinpath(self.ckpt_path.name)
78
+ self.attention_map_dir.mkdir(parents=True, exist_ok=True)
79
+
80
+ else:
81
+ if self.assess_latency:
82
+ self.num_trials = model_config["num_trials"]
83
+ else:
84
+ self.num_trials = 1
85
+
86
+ if torch.cuda.is_available():
87
+ self.device = torch.device(0)
88
+ else:
89
+ self.device = torch.device("cpu")
90
+
91
+ md_config = {}
92
+ md_config["FRAME_NUM"] = model_config["frames"]
93
+ md_config["MD_S"] = model_config["MD_S"]
94
+ md_config["MD_R"] = model_config["MD_R"]
95
+ md_config["MD_STEPS"] = model_config["MD_STEPS"]
96
+ md_config["MD_FSAM"] = model_config["MD_FSAM"]
97
+ md_config["MD_TYPE"] = model_config["MD_TYPE"]
98
+ md_config["MD_TRANSFORM"] = model_config["MD_TRANSFORM"]
99
+ md_config["MD_INFERENCE"] = model_config["MD_INFERENCE"]
100
+ md_config["MD_RESIDUAL"] = model_config["MD_RESIDUAL"]
101
+
102
+ if self.visualize:
103
+ self.net = nn.DataParallel(FactorizePhysBig(frames=self.frames, md_config=md_config,
104
+ device=self.device, in_channels=self.in_channels, debug=self.debug), device_ids=[0]).to(self.device)
105
+ self.net.load_state_dict(torch.load(str(self.ckpt_path), map_location=self.device))
106
+ else:
107
+ self.net = FactorizePhysBig(frames=self.frames, md_config=md_config,
108
+ device=self.device, in_channels=self.in_channels, debug=self.debug).to(self.device)
109
+
110
+ self.net.eval()
111
+ if self.assess_latency:
112
+ self.time_vec = []
113
+
114
+ if self.debug:
115
+ self.appx_error_list = []
116
+
117
+
118
+ def load_data(self, num_trial):
119
+
120
+ if self.visualize:
121
+ self.np_data = np.load(str(self.data_files[num_trial]))
122
+ self.np_label = np.load(str(self.label_files[num_trial]))
123
+ self.np_label = np.expand_dims(self.np_label, 0)
124
+ self.np_label = torch.tensor(self.np_label)
125
+
126
+ # print("Chunk data shape", self.np_data.shape)
127
+ # print("Chunk label shape", self.np_label.shape)
128
+ # print("Min Max of input data:", np.min(self.np_data), np.max(self.np_data))
129
+ # exit()
130
+
131
+ self.test_data = np.transpose(self.np_data, (3, 0, 1, 2))
132
+ self.test_data = torch.from_numpy(self.test_data)
133
+ self.test_data = self.test_data.unsqueeze(0)
134
+
135
+ last_frame = torch.unsqueeze(self.test_data[:, :, -1, :, :], 2).repeat(1, 1, 1, 1, 1)
136
+ self.test_data = torch.cat((self.test_data, last_frame), 2)
137
+ self.test_data = self.test_data.to(torch.float32).to(self.device)
138
+ else:
139
+ self.test_data = torch.rand(self.batch_size, self.data_channels, self.frames + 1, self.height, self.width)
140
+ self.test_data = self.test_data.to(torch.float32).to(self.device)
141
+
142
+
143
+ def run_inference(self, num_trial):
144
+
145
+ if self.visualize:
146
+ print("Processing:", self.data_files[num_trial].name)
147
+ if self.assess_latency:
148
+ t0 = time.time()
149
+
150
+ if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
151
+ self.pred, self.vox_embed, self.factorized_embed, self.appx_error = self.net(self.test_data)
152
+ else:
153
+ self.pred, self.vox_embed = self.net(self.test_data)
154
+
155
+ if self.assess_latency:
156
+ t1 = time.time()
157
+ self.time_vec.append(t1-t0)
158
+
159
+ if self.debug:
160
+ print("pred.shape", self.pred.shape)
161
+ if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
162
+ self.appx_error_list.append(self.appx_error.item())
163
+
164
+ if self.visualize:
165
+ self.save_attention_maps(num_trial)
166
+
167
+
168
+ def save_attention_maps(self, num_trial):
169
+ b, channels, enc_frames, enc_height, enc_width = self.vox_embed.shape
170
+ label_matrix = self.np_label.unsqueeze(0).repeat(1, channels, 1).unsqueeze(
171
+ 2).unsqueeze(2).permute(0, 1, 4, 3, 2).repeat(1, 1, 1, enc_height, enc_width)
172
+ label_matrix = label_matrix.to(device=self.device)
173
+ corr_matrix = F.cosine_similarity(self.vox_embed, label_matrix, dim=2).abs()
174
+
175
+ # avg_emb = torch.mean(self.vox_embed, dim=1)
176
+ # b, enc_frames, enc_height, enc_width = avg_emb.shape
177
+ # label_matrix = np_label.unsqueeze(0).unsqueeze(2).permute(0, 3, 2, 1).repeat(1, 1, enc_height, enc_width)
178
+ # label_matrix = label_matrix.to(device=device)
179
+ # corr_matrix = F.cosine_similarity(avg_emb, label_matrix, dim=1)
180
+
181
+ if self.debug:
182
+ print("corr_matrix.shape", corr_matrix.shape)
183
+ print("self.test_data.shape:", self.test_data.shape)
184
+ print("self.vox_embed.shape:", self.vox_embed.shape)
185
+
186
+ self.test_data = self.test_data.detach().cpu().numpy()
187
+ self.vox_embed = self.vox_embed.detach().cpu().numpy()
188
+ corr_matrix = corr_matrix.detach().cpu().numpy()
189
+
190
+ fig, ax = plt.subplots(4, 4, figsize=[16, 16])
191
+ fig.tight_layout()
192
+ cmap = "coolwarm"
193
+
194
+ ax[0, 0].imshow(self.np_data[enc_frames//2, ...].astype(np.uint8))
195
+ ax[0, 0].axis('off')
196
+
197
+ ch = 0
198
+ ax[0, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
199
+ ax[0, 1].axis('off')
200
+
201
+ ch = 1
202
+ ax[0, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
203
+ ax[0, 2].axis('off')
204
+
205
+ ch = 2
206
+ ax[0, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
207
+ ax[0, 3].axis('off')
208
+
209
+ ch = 3
210
+ ax[1, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
211
+ ax[1, 0].axis('off')
212
+
213
+ ch = 4
214
+ ax[1, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
215
+ ax[1, 1].axis('off')
216
+
217
+ ch = 5
218
+ ax[1, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
219
+ ax[1, 2].axis('off')
220
+
221
+ ch = 6
222
+ ax[1, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
223
+ ax[1, 3].axis('off')
224
+
225
+ ch = 7
226
+ ax[2, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
227
+ ax[2, 0].axis('off')
228
+
229
+ ch = 8
230
+ ax[2, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
231
+ ax[2, 1].axis('off')
232
+
233
+ ch = 9
234
+ ax[2, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
235
+ ax[2, 2].axis('off')
236
+
237
+ ch = 10
238
+ ax[2, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
239
+ ax[2, 3].axis('off')
240
+
241
+ ch = 11
242
+ ax[3, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
243
+ ax[3, 0].axis('off')
244
+
245
+ ch = 12
246
+ ax[3, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
247
+ ax[3, 1].axis('off')
248
+
249
+ ch = 13
250
+ ax[3, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
251
+ ax[3, 2].axis('off')
252
+
253
+ ch = 14
254
+ ax[3, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
255
+ ax[3, 3].axis('off')
256
+
257
+ # plt.show()
258
+ plt.savefig(str(self.attention_map_dir.joinpath(str(self.data_files[num_trial].name.replace(".npy", "_attention_map.jpg")))))
259
+ plt.close(fig)
260
+
261
+
262
+ def output_summary_results(self):
263
+ if self.assess_latency:
264
+ print("Median time: ", np.median(self.time_vec))
265
+ plt.plot(self.time_vec)
266
+ plt.savefig(str(self.plot_dir.joinpath("Latency.jpg")))
267
+
268
+ if self.debug:
269
+ if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
270
+ print("Median error:", np.median(self.appx_error_list))
271
+
272
+ pytorch_total_params = sum(p.numel() for p in self.net.parameters())
273
+ print("Total parameters = ", pytorch_total_params)
274
+
275
+ pytorch_trainable_params = sum(p.numel()
276
+ for p in self.net.parameters() if p.requires_grad)
277
+ print("Trainable parameters = ", pytorch_trainable_params)
278
+
279
+
280
+ if __name__ == "__main__":
281
+
282
+ testObj = TestFactorizePhysBig()
283
+
284
+ print("testObj.num_trials:", testObj.num_trials)
285
+ for trial_num in range(testObj.num_trials):
286
+ testObj.load_data(trial_num)
287
+ testObj.run_inference(trial_num)
288
+
289
+ testObj.output_summary_results()
290
+
291
+ # writer.add_graph(net, test_data)
292
+ # writer.close()
neural_methods/model/PhysFormer.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is a combination of Physformer.py and transformer_layer.py
2
+ in the official PhysFormer implementation here:
3
+ https://github.com/ZitongYu/PhysFormer
4
+
5
+ model.py - Model and module class for ViT.
6
+ They are built to mirror those in the official Jax implementation.
7
+ """
8
+
9
+ import numpy as np
10
+ from typing import Optional
11
+ import torch
12
+ from torch import nn
13
+ from torch import Tensor
14
+ from torch.nn import functional as F
15
+ import math
16
+
17
+ def as_tuple(x):
18
+ return x if isinstance(x, tuple) else (x, x)
19
+
20
+ '''
21
+ Temporal Center-difference based Convolutional layer (3D version)
22
+ theta: control the percentage of original convolution and centeral-difference convolution
23
+ '''
24
+ class CDC_T(nn.Module):
25
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
26
+ padding=1, dilation=1, groups=1, bias=False, theta=0.6):
27
+
28
+ super(CDC_T, self).__init__()
29
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
30
+ dilation=dilation, groups=groups, bias=bias)
31
+ self.theta = theta
32
+
33
+ def forward(self, x):
34
+ out_normal = self.conv(x)
35
+
36
+ if math.fabs(self.theta - 0.0) < 1e-8:
37
+ return out_normal
38
+ else:
39
+ [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
40
+
41
+ # only CD works on temporal kernel size>1
42
+ if self.conv.weight.shape[2] > 1:
43
+ kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
44
+ 2).sum(2)
45
+ kernel_diff = kernel_diff[:, :, None, None, None]
46
+ out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
47
+ padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
48
+ return out_normal - self.theta * out_diff
49
+
50
+ else:
51
+ return out_normal
52
+
53
+
54
+ def split_last(x, shape):
55
+ "split the last dimension to given shape"
56
+ shape = list(shape)
57
+ assert shape.count(-1) <= 1
58
+ if -1 in shape:
59
+ shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
60
+ return x.view(*x.size()[:-1], *shape)
61
+
62
+
63
+ def merge_last(x, n_dims):
64
+ "merge the last n_dims to a dimension"
65
+ s = x.size()
66
+ assert n_dims > 1 and n_dims < len(s)
67
+ return x.view(*s[:-n_dims], -1)
68
+
69
+ class MultiHeadedSelfAttention_TDC_gra_sharp(nn.Module):
70
+ """Multi-Headed Dot Product Attention with depth-wise Conv3d"""
71
+ def __init__(self, dim, num_heads, dropout, theta):
72
+ super().__init__()
73
+
74
+ self.proj_q = nn.Sequential(
75
+ CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=theta),
76
+ nn.BatchNorm3d(dim),
77
+ )
78
+ self.proj_k = nn.Sequential(
79
+ CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=theta),
80
+ nn.BatchNorm3d(dim),
81
+ )
82
+ self.proj_v = nn.Sequential(
83
+ nn.Conv3d(dim, dim, 1, stride=1, padding=0, groups=1, bias=False),
84
+ )
85
+
86
+ self.drop = nn.Dropout(dropout)
87
+ self.n_heads = num_heads
88
+ self.scores = None # for visualization
89
+
90
+ def forward(self, x, gra_sharp): # [B, 4*4*40, 128]
91
+ """
92
+ x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
93
+ mask : (B(batch_size) x S(seq_len))
94
+ * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
95
+ """
96
+ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
97
+
98
+ [B, P, C]=x.shape
99
+ x = x.transpose(1, 2).view(B, C, P//16, 4, 4) # [B, dim, 40, 4, 4]
100
+ q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
101
+ q = q.flatten(2).transpose(1, 2) # [B, 4*4*40, dim]
102
+ k = k.flatten(2).transpose(1, 2) # [B, 4*4*40, dim]
103
+ v = v.flatten(2).transpose(1, 2) # [B, 4*4*40, dim]
104
+
105
+ q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
106
+ # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
107
+ scores = q @ k.transpose(-2, -1) / gra_sharp
108
+
109
+ scores = self.drop(F.softmax(scores, dim=-1))
110
+ # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
111
+ h = (scores @ v).transpose(1, 2).contiguous()
112
+ # -merge-> (B, S, D)
113
+ h = merge_last(h, 2)
114
+ self.scores = scores
115
+ return h, scores
116
+
117
+
118
+
119
+
120
+ class PositionWiseFeedForward_ST(nn.Module):
121
+ """FeedForward Neural Networks for each position"""
122
+ def __init__(self, dim, ff_dim):
123
+ super().__init__()
124
+
125
+ self.fc1 = nn.Sequential(
126
+ nn.Conv3d(dim, ff_dim, 1, stride=1, padding=0, bias=False),
127
+ nn.BatchNorm3d(ff_dim),
128
+ nn.ELU(),
129
+ )
130
+
131
+ self.STConv = nn.Sequential(
132
+ nn.Conv3d(ff_dim, ff_dim, 3, stride=1, padding=1, groups=ff_dim, bias=False),
133
+ nn.BatchNorm3d(ff_dim),
134
+ nn.ELU(),
135
+ )
136
+
137
+ self.fc2 = nn.Sequential(
138
+ nn.Conv3d(ff_dim, dim, 1, stride=1, padding=0, bias=False),
139
+ nn.BatchNorm3d(dim),
140
+ )
141
+
142
+ def forward(self, x): # [B, 4*4*40, 128]
143
+ [B, P, C]=x.shape
144
+ x = x.transpose(1, 2).view(B, C, P//16, 4, 4) # [B, dim, 40, 4, 4]
145
+ x = self.fc1(x) # x [B, ff_dim, 40, 4, 4]
146
+ x = self.STConv(x) # x [B, ff_dim, 40, 4, 4]
147
+ x = self.fc2(x) # x [B, dim, 40, 4, 4]
148
+ x = x.flatten(2).transpose(1, 2) # [B, 4*4*40, dim]
149
+
150
+ return x
151
+
152
+ class Block_ST_TDC_gra_sharp(nn.Module):
153
+ """Transformer Block"""
154
+ def __init__(self, dim, num_heads, ff_dim, dropout, theta):
155
+ super().__init__()
156
+ self.attn = MultiHeadedSelfAttention_TDC_gra_sharp(dim, num_heads, dropout, theta)
157
+ self.proj = nn.Linear(dim, dim)
158
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
159
+ self.pwff = PositionWiseFeedForward_ST(dim, ff_dim)
160
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
161
+ self.drop = nn.Dropout(dropout)
162
+
163
+ def forward(self, x, gra_sharp):
164
+ Atten, Score = self.attn(self.norm1(x), gra_sharp)
165
+ h = self.drop(self.proj(Atten))
166
+ x = x + h
167
+ h = self.drop(self.pwff(self.norm2(x)))
168
+ x = x + h
169
+ return x, Score
170
+
171
+ class Transformer_ST_TDC_gra_sharp(nn.Module):
172
+ """Transformer with Self-Attentive Blocks"""
173
+ def __init__(self, num_layers, dim, num_heads, ff_dim, dropout, theta):
174
+ super().__init__()
175
+ self.blocks = nn.ModuleList([
176
+ Block_ST_TDC_gra_sharp(dim, num_heads, ff_dim, dropout, theta) for _ in range(num_layers)])
177
+
178
+ def forward(self, x, gra_sharp):
179
+ for block in self.blocks:
180
+ x, Score = block(x, gra_sharp)
181
+ return x, Score
182
+
183
+ # stem_3DCNN + ST-ViT with local Depthwise Spatio-Temporal MLP
184
+ class ViT_ST_ST_Compact3_TDC_gra_sharp(nn.Module):
185
+
186
+ def __init__(
187
+ self,
188
+ name: Optional[str] = None,
189
+ pretrained: bool = False,
190
+ patches: int = 16,
191
+ dim: int = 768,
192
+ ff_dim: int = 3072,
193
+ num_heads: int = 12,
194
+ num_layers: int = 12,
195
+ attention_dropout_rate: float = 0.0,
196
+ dropout_rate: float = 0.2,
197
+ representation_size: Optional[int] = None,
198
+ load_repr_layer: bool = False,
199
+ classifier: str = 'token',
200
+ #positional_embedding: str = '1d',
201
+ in_channels: int = 3,
202
+ frame: int = 160,
203
+ theta: float = 0.2,
204
+ image_size: Optional[int] = None,
205
+ ):
206
+ super().__init__()
207
+
208
+
209
+ self.image_size = image_size
210
+ self.frame = frame
211
+ self.dim = dim
212
+
213
+ # Image and patch sizes
214
+ t, h, w = as_tuple(image_size) # tube sizes
215
+ ft, fh, fw = as_tuple(patches) # patch sizes, ft = 4 ==> 160/4=40
216
+ gt, gh, gw = t//ft, h // fh, w // fw # number of patches
217
+ seq_len = gh * gw * gt
218
+
219
+ # Patch embedding [4x16x16]conv
220
+ self.patch_embedding = nn.Conv3d(dim, dim, kernel_size=(ft, fh, fw), stride=(ft, fh, fw))
221
+
222
+ # Transformer
223
+ self.transformer1 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads,
224
+ ff_dim=ff_dim, dropout=dropout_rate, theta=theta)
225
+ # Transformer
226
+ self.transformer2 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads,
227
+ ff_dim=ff_dim, dropout=dropout_rate, theta=theta)
228
+ # Transformer
229
+ self.transformer3 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads,
230
+ ff_dim=ff_dim, dropout=dropout_rate, theta=theta)
231
+
232
+
233
+
234
+ self.Stem0 = nn.Sequential(
235
+ nn.Conv3d(3, dim//4, [1, 5, 5], stride=1, padding=[0,2,2]),
236
+ nn.BatchNorm3d(dim//4),
237
+ nn.ReLU(inplace=True),
238
+ nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
239
+ )
240
+
241
+ self.Stem1 = nn.Sequential(
242
+ nn.Conv3d(dim//4, dim//2, [3, 3, 3], stride=1, padding=1),
243
+ nn.BatchNorm3d(dim//2),
244
+ nn.ReLU(inplace=True),
245
+ nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
246
+ )
247
+ self.Stem2 = nn.Sequential(
248
+ nn.Conv3d(dim//2, dim, [3, 3, 3], stride=1, padding=1),
249
+ nn.BatchNorm3d(dim),
250
+ nn.ReLU(inplace=True),
251
+ nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
252
+ )
253
+
254
+ self.upsample = nn.Sequential(
255
+ nn.Upsample(scale_factor=(2,1,1)),
256
+ nn.Conv3d(dim, dim, [3, 1, 1], stride=1, padding=(1,0,0)),
257
+ nn.BatchNorm3d(dim),
258
+ nn.ELU(),
259
+ )
260
+ self.upsample2 = nn.Sequential(
261
+ nn.Upsample(scale_factor=(2,1,1)),
262
+ nn.Conv3d(dim, dim//2, [3, 1, 1], stride=1, padding=(1,0,0)),
263
+ nn.BatchNorm3d(dim//2),
264
+ nn.ELU(),
265
+ )
266
+
267
+ self.ConvBlockLast = nn.Conv1d(dim//2, 1, 1,stride=1, padding=0)
268
+
269
+
270
+ # Initialize weights
271
+ self.init_weights()
272
+
273
+ @torch.no_grad()
274
+ def init_weights(self):
275
+ def _init(m):
276
+ if isinstance(m, nn.Linear):
277
+ nn.init.xavier_uniform_(m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal
278
+ if hasattr(m, 'bias') and m.bias is not None:
279
+ nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0)
280
+ self.apply(_init)
281
+
282
+
283
+ def forward(self, x, gra_sharp):
284
+
285
+ # b is batch number, c channels, t frame, fh frame height, and fw frame width
286
+ b, c, t, fh, fw = x.shape
287
+
288
+ x = self.Stem0(x)
289
+ x = self.Stem1(x)
290
+ x = self.Stem2(x) # [B, 64, 160, 64, 64]
291
+
292
+ x = self.patch_embedding(x) # [B, 64, 40, 4, 4]
293
+ x = x.flatten(2).transpose(1, 2) # [B, 40*4*4, 64]
294
+
295
+
296
+ Trans_features, Score1 = self.transformer1(x, gra_sharp) # [B, 4*4*40, 64]
297
+ Trans_features2, Score2 = self.transformer2(Trans_features, gra_sharp) # [B, 4*4*40, 64]
298
+ Trans_features3, Score3 = self.transformer3(Trans_features2, gra_sharp) # [B, 4*4*40, 64]
299
+
300
+ # upsampling heads
301
+ #features_last = Trans_features3.transpose(1, 2).view(b, self.dim, 40, 4, 4) # [B, 64, 40, 4, 4]
302
+ features_last = Trans_features3.transpose(1, 2).view(b, self.dim, t//4, 4, 4) # [B, 64, 40, 4, 4]
303
+
304
+ features_last = self.upsample(features_last) # x [B, 64, 7*7, 80]
305
+ features_last = self.upsample2(features_last) # x [B, 32, 7*7, 160]
306
+
307
+ features_last = torch.mean(features_last,3) # x [B, 32, 160, 4]
308
+ features_last = torch.mean(features_last,3) # x [B, 32, 160]
309
+ rPPG = self.ConvBlockLast(features_last) # x [B, 1, 160]
310
+
311
+ rPPG = rPPG.squeeze(1)
312
+
313
+ return rPPG, Score1, Score2, Score3
neural_methods/model/PhysMamba.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from timm.models.layers import trunc_normal_, DropPath
5
+ from mamba_ssm import Mamba
6
+ from torch.nn import functional as F
7
+
8
+ class ChannelAttention3D(nn.Module):
9
+ def __init__(self, in_channels, reduction):
10
+ super(ChannelAttention3D, self).__init__()
11
+ self.avg_pool = nn.AdaptiveAvgPool3d(1)
12
+ self.max_pool = nn.AdaptiveMaxPool3d(1)
13
+
14
+ self.fc = nn.Sequential(
15
+ nn.Conv3d(in_channels, in_channels // reduction, 1, bias=False),
16
+ nn.ReLU(),
17
+ nn.Conv3d(in_channels // reduction, in_channels, 1, bias=False)
18
+ )
19
+ self.sigmoid = nn.Sigmoid()
20
+
21
+ def forward(self, x):
22
+ avg_out = self.fc(self.avg_pool(x))
23
+ max_out = self.fc(self.max_pool(x))
24
+ out = avg_out + max_out
25
+ attention = self.sigmoid(out)
26
+ return x*attention
27
+
28
+ class LateralConnection(nn.Module):
29
+ def __init__(self, fast_channels=32, slow_channels=64):
30
+ super(LateralConnection, self).__init__()
31
+ self.conv = nn.Sequential(
32
+ nn.Conv3d(fast_channels, slow_channels, [3, 1, 1], stride=[2, 1, 1], padding=[1,0,0]),
33
+ nn.BatchNorm3d(64),
34
+ nn.ReLU(),
35
+ )
36
+
37
+ def forward(self, slow_path, fast_path):
38
+ fast_path = self.conv(fast_path)
39
+ return fast_path + slow_path
40
+
41
+ class CDC_T(nn.Module):
42
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
43
+ padding=1, dilation=1, groups=1, bias=False, theta=0.2):
44
+
45
+ super(CDC_T, self).__init__()
46
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
47
+ dilation=dilation, groups=groups, bias=bias)
48
+ self.theta = theta
49
+
50
+ def forward(self, x):
51
+
52
+ out_normal = self.conv(x)
53
+
54
+ if math.fabs(self.theta - 0.0) < 1e-8:
55
+ return out_normal
56
+ else:
57
+ [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
58
+
59
+ # only CD works on temporal kernel size>1
60
+ if self.conv.weight.shape[2] > 1:
61
+ kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
62
+ 2).sum(2)
63
+ kernel_diff = kernel_diff[:, :, None, None, None]
64
+ out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
65
+ padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
66
+ return out_normal - self.theta * out_diff
67
+
68
+ else:
69
+ return out_normal
70
+
71
+ class MambaLayer(nn.Module):
72
+ def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, channel_token = False):
73
+ super(MambaLayer, self).__init__()
74
+ self.dim = dim
75
+ self.norm1 = nn.LayerNorm(dim)
76
+ self.norm2 = nn.LayerNorm(dim)
77
+ drop_path = 0
78
+ self.mamba = Mamba(
79
+ d_model=dim, # Model dimension d_model
80
+ d_state=d_state, # SSM state expansion factor
81
+ d_conv=d_conv, # Local convolution width
82
+ expand=expand, # Block expansion factor
83
+ bimamba=True,
84
+ )
85
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
86
+ self.apply(self._init_weights)
87
+
88
+ def _init_weights(self, m):
89
+ if isinstance(m, nn.Linear):
90
+ trunc_normal_(m.weight, std=.02)
91
+ if isinstance(m, nn.Linear) and m.bias is not None:
92
+ nn.init.constant_(m.bias, 0)
93
+ elif isinstance(m, nn.LayerNorm):
94
+ nn.init.constant_(m.bias, 0)
95
+ nn.init.constant_(m.weight, 1.0)
96
+ elif isinstance(m, nn.Conv2d):
97
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
98
+ fan_out //= m.groups
99
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
100
+ if m.bias is not None:
101
+ m.bias.data.zero_()
102
+
103
+ def forward_patch_token(self, x):
104
+ B, C, nf, H, W = x.shape
105
+ B, d_model = x.shape[:2]
106
+ assert d_model == self.dim
107
+ n_tokens = x.shape[2:].numel()
108
+ img_dims = x.shape[2:]
109
+ x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2)
110
+ x_norm = self.norm1(x_flat)
111
+ x_mamba = self.mamba(x_norm)
112
+ x_out = self.norm2(x_flat + self.drop_path(x_mamba))
113
+ out = x_out.transpose(-1, -2).reshape(B, d_model, *img_dims)
114
+ return out
115
+
116
+ def forward(self, x):
117
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
118
+ x = x.type(torch.float32)
119
+ out = self.forward_patch_token(x)
120
+ return out
121
+
122
+ def conv_block(in_channels, out_channels, kernel_size, stride, padding, bn=True, activation='relu'):
123
+ layers = [nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)]
124
+ if bn:
125
+ layers.append(nn.BatchNorm3d(out_channels))
126
+ if activation == 'relu':
127
+ layers.append(nn.ReLU(inplace=True))
128
+ elif activation == 'elu':
129
+ layers.append(nn.ELU(inplace=True))
130
+ return nn.Sequential(*layers)
131
+
132
+
133
+ class PhysMamba(nn.Module):
134
+ def __init__(self, theta=0.5, drop_rate1=0.25, drop_rate2=0.5, frames=128):
135
+ super(PhysMamba, self).__init__()
136
+
137
+ self.ConvBlock1 = conv_block(3, 16, [1, 5, 5], stride=1, padding=[0, 2, 2])
138
+ self.ConvBlock2 = conv_block(16, 32, [3, 3, 3], stride=1, padding=1)
139
+ self.ConvBlock3 = conv_block(32, 64, [3, 3, 3], stride=1, padding=1)
140
+ self.ConvBlock4 = conv_block(64, 64, [4, 1, 1], stride=[4, 1, 1], padding=0)
141
+ self.ConvBlock5 = conv_block(64, 32, [2, 1, 1], stride=[2, 1, 1], padding=0)
142
+ self.ConvBlock6 = conv_block(32, 32, [3, 1, 1], stride=1, padding=[1, 0, 0], activation='elu')
143
+
144
+ # Temporal Difference Mamba Blocks
145
+ # Slow Stream
146
+ self.Block1 = self._build_block(64, theta)
147
+ self.Block2 = self._build_block(64, theta)
148
+ self.Block3 = self._build_block(64, theta)
149
+ # Fast Stream
150
+ self.Block4 = self._build_block(32, theta)
151
+ self.Block5 = self._build_block(32, theta)
152
+ self.Block6 = self._build_block(32, theta)
153
+
154
+ # Upsampling
155
+ self.upsample1 = nn.Sequential(
156
+ nn.Upsample(scale_factor=(2,1,1)),
157
+ nn.Conv3d(64, 64, [3, 1, 1], stride=1, padding=(1,0,0)),
158
+ nn.BatchNorm3d(64),
159
+ nn.ELU(),
160
+ )
161
+ self.upsample2 = nn.Sequential(
162
+ nn.Upsample(scale_factor=(2,1,1)),
163
+ nn.Conv3d(96, 48, [3, 1, 1], stride=1, padding=(1,0,0)),
164
+ nn.BatchNorm3d(48),
165
+ nn.ELU(),
166
+ )
167
+
168
+ self.ConvBlockLast = nn.Conv3d(48, 1, [1, 1, 1], stride=1, padding=0)
169
+ self.MaxpoolSpa = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2))
170
+ self.MaxpoolSpaTem = nn.MaxPool3d((2, 2, 2), stride=2)
171
+
172
+ self.fuse_1 = LateralConnection(fast_channels=32, slow_channels=64)
173
+ self.fuse_2 = LateralConnection(fast_channels=32, slow_channels=64)
174
+
175
+ self.drop_1 = nn.Dropout(drop_rate1)
176
+ self.drop_2 = nn.Dropout(drop_rate1)
177
+ self.drop_3 = nn.Dropout(drop_rate2)
178
+ self.drop_4 = nn.Dropout(drop_rate2)
179
+ self.drop_5 = nn.Dropout(drop_rate2)
180
+ self.drop_6 = nn.Dropout(drop_rate2)
181
+
182
+ self.poolspa = nn.AdaptiveAvgPool3d((frames, 1, 1))
183
+
184
+ def _build_block(self, channels, theta):
185
+ return nn.Sequential(
186
+ CDC_T(channels, channels, theta=theta),
187
+ nn.BatchNorm3d(channels),
188
+ nn.ReLU(),
189
+ MambaLayer(dim=channels),
190
+ ChannelAttention3D(in_channels=channels, reduction=2),
191
+ )
192
+
193
+ def forward(self, x):
194
+ [batch, channel, length, width, height] = x.shape
195
+
196
+ x = self.ConvBlock1(x)
197
+ x = self.MaxpoolSpa(x)
198
+ x = self.ConvBlock2(x)
199
+ x = self.ConvBlock3(x)
200
+ x = self.MaxpoolSpa(x)
201
+
202
+ # Process streams
203
+ s_x = self.ConvBlock4(x) # Slow stream
204
+ f_x = self.ConvBlock5(x) # Fast stream
205
+
206
+ # First set of blocks and fusion
207
+ s_x1 = self.Block1(s_x)
208
+ s_x1 = self.MaxpoolSpa(s_x1)
209
+ s_x1 = self.drop_1(s_x1)
210
+
211
+ f_x1 = self.Block4(f_x)
212
+ f_x1 = self.MaxpoolSpa(f_x1)
213
+ f_x1 = self.drop_2(f_x1)
214
+
215
+ s_x1 = self.fuse_1(s_x1,f_x1) # LateralConnection
216
+
217
+ # Second set of blocks and fusion
218
+ s_x2 = self.Block2(s_x1)
219
+ s_x2 = self.MaxpoolSpa(s_x2)
220
+ s_x2 = self.drop_3(s_x2)
221
+
222
+ f_x2 = self.Block5(f_x1)
223
+ f_x2 = self.MaxpoolSpa(f_x2)
224
+ f_x2 = self.drop_4(f_x2)
225
+
226
+ s_x2 = self.fuse_2(s_x2,f_x2) # LateralConnection
227
+
228
+ # Third blocks and upsampling
229
+ s_x3 = self.Block3(s_x2)
230
+ s_x3 = self.upsample1(s_x3)
231
+ s_x3 = self.drop_5(s_x3)
232
+
233
+ f_x3 = self.Block6(f_x2)
234
+ f_x3 = self.ConvBlock6(f_x3)
235
+ f_x3 = self.drop_6(f_x3)
236
+
237
+ # Final fusion and upsampling
238
+ x_fusion = torch.cat((f_x3, s_x3), dim=1)
239
+ x_final = self.upsample2(x_fusion)
240
+
241
+ x_final = self.poolspa(x_final)
242
+ x_final = self.ConvBlockLast(x_final)
243
+
244
+ rPPG = x_final.view(-1, length)
245
+
246
+ return rPPG
neural_methods/model/PhysNet.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PhysNet
2
+ We repulicate the net pipeline of the orginal paper, but set the input as diffnormalized data.
3
+ orginal source:
4
+ Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks
5
+ British Machine Vision Conference (BMVC)} 2019,
6
+ By Zitong Yu, 2019/05/05
7
+ Only for research purpose, and commercial use is not allowed.
8
+ MIT License
9
+ Copyright (c) 2019
10
+ """
11
+
12
+ import math
13
+ import pdb
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.nn.modules.utils import _triple
18
+
19
+
20
+ class PhysNet_padding_Encoder_Decoder_MAX(nn.Module):
21
+ def __init__(self, frames=128):
22
+ super(PhysNet_padding_Encoder_Decoder_MAX, self).__init__()
23
+
24
+ self.ConvBlock1 = nn.Sequential(
25
+ nn.Conv3d(3, 16, [1, 5, 5], stride=1, padding=[0, 2, 2]),
26
+ nn.BatchNorm3d(16),
27
+ nn.ReLU(inplace=True),
28
+ )
29
+
30
+ self.ConvBlock2 = nn.Sequential(
31
+ nn.Conv3d(16, 32, [3, 3, 3], stride=1, padding=1),
32
+ nn.BatchNorm3d(32),
33
+ nn.ReLU(inplace=True),
34
+ )
35
+ self.ConvBlock3 = nn.Sequential(
36
+ nn.Conv3d(32, 64, [3, 3, 3], stride=1, padding=1),
37
+ nn.BatchNorm3d(64),
38
+ nn.ReLU(inplace=True),
39
+ )
40
+
41
+ self.ConvBlock4 = nn.Sequential(
42
+ nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
43
+ nn.BatchNorm3d(64),
44
+ nn.ReLU(inplace=True),
45
+ )
46
+ self.ConvBlock5 = nn.Sequential(
47
+ nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
48
+ nn.BatchNorm3d(64),
49
+ nn.ReLU(inplace=True),
50
+ )
51
+ self.ConvBlock6 = nn.Sequential(
52
+ nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
53
+ nn.BatchNorm3d(64),
54
+ nn.ReLU(inplace=True),
55
+ )
56
+ self.ConvBlock7 = nn.Sequential(
57
+ nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
58
+ nn.BatchNorm3d(64),
59
+ nn.ReLU(inplace=True),
60
+ )
61
+ self.ConvBlock8 = nn.Sequential(
62
+ nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
63
+ nn.BatchNorm3d(64),
64
+ nn.ReLU(inplace=True),
65
+ )
66
+ self.ConvBlock9 = nn.Sequential(
67
+ nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
68
+ nn.BatchNorm3d(64),
69
+ nn.ReLU(inplace=True),
70
+ )
71
+
72
+ self.upsample = nn.Sequential(
73
+ nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=[
74
+ 4, 1, 1], stride=[2, 1, 1], padding=[1, 0, 0]), # [1, 128, 32]
75
+ nn.BatchNorm3d(64),
76
+ nn.ELU(),
77
+ )
78
+ self.upsample2 = nn.Sequential(
79
+ nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=[
80
+ 4, 1, 1], stride=[2, 1, 1], padding=[1, 0, 0]), # [1, 128, 32]
81
+ nn.BatchNorm3d(64),
82
+ nn.ELU(),
83
+ )
84
+
85
+ self.ConvBlock10 = nn.Conv3d(64, 1, [1, 1, 1], stride=1, padding=0)
86
+
87
+ self.MaxpoolSpa = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2))
88
+ self.MaxpoolSpaTem = nn.MaxPool3d((2, 2, 2), stride=2)
89
+
90
+ # self.poolspa = nn.AdaptiveMaxPool3d((frames,1,1)) # pool only spatial space
91
+ self.poolspa = nn.AdaptiveAvgPool3d((frames, 1, 1))
92
+
93
+ def forward(self, x): # Batch_size*[3, T, 128,128]
94
+ x_visual = x
95
+ [batch, channel, length, width, height] = x.shape
96
+
97
+ x = self.ConvBlock1(x) # x [3, T, 128,128]
98
+ x = self.MaxpoolSpa(x) # x [16, T, 64,64]
99
+
100
+ x = self.ConvBlock2(x) # x [32, T, 64,64]
101
+ x_visual6464 = self.ConvBlock3(x) # x [32, T, 64,64]
102
+ # x [32, T/2, 32,32] Temporal halve
103
+ x = self.MaxpoolSpaTem(x_visual6464)
104
+
105
+ x = self.ConvBlock4(x) # x [64, T/2, 32,32]
106
+ x_visual3232 = self.ConvBlock5(x) # x [64, T/2, 32,32]
107
+ x = self.MaxpoolSpaTem(x_visual3232) # x [64, T/4, 16,16]
108
+
109
+ x = self.ConvBlock6(x) # x [64, T/4, 16,16]
110
+ x_visual1616 = self.ConvBlock7(x) # x [64, T/4, 16,16]
111
+ x = self.MaxpoolSpa(x_visual1616) # x [64, T/4, 8,8]
112
+
113
+ x = self.ConvBlock8(x) # x [64, T/4, 8, 8]
114
+ x = self.ConvBlock9(x) # x [64, T/4, 8, 8]
115
+ x = self.upsample(x) # x [64, T/2, 8, 8]
116
+ x = self.upsample2(x) # x [64, T, 8, 8]
117
+
118
+ # x [64, T, 1,1] --> groundtruth left and right - 7
119
+ x = self.poolspa(x)
120
+ x = self.ConvBlock10(x) # x [1, T, 1,1]
121
+
122
+ rPPG = x.view(-1, length)
123
+
124
+ return rPPG, x_visual, x_visual3232, x_visual1616
neural_methods/model/RhythmFormer.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RhythmFormer:Extracting rPPG Signals Based on Hierarchical Temporal Periodic Transformer
3
+ """
4
+ from typing import Optional
5
+ import torch
6
+ from torch import nn, Tensor, LongTensor
7
+ from torch.nn import functional as F
8
+ import math
9
+ from typing import Tuple, Union
10
+ from timm.models.layers import trunc_normal_, DropPath
11
+
12
+
13
+
14
+ """
15
+ Adapted from here: https://github.com/rayleizhu/BiFormer
16
+ """
17
+ import torch
18
+ from torch import Tensor, LongTensor , nn
19
+ import torch.nn.functional as F
20
+ from typing import Optional, Tuple
21
+
22
+ def _grid2seq(x:Tensor, region_size:Tuple[int], num_heads:int):
23
+ """
24
+ Args:
25
+ x: BCTHW tensor
26
+ region size: int
27
+ num_heads: number of attention heads
28
+ Return:
29
+ out: rearranged x, has a shape of (bs, nhead, nregion, reg_size, head_dim)
30
+ region_t, region_h, region_w: number of regions per t/col/row
31
+ """
32
+ B, C, T, H, W = x.size()
33
+ region_t ,region_h, region_w = T//region_size[0], H//region_size[1], W//region_size[2]
34
+ x = x.view(B, num_heads, C//num_heads, region_t, region_size[0],region_h, region_size[1], region_w, region_size[2])
35
+ x = torch.einsum('bmdtohpwq->bmthwopqd', x).flatten(2, 4).flatten(-4, -2) # (bs, nhead, nregion, reg_size, head_dim)
36
+ return x, region_t, region_h, region_w
37
+
38
+
39
+ def _seq2grid(x:Tensor, region_t:int, region_h:int, region_w:int, region_size:Tuple[int]):
40
+ """
41
+ Args:
42
+ x: (bs, nhead, nregion, reg_size^2, head_dim)
43
+ Return:
44
+ x: (bs, C, T, H, W)
45
+ """
46
+ bs, nhead, nregion, reg_size_square, head_dim = x.size()
47
+ x = x.view(bs, nhead, region_t, region_h, region_w, region_size[0], region_size[1], region_size[2], head_dim)
48
+ x = torch.einsum('bmthwopqd->bmdtohpwq', x).reshape(bs, nhead*head_dim,
49
+ region_t*region_size[0],region_h*region_size[1], region_w*region_size[2])
50
+ return x
51
+
52
+
53
+ def video_regional_routing_attention_torch(
54
+ query:Tensor, key:Tensor, value:Tensor, scale:float,
55
+ region_graph:LongTensor, region_size:Tuple[int],
56
+ kv_region_size:Optional[Tuple[int]]=None,
57
+ auto_pad=False)->Tensor:
58
+ """
59
+ Args:
60
+ query, key, value: (B, C, T, H, W) tensor
61
+ scale: the scale/temperature for dot product attention
62
+ region_graph: (B, nhead, t_q*h_q*w_q, topk) tensor, topk <= t_k*h_k*w_k
63
+ region_size: region/window size for queries, (rt, rh, rw)
64
+ key_region_size: optional, if None, key_region_size=region_size
65
+ Return:
66
+ output: (B, C, T, H, W) tensor
67
+ attn: (bs, nhead, q_nregion, reg_size, topk*kv_region_size) attention matrix
68
+ """
69
+ kv_region_size = kv_region_size or region_size
70
+ bs, nhead, q_nregion, topk = region_graph.size()
71
+
72
+ # # Auto pad to deal with any input size
73
+ # q_pad_b, q_pad_r, kv_pad_b, kv_pad_r = 0, 0, 0, 0
74
+ # if auto_pad:
75
+ # _, _, Hq, Wq = query.size()
76
+ # q_pad_b = (region_size[0] - Hq % region_size[0]) % region_size[0]
77
+ # q_pad_r = (region_size[1] - Wq % region_size[1]) % region_size[1]
78
+ # if (q_pad_b > 0 or q_pad_r > 0):
79
+ # query = F.pad(query, (0, q_pad_r, 0, q_pad_b)) # zero padding
80
+
81
+ # _, _, Hk, Wk = key.size()
82
+ # kv_pad_b = (kv_region_size[0] - Hk % kv_region_size[0]) % kv_region_size[0]
83
+ # kv_pad_r = (kv_region_size[1] - Wk % kv_region_size[1]) % kv_region_size[1]
84
+ # if (kv_pad_r > 0 or kv_pad_b > 0):
85
+ # key = F.pad(key, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
86
+ # value = F.pad(value, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
87
+
88
+ # to sequence format, i.e. (bs, nhead, nregion, reg_size, head_dim)
89
+ query, q_region_t, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=nhead)
90
+ key, _, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=nhead)
91
+ value, _, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=nhead)
92
+
93
+ # gather key and values.
94
+ # torch.gather does not support broadcasting, hence we do it manually
95
+ bs, nhead, kv_nregion, kv_region_size, head_dim = key.size()
96
+ broadcasted_region_graph = region_graph.view(bs, nhead, q_nregion, topk, 1, 1).\
97
+ expand(-1, -1, -1, -1, kv_region_size, head_dim)
98
+ key_g = torch.gather(key.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\
99
+ expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
100
+ index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
101
+ value_g = torch.gather(value.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\
102
+ expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
103
+ index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
104
+
105
+ # token-to-token attention
106
+ # (bs, nhead, q_nregion, reg_size, head_dim) @ (bs, nhead, q_nregion, head_dim, topk*kv_region_size)
107
+ # -> (bs, nhead, q_nregion, reg_size, topk*kv_region_size)
108
+ attn = (query * scale) @ key_g.flatten(-3, -2).transpose(-1, -2)
109
+ attn = torch.softmax(attn, dim=-1)
110
+ # (bs, nhead, q_nregion, reg_size, topk*kv_region_size) @ (bs, nhead, q_nregion, topk*kv_region_size, head_dim)
111
+ # -> (bs, nhead, q_nregion, reg_size, head_dim)
112
+ output = attn @ value_g.flatten(-3, -2)
113
+
114
+ # to BCTHW format
115
+ output = _seq2grid(output, region_t=q_region_t, region_h=q_region_h, region_w=q_region_w, region_size=region_size)
116
+
117
+ # remove paddings if needed
118
+ # if auto_pad and (q_pad_b > 0 or q_pad_r > 0):
119
+ # output = output[:, :, :Hq, :Wq]
120
+
121
+ return output, attn
122
+
123
+
124
+
125
+
126
+ class CDC_T(nn.Module):
127
+ """
128
+ The CDC_T Module is from here: https://github.com/ZitongYu/PhysFormer/model/transformer_layer.py
129
+ """
130
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
131
+ padding=1, dilation=1, groups=1, bias=False, theta=0.6):
132
+
133
+ super(CDC_T, self).__init__()
134
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
135
+ dilation=dilation, groups=groups, bias=bias)
136
+ self.theta = theta
137
+
138
+ def forward(self, x):
139
+ out_normal = self.conv(x)
140
+
141
+ if math.fabs(self.theta - 0.0) < 1e-8:
142
+ return out_normal
143
+ else:
144
+ # pdb.set_trace()
145
+ [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
146
+
147
+ # only CD works on temporal kernel size>1
148
+ if self.conv.weight.shape[2] > 1:
149
+ kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
150
+ 2).sum(2)
151
+ kernel_diff = kernel_diff[:, :, None, None, None]
152
+ out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
153
+ padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
154
+ return out_normal - self.theta * out_diff
155
+
156
+ else:
157
+ return out_normal
158
+
159
+ class video_BRA(nn.Module):
160
+
161
+ def __init__(self, dim, num_heads=8, t_patch=8, qk_scale=None, topk=4, side_dwconv=3, auto_pad=False, attn_backend='torch'):
162
+ super().__init__()
163
+
164
+ self.dim = dim
165
+ self.num_heads = num_heads
166
+ assert self.dim % num_heads == 0, 'dim must be divisible by num_heads!'
167
+ self.head_dim = self.dim // self.num_heads
168
+ self.scale = qk_scale or self.dim ** -0.5
169
+ self.topk = topk
170
+ self.t_patch = t_patch # frame of patch
171
+ ################side_dwconv (i.e. LCE in Shunted Transformer)###########
172
+ self.lepe = nn.Conv3d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
173
+ lambda x: torch.zeros_like(x)
174
+ ##########################################
175
+ self.qkv_linear = nn.Conv3d(self.dim, 3*self.dim, kernel_size=1)
176
+ self.output_linear = nn.Conv3d(self.dim, self.dim, kernel_size=1)
177
+ self.proj_q = nn.Sequential(
178
+ CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=0.2),
179
+ nn.BatchNorm3d(dim),
180
+ )
181
+ self.proj_k = nn.Sequential(
182
+ CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=0.2),
183
+ nn.BatchNorm3d(dim),
184
+ )
185
+ self.proj_v = nn.Sequential(
186
+ nn.Conv3d(dim, dim, 1, stride=1, padding=0, groups=1, bias=False),
187
+ )
188
+ if attn_backend == 'torch':
189
+ self.attn_fn = video_regional_routing_attention_torch
190
+ else:
191
+ raise ValueError('CUDA implementation is not available yet. Please stay tuned.')
192
+
193
+ def forward(self, x:Tensor):
194
+
195
+ N, C, T, H, W = x.size()
196
+ t_region = max(4 // self.t_patch , 1)
197
+ region_size = (t_region, H//4 , W//4)
198
+
199
+ # STEP 1: linear projection
200
+ q , k , v = self.proj_q(x) , self.proj_k(x) ,self.proj_v(x)
201
+
202
+ # STEP 2: pre attention
203
+ q_r = F.avg_pool3d(q.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False)
204
+ k_r = F.avg_pool3d(k.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) # ncthw
205
+ q_r:Tensor = q_r.permute(0, 2, 3, 4, 1).flatten(1, 3) # n(thw)c
206
+ k_r:Tensor = k_r.flatten(2, 4) # nc(thw)
207
+ a_r = q_r @ k_r # n(thw)(thw)
208
+ _, idx_r = torch.topk(a_r, k=self.topk, dim=-1) # n(thw)k
209
+ idx_r:LongTensor = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1)
210
+
211
+ # STEP 3: refined attention
212
+ output, attn_mat = self.attn_fn(query=q, key=k, value=v, scale=self.scale,
213
+ region_graph=idx_r, region_size=region_size)
214
+
215
+ output = output + self.lepe(v) # nctHW
216
+ output = self.output_linear(output) # nctHW
217
+
218
+ return output
219
+
220
+ class video_BiFormerBlock(nn.Module):
221
+ def __init__(self, dim, drop_path=0., num_heads=4, t_patch=1,qk_scale=None, topk=4, mlp_ratio=2, side_dwconv=5):
222
+ super().__init__()
223
+ self.t_patch = t_patch
224
+ self.norm1 = nn.BatchNorm3d(dim)
225
+ self.attn = video_BRA(dim=dim, num_heads=num_heads, t_patch=t_patch,qk_scale=qk_scale, topk=topk, side_dwconv=side_dwconv)
226
+ self.norm2 = nn.BatchNorm3d(dim)
227
+ self.mlp = nn.Sequential(nn.Conv3d(dim, int(mlp_ratio*dim), kernel_size=1),
228
+ nn.BatchNorm3d(int(mlp_ratio*dim)),
229
+ nn.GELU(),
230
+ nn.Conv3d(int(mlp_ratio*dim), int(mlp_ratio*dim), 3, stride=1, padding=1),
231
+ nn.BatchNorm3d(int(mlp_ratio*dim)),
232
+ nn.GELU(),
233
+ nn.Conv3d(int(mlp_ratio*dim), dim, kernel_size=1),
234
+ nn.BatchNorm3d(dim),
235
+ )
236
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
237
+
238
+ def forward(self, x):
239
+ x = x + self.drop_path(self.attn(self.norm1(x)))
240
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
241
+ return x
242
+
243
+ class Fusion_Stem(nn.Module):
244
+ def __init__(self,apha=0.5,belta=0.5):
245
+ super(Fusion_Stem, self).__init__()
246
+
247
+ self.stem11 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
248
+ nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
249
+ nn.ReLU(inplace=True),
250
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
251
+ )
252
+
253
+ self.stem12 = nn.Sequential(nn.Conv2d(12, 64, kernel_size=5, stride=2, padding=2),
254
+ nn.BatchNorm2d(64),
255
+ nn.ReLU(inplace=True),
256
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
257
+ )
258
+
259
+ self.stem21 =nn.Sequential(
260
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
261
+ nn.BatchNorm2d(64),
262
+ nn.ReLU(inplace=True),
263
+ )
264
+
265
+ self.stem22 =nn.Sequential(
266
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
267
+ nn.BatchNorm2d(64),
268
+ nn.ReLU(inplace=True),
269
+ )
270
+
271
+ self.apha = apha
272
+ self.belta = belta
273
+
274
+ def forward(self, x):
275
+ """Definition of Fusion_Stem.
276
+ Args:
277
+ x [N,D,C,H,W]
278
+ Returns:
279
+ fusion_x [N*D,C,H/4,W/4]
280
+ """
281
+ N, D, C, H, W = x.shape
282
+ x1 = torch.cat([x[:,:1,:,:,:],x[:,:1,:,:,:],x[:,:D-2,:,:,:]],1)
283
+ x2 = torch.cat([x[:,:1,:,:,:],x[:,:D-1,:,:,:]],1)
284
+ x3 = x
285
+ x4 = torch.cat([x[:,1:,:,:,:],x[:,D-1:,:,:,:]],1)
286
+ x5 = torch.cat([x[:,2:,:,:,:],x[:,D-1:,:,:,:],x[:,D-1:,:,:,:]],1)
287
+ x_diff = self.stem12(torch.cat([x2-x1,x3-x2,x4-x3,x5-x4],2).view(N * D, 12, H, W))
288
+ x3 = x3.contiguous().view(N * D, C, H, W)
289
+ x = self.stem11(x3)
290
+
291
+ #fusion layer1
292
+ x_path1 = self.apha*x + self.belta*x_diff
293
+ x_path1 = self.stem21(x_path1)
294
+ #fusion layer2
295
+ x_path2 = self.stem22(x_diff)
296
+ x = self.apha*x_path1 + self.belta*x_path2
297
+
298
+ return x
299
+
300
+ class TPT_Block(nn.Module):
301
+ def __init__(self, dim, depth, num_heads, t_patch, topk,
302
+ mlp_ratio=4., drop_path=0., side_dwconv=5):
303
+ super().__init__()
304
+ self.dim = dim
305
+ self.depth = depth
306
+ ############ downsample layers & upsample layers #####################
307
+ self.downsample_layers = nn.ModuleList()
308
+ self.upsample_layers = nn.ModuleList()
309
+ self.layer_n = int(math.log(t_patch,2))
310
+ for i in range(self.layer_n):
311
+ downsample_layer = nn.Sequential(
312
+ nn.BatchNorm3d(dim),
313
+ nn.Conv3d(dim , dim , kernel_size=(2, 1, 1), stride=(2, 1, 1)),
314
+ )
315
+ self.downsample_layers.append(downsample_layer)
316
+ upsample_layer = nn.Sequential(
317
+ nn.Upsample(scale_factor=(2, 1, 1)),
318
+ nn.Conv3d(dim , dim , [3, 1, 1], stride=1, padding=(1, 0, 0)),
319
+ nn.BatchNorm3d(dim),
320
+ nn.ELU(),
321
+ )
322
+ self.upsample_layers.append(upsample_layer)
323
+ ######################################################################
324
+ self.blocks = nn.ModuleList([
325
+ video_BiFormerBlock(
326
+ dim=dim,
327
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
328
+ num_heads=num_heads,
329
+ t_patch=t_patch,
330
+ topk=topk,
331
+ mlp_ratio=mlp_ratio,
332
+ side_dwconv=side_dwconv,
333
+ )
334
+ for i in range(depth)
335
+ ])
336
+ def forward(self, x:torch.Tensor):
337
+ """Definition of TPT_Block.
338
+ Args:
339
+ x [N,C,D,H,W]
340
+ Returns:
341
+ x [N,C,D,H,W]
342
+ """
343
+ for i in range(self.layer_n) :
344
+ x = self.downsample_layers[i](x)
345
+ for blk in self.blocks:
346
+ x = blk(x)
347
+ for i in range(self.layer_n) :
348
+ x = self.upsample_layers[i](x)
349
+
350
+ return x
351
+
352
+ class RhythmFormer(nn.Module):
353
+
354
+ def __init__(
355
+ self,
356
+ name: Optional[str] = None,
357
+ pretrained: bool = False,
358
+ dim: int = 64, frame: int = 160,
359
+ image_size: Optional[int] = (160,128,128),
360
+ in_chans=64, head_dim=16,
361
+ stage_n = 3,
362
+ embed_dim=[64, 64, 64], mlp_ratios=[1.5, 1.5, 1.5],
363
+ depth=[2, 2, 2],
364
+ t_patchs:Union[int, Tuple[int]]=(2, 4, 8),
365
+ topks:Union[int, Tuple[int]]=(40, 40, 40),
366
+ side_dwconv:int=3,
367
+ drop_path_rate=0.,
368
+ use_checkpoint_stages=[],
369
+ ):
370
+ super().__init__()
371
+
372
+ self.image_size = image_size
373
+ self.frame = frame
374
+ self.dim = dim
375
+ self.stage_n = stage_n
376
+
377
+ self.Fusion_Stem = Fusion_Stem()
378
+ self.patch_embedding = nn.Conv3d(in_chans,embed_dim[0], kernel_size=(1, 4, 4), stride=(1, 4, 4))
379
+ self.ConvBlockLast = nn.Conv1d(embed_dim[-1], 1, kernel_size=1,stride=1, padding=0)
380
+
381
+ ##########################################################################
382
+ self.stages = nn.ModuleList()
383
+ nheads= [dim // head_dim for dim in embed_dim]
384
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]
385
+ for i in range(stage_n):
386
+ stage = TPT_Block(dim=embed_dim[i],
387
+ depth=depth[i],
388
+ num_heads=nheads[i],
389
+ mlp_ratio=mlp_ratios[i],
390
+ drop_path=dp_rates[sum(depth[:i]):sum(depth[:i+1])],
391
+ t_patch=t_patchs[i], topk=topks[i], side_dwconv=side_dwconv
392
+ )
393
+ self.stages.append(stage)
394
+ ##########################################################################
395
+
396
+ self.apply(self._init_weights)
397
+
398
+ def _init_weights(self, m):
399
+ if isinstance(m, nn.Linear):
400
+ trunc_normal_(m.weight, std=.02)
401
+ if isinstance(m, nn.Linear) and m.bias is not None:
402
+ nn.init.constant_(m.bias, 0)
403
+ elif isinstance(m, nn.LayerNorm):
404
+ nn.init.constant_(m.bias, 0)
405
+ nn.init.constant_(m.weight, 1.0)
406
+
407
+ def forward(self, x):
408
+ N, D, C, H, W = x.shape
409
+ x = self.Fusion_Stem(x) #[N*D 64 H/4 W/4]
410
+ x = x.view(N,D,64,H//4,W//4).permute(0,2,1,3,4)
411
+ x = self.patch_embedding(x) #[N 64 D 8 8]
412
+ for i in range(3):
413
+ x = self.stages[i](x) #[N 64 D 8 8]
414
+ features_last = torch.mean(x,3) #[N, 64, D, 8]
415
+ features_last = torch.mean(features_last,3) #[N, 64, D]
416
+ rPPG = self.ConvBlockLast(features_last) #[N, 1, D]
417
+ rPPG = rPPG.squeeze(1)
418
+ return rPPG
neural_methods/model/TS_CAN.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Temporal Shift Convolutional Attention Network (TS-CAN).
2
+ Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement
3
+ NeurIPS, 2020
4
+ Xin Liu, Josh Fromm, Shwetak Patel, Daniel McDuff
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class Attention_mask(nn.Module):
12
+ def __init__(self):
13
+ super(Attention_mask, self).__init__()
14
+
15
+ def forward(self, x):
16
+ xsum = torch.sum(x, dim=2, keepdim=True)
17
+ xsum = torch.sum(xsum, dim=3, keepdim=True)
18
+ xshape = tuple(x.size())
19
+ return x / xsum * xshape[2] * xshape[3] * 0.5
20
+
21
+ def get_config(self):
22
+ """May be generated manually. """
23
+ config = super(Attention_mask, self).get_config()
24
+ return config
25
+
26
+
27
+ class TSM(nn.Module):
28
+ def __init__(self, n_segment=10, fold_div=3):
29
+ super(TSM, self).__init__()
30
+ self.n_segment = n_segment
31
+ self.fold_div = fold_div
32
+
33
+ def forward(self, x):
34
+ nt, c, h, w = x.size()
35
+ n_batch = nt // self.n_segment
36
+ x = x.view(n_batch, self.n_segment, c, h, w)
37
+ fold = c // self.fold_div
38
+ out = torch.zeros_like(x)
39
+ out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
40
+ out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
41
+ out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
42
+ return out.view(nt, c, h, w)
43
+
44
+
45
+ class TSCAN(nn.Module):
46
+
47
+ def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
48
+ dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20, img_size=36):
49
+ """Definition of TS_CAN.
50
+ Args:
51
+ in_channels: the number of input channel. Default: 3
52
+ frame_depth: the number of frame (window size) used in temport shift. Default: 20
53
+ img_size: height/width of each frame. Default: 36.
54
+ Returns:
55
+ TS_CAN model.
56
+ """
57
+ super(TSCAN, self).__init__()
58
+ self.in_channels = in_channels
59
+ self.kernel_size = kernel_size
60
+ self.dropout_rate1 = dropout_rate1
61
+ self.dropout_rate2 = dropout_rate2
62
+ self.pool_size = pool_size
63
+ self.nb_filters1 = nb_filters1
64
+ self.nb_filters2 = nb_filters2
65
+ self.nb_dense = nb_dense
66
+ # TSM layers
67
+ self.TSM_1 = TSM(n_segment=frame_depth)
68
+ self.TSM_2 = TSM(n_segment=frame_depth)
69
+ self.TSM_3 = TSM(n_segment=frame_depth)
70
+ self.TSM_4 = TSM(n_segment=frame_depth)
71
+ # Motion branch convs
72
+ self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
73
+ bias=True)
74
+ self.motion_conv2 = nn.Conv2d(
75
+ self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
76
+ self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
77
+ bias=True)
78
+ self.motion_conv4 = nn.Conv2d(
79
+ self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
80
+ # Apperance branch convs
81
+ self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size,
82
+ padding=(1, 1), bias=True)
83
+ self.apperance_conv2 = nn.Conv2d(
84
+ self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
85
+ self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size,
86
+ padding=(1, 1), bias=True)
87
+ self.apperance_conv4 = nn.Conv2d(
88
+ self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
89
+ # Attention layers
90
+ self.apperance_att_conv1 = nn.Conv2d(
91
+ self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
92
+ self.attn_mask_1 = Attention_mask()
93
+ self.apperance_att_conv2 = nn.Conv2d(
94
+ self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
95
+ self.attn_mask_2 = Attention_mask()
96
+ # Avg pooling
97
+ self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
98
+ self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
99
+ self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
100
+ # Dropout layers
101
+ self.dropout_1 = nn.Dropout(self.dropout_rate1)
102
+ self.dropout_2 = nn.Dropout(self.dropout_rate1)
103
+ self.dropout_3 = nn.Dropout(self.dropout_rate1)
104
+ self.dropout_4 = nn.Dropout(self.dropout_rate2)
105
+ # Dense layers
106
+ if img_size == 36:
107
+ self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True)
108
+ elif img_size == 72:
109
+ self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True)
110
+ elif img_size == 96:
111
+ self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True)
112
+ elif img_size == 128:
113
+ self.final_dense_1 = nn.Linear(57600, self.nb_dense, bias=True)
114
+ else:
115
+ raise Exception('Unsupported image size')
116
+ self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True)
117
+
118
+ def forward(self, inputs, params=None):
119
+ diff_input = inputs[:, :3, :, :]
120
+ raw_input = inputs[:, 3:, :, :]
121
+
122
+ diff_input = self.TSM_1(diff_input)
123
+ d1 = torch.tanh(self.motion_conv1(diff_input))
124
+ d1 = self.TSM_2(d1)
125
+ d2 = torch.tanh(self.motion_conv2(d1))
126
+
127
+ r1 = torch.tanh(self.apperance_conv1(raw_input))
128
+ r2 = torch.tanh(self.apperance_conv2(r1))
129
+
130
+ g1 = torch.sigmoid(self.apperance_att_conv1(r2))
131
+ g1 = self.attn_mask_1(g1)
132
+ gated1 = d2 * g1
133
+
134
+ d3 = self.avg_pooling_1(gated1)
135
+ d4 = self.dropout_1(d3)
136
+
137
+ r3 = self.avg_pooling_2(r2)
138
+ r4 = self.dropout_2(r3)
139
+
140
+ d4 = self.TSM_3(d4)
141
+ d5 = torch.tanh(self.motion_conv3(d4))
142
+ d5 = self.TSM_4(d5)
143
+ d6 = torch.tanh(self.motion_conv4(d5))
144
+
145
+ r5 = torch.tanh(self.apperance_conv3(r4))
146
+ r6 = torch.tanh(self.apperance_conv4(r5))
147
+
148
+ g2 = torch.sigmoid(self.apperance_att_conv2(r6))
149
+ g2 = self.attn_mask_2(g2)
150
+ gated2 = d6 * g2
151
+
152
+ d7 = self.avg_pooling_3(gated2)
153
+ d8 = self.dropout_3(d7)
154
+ d9 = d8.view(d8.size(0), -1)
155
+ d10 = torch.tanh(self.final_dense_1(d9))
156
+ d11 = self.dropout_4(d10)
157
+ out = self.final_dense_2(d11)
158
+
159
+ return out
160
+
161
+
162
+ class MTTS_CAN(nn.Module):
163
+ """MTTS_CAN is the multi-task (respiration) version of TS-CAN"""
164
+
165
+ def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
166
+ dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20):
167
+ super(MTTS_CAN, self).__init__()
168
+ self.in_channels = in_channels
169
+ self.kernel_size = kernel_size
170
+ self.dropout_rate1 = dropout_rate1
171
+ self.dropout_rate2 = dropout_rate2
172
+ self.pool_size = pool_size
173
+ self.nb_filters1 = nb_filters1
174
+ self.nb_filters2 = nb_filters2
175
+ self.nb_dense = nb_dense
176
+ # TSM layers
177
+ self.TSM_1 = TSM(n_segment=frame_depth)
178
+ self.TSM_2 = TSM(n_segment=frame_depth)
179
+ self.TSM_3 = TSM(n_segment=frame_depth)
180
+ self.TSM_4 = TSM(n_segment=frame_depth)
181
+ # Motion branch convs
182
+ self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
183
+ bias=True)
184
+ self.motion_conv2 = nn.Conv2d(
185
+ self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
186
+ self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
187
+ bias=True)
188
+ self.motion_conv4 = nn.Conv2d(
189
+ self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
190
+ # Apperance branch convs
191
+ self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size,
192
+ padding=(1, 1), bias=True)
193
+ self.apperance_conv2 = nn.Conv2d(
194
+ self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
195
+ self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size,
196
+ padding=(1, 1), bias=True)
197
+ self.apperance_conv4 = nn.Conv2d(
198
+ self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
199
+ # Attention layers
200
+ self.apperance_att_conv1 = nn.Conv2d(
201
+ self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
202
+ self.attn_mask_1 = Attention_mask()
203
+ self.apperance_att_conv2 = nn.Conv2d(
204
+ self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
205
+ self.attn_mask_2 = Attention_mask()
206
+ # Avg pooling
207
+ self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
208
+ self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
209
+ self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
210
+ # Dropout layers
211
+ self.dropout_1 = nn.Dropout(self.dropout_rate1)
212
+ self.dropout_2 = nn.Dropout(self.dropout_rate1)
213
+ self.dropout_3 = nn.Dropout(self.dropout_rate1)
214
+ self.dropout_4_y = nn.Dropout(self.dropout_rate2)
215
+ self.dropout_4_r = nn.Dropout(self.dropout_rate2)
216
+
217
+ # Dense layers
218
+ self.final_dense_1_y = nn.Linear(16384, self.nb_dense, bias=True)
219
+ self.final_dense_2_y = nn.Linear(self.nb_dense, 1, bias=True)
220
+ self.final_dense_1_r = nn.Linear(16384, self.nb_dense, bias=True)
221
+ self.final_dense_2_r = nn.Linear(self.nb_dense, 1, bias=True)
222
+
223
+ def forward(self, inputs, params=None):
224
+ diff_input = inputs[:, :3, :, :]
225
+ raw_input = inputs[:, 3:, :, :]
226
+
227
+ diff_input = self.TSM_1(diff_input)
228
+ d1 = torch.tanh(self.motion_conv1(diff_input))
229
+ d1 = self.TSM_2(d1)
230
+ d2 = torch.tanh(self.motion_conv2(d1))
231
+
232
+ r1 = torch.tanh(self.apperance_conv1(raw_input))
233
+ r2 = torch.tanh(self.apperance_conv2(r1))
234
+
235
+ g1 = torch.sigmoid(self.apperance_att_conv1(r2))
236
+ g1 = self.attn_mask_1(g1)
237
+ gated1 = d2 * g1
238
+
239
+ d3 = self.avg_pooling_1(gated1)
240
+ d4 = self.dropout_1(d3)
241
+
242
+ r3 = self.avg_pooling_2(r2)
243
+ r4 = self.dropout_2(r3)
244
+
245
+ d4 = self.TSM_3(d4)
246
+ d5 = torch.tanh(self.motion_conv3(d4))
247
+ d5 = self.TSM_4(d5)
248
+ d6 = torch.tanh(self.motion_conv4(d5))
249
+
250
+ r5 = torch.tanh(self.apperance_conv3(r4))
251
+ r6 = torch.tanh(self.apperance_conv4(r5))
252
+
253
+ g2 = torch.sigmoid(self.apperance_att_conv2(r6))
254
+ g2 = self.attn_mask_2(g2)
255
+ gated2 = d6 * g2
256
+
257
+ d7 = self.avg_pooling_3(gated2)
258
+ d8 = self.dropout_3(d7)
259
+ d9 = d8.view(d8.size(0), -1)
260
+
261
+ d10 = torch.tanh(self.final_dense_1_y(d9))
262
+ d11 = self.dropout_4_y(d10)
263
+ out_y = self.final_dense_2_y(d11)
264
+
265
+ d10 = torch.tanh(self.final_dense_1_r(d9))
266
+ d11 = self.dropout_4_r(d10)
267
+ out_r = self.final_dense_2_r(d11)
268
+
269
+ return out_y, out_r
neural_methods/model/__init__.py ADDED
File without changes
neural_methods/model/iBVPNet.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """iBVPNet - 3D Convolutional Network.
2
+ Proposed along with the iBVP Dataset, see https://doi.org/10.3390/electronics13071334
3
+
4
+ Joshi, Jitesh, and Youngjun Cho. 2024. "iBVP Dataset: RGB-Thermal rPPG Dataset with High Resolution Signal Quality Labels" Electronics 13, no. 7: 1334.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class ConvBlock3D(nn.Module):
13
+ def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
14
+ super(ConvBlock3D, self).__init__()
15
+ self.conv_block_3d = nn.Sequential(
16
+ nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding),
17
+ nn.Tanh(),
18
+ nn.InstanceNorm3d(out_channel),
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.conv_block_3d(x)
23
+
24
+
25
+ class DeConvBlock3D(nn.Module):
26
+ def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
27
+ super(DeConvBlock3D, self).__init__()
28
+ k_t, k_s1, k_s2 = kernel_size
29
+ s_t, s_s1, s_s2 = stride
30
+ self.deconv_block_3d = nn.Sequential(
31
+ nn.ConvTranspose3d(in_channel, in_channel, (k_t, 1, 1), (s_t, 1, 1), padding),
32
+ nn.Tanh(),
33
+ nn.InstanceNorm3d(in_channel),
34
+
35
+ nn.Conv3d(in_channel, out_channel, (1, k_s1, k_s2), (1, s_s1, s_s2), padding),
36
+ nn.Tanh(),
37
+ nn.InstanceNorm3d(out_channel),
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.deconv_block_3d(x)
42
+
43
+ # num_filters
44
+ nf = [8, 16, 24, 40, 64]
45
+
46
+ class encoder_block(nn.Module):
47
+ def __init__(self, in_channel, debug=False):
48
+ super(encoder_block, self).__init__()
49
+ # in_channel, out_channel, kernel_size, stride, padding
50
+
51
+ self.debug = debug
52
+ self.spatio_temporal_encoder = nn.Sequential(
53
+ ConvBlock3D(in_channel, nf[0], [1, 3, 3], [1, 1, 1], [0, 1, 1]),
54
+ ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 1, 1], [1, 1, 1]),
55
+ nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
56
+ ConvBlock3D(nf[1], nf[2], [1, 3, 3], [1, 1, 1], [0, 1, 1]),
57
+ ConvBlock3D(nf[2], nf[3], [3, 3, 3], [1, 1, 1], [1, 1, 1]),
58
+ nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
59
+ ConvBlock3D(nf[3], nf[4], [1, 3, 3], [1, 1, 1], [0, 1, 1]),
60
+ ConvBlock3D(nf[4], nf[4], [3, 3, 3], [1, 1, 1], [1, 1, 1]),
61
+ )
62
+
63
+ self.temporal_encoder = nn.Sequential(
64
+ ConvBlock3D(nf[4], nf[4], [11, 1, 1], [1, 1, 1], [5, 0, 0]),
65
+ ConvBlock3D(nf[4], nf[4], [11, 3, 3], [1, 1, 1], [5, 1, 1]),
66
+ nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),
67
+ ConvBlock3D(nf[4], nf[4], [11, 1, 1], [1, 1, 1], [5, 0, 0]),
68
+ ConvBlock3D(nf[4], nf[4], [11, 3, 3], [1, 1, 1], [5, 1, 1]),
69
+ nn.MaxPool3d((2, 2, 2), stride=(2, 1, 1)),
70
+ ConvBlock3D(nf[4], nf[4], [7, 1, 1], [1, 1, 1], [3, 0, 0]),
71
+ ConvBlock3D(nf[4], nf[4], [7, 3, 3], [1, 1, 1], [3, 1, 1])
72
+ )
73
+
74
+ def forward(self, x):
75
+ if self.debug:
76
+ print("Encoder")
77
+ print("x.shape", x.shape)
78
+ st_x = self.spatio_temporal_encoder(x)
79
+ if self.debug:
80
+ print("st_x.shape", st_x.shape)
81
+ t_x = self.temporal_encoder(st_x)
82
+ if self.debug:
83
+ print("t_x.shape", t_x.shape)
84
+ return t_x
85
+
86
+
87
+ class decoder_block(nn.Module):
88
+ def __init__(self, debug=False):
89
+ super(decoder_block, self).__init__()
90
+ self.debug = debug
91
+ self.decoder_block = nn.Sequential(
92
+ DeConvBlock3D(nf[4], nf[3], [7, 3, 3], [2, 2, 2], [2, 1, 1]),
93
+ DeConvBlock3D(nf[3], nf[2], [7, 3, 3], [2, 2, 2], [2, 1, 1])
94
+ )
95
+
96
+ def forward(self, x):
97
+ if self.debug:
98
+ print("Decoder")
99
+ print("x.shape", x.shape)
100
+ x = self.decoder_block(x)
101
+ if self.debug:
102
+ print("x.shape", x.shape)
103
+ return x
104
+
105
+
106
+
107
+ class iBVPNet(nn.Module):
108
+ def __init__(self, frames, in_channels=3, debug=False):
109
+ super(iBVPNet, self).__init__()
110
+ self.debug = debug
111
+
112
+ self.in_channels = in_channels
113
+ if self.in_channels == 1 or self.in_channels == 3:
114
+ self.norm = nn.InstanceNorm3d(self.in_channels)
115
+ elif self.in_channels == 4:
116
+ self.rgb_norm = nn.InstanceNorm3d(3)
117
+ self.thermal_norm = nn.InstanceNorm3d(1)
118
+ else:
119
+ print("Unsupported input channels")
120
+
121
+ self.ibvpnet = nn.Sequential(
122
+ encoder_block(in_channels, debug),
123
+ decoder_block(debug),
124
+ # spatial adaptive pooling
125
+ nn.AdaptiveMaxPool3d((frames, 1, 1)),
126
+ nn.Conv3d(nf[2], 1, [1, 1, 1], stride=1, padding=0)
127
+ )
128
+
129
+
130
+ def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32]
131
+
132
+ [batch, channel, length, width, height] = x.shape
133
+
134
+ x = torch.diff(x, dim=2)
135
+
136
+ if self.debug:
137
+ print("Input.shape", x.shape)
138
+
139
+ if self.in_channels == 1:
140
+ x = self.norm(x[:, -1:, :, :, :])
141
+ elif self.in_channels == 3:
142
+ x = self.norm(x[:, :3, :, :, :])
143
+ elif self.in_channels == 4:
144
+ rgb_x = self.rgb_norm(x[:, :3, :, :, :])
145
+ thermal_x = self.thermal_norm(x[:, -1:, :, :, :])
146
+ x = torch.concat([rgb_x, thermal_x], dim = 1)
147
+ else:
148
+ try:
149
+ print("Specified input channels:", self.in_channels)
150
+ print("Data channels", channel)
151
+ assert self.in_channels <= channel
152
+ except:
153
+ print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels")
154
+ print("Default or specified channels:", self.in_channels)
155
+ print("Data channels [B, C, N, W, H]", x.shape)
156
+ print("Exiting")
157
+ exit()
158
+
159
+ if self.debug:
160
+ print("Diff Normalized shape", x.shape)
161
+
162
+ feats = self.ibvpnet(x)
163
+ if self.debug:
164
+ print("feats.shape", feats.shape)
165
+ rPPG = feats.view(-1, length-1)
166
+ return rPPG
167
+
168
+
169
+ if __name__ == "__main__":
170
+ import torch
171
+ from torch.utils.tensorboard import SummaryWriter
172
+
173
+ # default `log_dir` is "runs" - we'll be more specific here
174
+ writer = SummaryWriter('runs/iBVPNet')
175
+
176
+ duration = 8
177
+ fs = 25
178
+ batch_size = 4
179
+ frames = duration*fs
180
+ in_channels = 1
181
+ height = 64
182
+ width = 64
183
+ test_data = torch.rand(batch_size, in_channels, frames, height, width)
184
+
185
+ net = iBVPNet(in_channels=in_channels, frames=frames, debug=True)
186
+ # print("-"*100)
187
+ # print(net)
188
+ # print("-"*100)
189
+ pred = net(test_data)
190
+
191
+ print(pred.shape)
192
+
193
+ writer.add_graph(net, test_data)
194
+ writer.close()
neural_methods/trainer/BaseTrainer.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib.ticker import ScalarFormatter, MaxNLocator
5
+ import os
6
+ import pickle
7
+
8
+
9
+ class BaseTrainer:
10
+ @staticmethod
11
+ def add_trainer_args(parser):
12
+ """Adds arguments to Paser for training process"""
13
+ parser.add_argument('--lr', default=None, type=float)
14
+ parser.add_argument('--model_file_name', default=None, type=float)
15
+ return parser
16
+
17
+ def __init__(self):
18
+ pass
19
+
20
+ def train(self, data_loader):
21
+ pass
22
+
23
+ def valid(self, data_loader):
24
+ pass
25
+
26
+ def test(self):
27
+ pass
28
+
29
+ def save_test_outputs(self, predictions, labels, config):
30
+
31
+ output_dir = config.TEST.OUTPUT_SAVE_DIR
32
+ if not os.path.exists(output_dir):
33
+ os.makedirs(output_dir, exist_ok=True)
34
+
35
+ # Filename ID to be used in any output files that get saved
36
+ if config.TOOLBOX_MODE == 'train_and_test':
37
+ filename_id = self.model_file_name
38
+ elif config.TOOLBOX_MODE == 'only_test':
39
+ model_file_root = config.INFERENCE.MODEL_PATH.split("/")[-1].split(".pth")[0]
40
+ filename_id = model_file_root + "_" + config.TEST.DATA.DATASET
41
+ else:
42
+ raise ValueError('Metrics.py evaluation only supports train_and_test and only_test!')
43
+ output_path = os.path.join(output_dir, filename_id + '_outputs.pickle')
44
+
45
+ data = dict()
46
+ data['predictions'] = predictions
47
+ data['labels'] = labels
48
+ data['label_type'] = config.TEST.DATA.PREPROCESS.LABEL_TYPE
49
+ data['fs'] = config.TEST.DATA.FS
50
+
51
+ with open(output_path, 'wb') as handle: # save out frame dict pickle file
52
+ pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
53
+
54
+ print('Saving outputs to:', output_path)
55
+
56
+ def plot_losses_and_lrs(self, train_loss, valid_loss, lrs, config):
57
+
58
+ output_dir = os.path.join(config.LOG.PATH, config.TRAIN.DATA.EXP_DATA_NAME, 'plots')
59
+ if not os.path.exists(output_dir):
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ # Filename ID to be used in plots that get saved
63
+ if config.TOOLBOX_MODE == 'train_and_test':
64
+ filename_id = self.model_file_name
65
+ else:
66
+ raise ValueError('Metrics.py evaluation only supports train_and_test and only_test!')
67
+
68
+ # Create a single plot for training and validation losses
69
+ plt.figure(figsize=(10, 6))
70
+ epochs = range(0, len(train_loss)) # Integer values for x-axis
71
+ plt.plot(epochs, train_loss, label='Training Loss')
72
+ if len(valid_loss) > 0:
73
+ plt.plot(epochs, valid_loss, label='Validation Loss')
74
+ else:
75
+ print("The list of validation losses is empty. The validation loss will not be plotted!")
76
+ plt.xlabel('Epoch')
77
+ plt.ylabel('Loss')
78
+ plt.title(f'{filename_id} Losses')
79
+ plt.legend()
80
+ plt.xticks(epochs)
81
+
82
+ # Set y-axis ticks with more granularity
83
+ ax = plt.gca()
84
+ ax.yaxis.set_major_locator(MaxNLocator(integer=False, prune='both'))
85
+
86
+ loss_plot_filename = os.path.join(output_dir, filename_id + '_losses.pdf')
87
+ plt.savefig(loss_plot_filename, dpi=300)
88
+ plt.close()
89
+
90
+ # Create a separate plot for learning rates
91
+ plt.figure(figsize=(6, 4))
92
+ scheduler_steps = range(0, len(lrs))
93
+ plt.plot(scheduler_steps, lrs, label='Learning Rate')
94
+ plt.xlabel('Scheduler Step')
95
+ plt.ylabel('Learning Rate')
96
+ plt.title(f'{filename_id} LR Schedule')
97
+ plt.legend()
98
+
99
+ # Set y-axis values in scientific notation
100
+ ax = plt.gca()
101
+ ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True, useOffset=False))
102
+ ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0)) # Force scientific notation
103
+
104
+ lr_plot_filename = os.path.join(output_dir, filename_id + '_learning_rates.pdf')
105
+ plt.savefig(lr_plot_filename, bbox_inches='tight', dpi=300)
106
+ plt.close()
107
+
108
+ print('Saving plots of losses and learning rates to:', output_dir)
neural_methods/trainer/BigSmallTrainer.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trainer for BigSmall Multitask Models"""
2
+
3
+ # Training / Eval Imports
4
+ import torch
5
+ import torch.optim as optim
6
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
7
+ from neural_methods import loss
8
+ from neural_methods.model.BigSmall import BigSmall
9
+ from evaluation.bigsmall_multitask_metrics import (calculate_bvp_metrics,
10
+ calculate_resp_metrics,
11
+ calculate_bp4d_au_metrics)
12
+
13
+ # Other Imports
14
+ from collections import OrderedDict
15
+ import numpy as np
16
+ import os
17
+ from tqdm import tqdm
18
+
19
+ class BigSmallTrainer(BaseTrainer):
20
+
21
+ def define_model(self, config):
22
+
23
+ # BigSmall Model
24
+ model = BigSmall(n_segment=3)
25
+
26
+ if self.using_TSM:
27
+ self.frame_depth = config.MODEL.BIGSMALL.FRAME_DEPTH
28
+ self.base_len = self.num_of_gpu * self.frame_depth
29
+
30
+ return model
31
+
32
+ def format_data_shape(self, data, labels):
33
+ # reshape big data
34
+ data_big = data[0]
35
+ N, D, C, H, W = data_big.shape
36
+ data_big = data_big.view(N * D, C, H, W)
37
+
38
+ # reshape small data
39
+ data_small = data[1]
40
+ N, D, C, H, W = data_small.shape
41
+ data_small = data_small.view(N * D, C, H, W)
42
+
43
+ # reshape labels
44
+ if len(labels.shape) != 3: # this training format requires labels that are of shape N_label, D_label, C_label
45
+ labels = torch.unsqueeze(labels, dim=-1)
46
+ N_label, D_label, C_label = labels.shape
47
+ labels = labels.view(N_label * D_label, C_label)
48
+
49
+ # If using temporal shift module
50
+ if self.using_TSM:
51
+ data_big = data_big[:(N * D) // self.base_len * self.base_len]
52
+ data_small = data_small[:(N * D) // self.base_len * self.base_len]
53
+ labels = labels[:(N * D) // self.base_len * self.base_len]
54
+
55
+ data[0] = data_big
56
+ data[1] = data_small
57
+ labels = torch.unsqueeze(labels, dim=-1)
58
+
59
+ return data, labels
60
+
61
+
62
+ def send_data_to_device(self, data, labels):
63
+ big_data = data[0].to(self.device)
64
+ small_data = data[1].to(self.device)
65
+ labels = labels.to(self.device)
66
+ data = (big_data, small_data)
67
+ return data, labels
68
+
69
+
70
+ def get_label_idxs(self, label_list, used_labels):
71
+ label_idxs = []
72
+ for l in used_labels:
73
+ idx = label_list.index(l)
74
+ label_idxs.append(idx)
75
+ return label_idxs
76
+
77
+
78
+ def remove_data_parallel(self, old_state_dict):
79
+ new_state_dict = OrderedDict()
80
+
81
+ for k, v in old_state_dict.items():
82
+ name = k[7:] # remove `module.`
83
+ new_state_dict[name] = v
84
+
85
+ return new_state_dict
86
+
87
+
88
+ def save_model(self, index):
89
+ if not os.path.exists(self.model_dir):
90
+ os.makedirs(self.model_dir)
91
+ model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
92
+ torch.save(self.model.state_dict(), model_path)
93
+ print('Saved Model Path: ', model_path)
94
+ print('')
95
+
96
+
97
+ def __init__(self, config, data_loader):
98
+
99
+ print('')
100
+ print('Init BigSmall Multitask Trainer\n\n')
101
+
102
+ self.config = config # save config file
103
+
104
+ # Set up GPU/CPU compute device
105
+ if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0:
106
+ self.device = torch.device(config.DEVICE) # set device to primary GPU
107
+ self.num_of_gpu = config.NUM_OF_GPU_TRAIN # set number of used GPUs
108
+ else:
109
+ self.device = "cpu" # if no GPUs set device is CPU
110
+ self.num_of_gpu = 0 # no GPUs used
111
+
112
+ # Defining model
113
+ self.using_TSM = True
114
+ self.model = self.define_model(config) # define the model
115
+
116
+ if torch.cuda.device_count() > 1 and config.NUM_OF_GPU_TRAIN > 1: # distribute model across GPUs
117
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) # data parallel model
118
+
119
+ self.model = self.model.to(self.device) # send model to primary GPU
120
+
121
+ # Training parameters
122
+ self.batch_size = config.TRAIN.BATCH_SIZE
123
+ self.max_epoch_num = config.TRAIN.EPOCHS
124
+ self.LR = config.TRAIN.LR
125
+
126
+ # Set Loss and Optimizer
127
+ AU_weights = torch.as_tensor([9.64, 11.74, 16.77, 1.05, 0.53, 0.56,
128
+ 0.75, 0.69, 8.51, 6.94, 5.03, 25.00]).to(self.device)
129
+
130
+ self.criterionAU = torch.nn.BCEWithLogitsLoss(pos_weight=AU_weights).to(self.device)
131
+ self.criterionBVP = torch.nn.MSELoss().to(self.device)
132
+ self.criterionRESP = torch.nn.MSELoss().to(self.device)
133
+ self.optimizer = optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0)
134
+
135
+ # self.scaler = torch.cuda.amp.GradScaler() # Loss scalar
136
+
137
+ # Model info (saved more dir, chunk len, best epoch, etc.)
138
+ self.model_dir = config.MODEL.MODEL_DIR
139
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
140
+ self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
141
+
142
+ # Epoch To Use For Test
143
+ self.used_epoch = 0
144
+
145
+ # Indicies corresponding to used labels
146
+ label_list = ['bp_wave', 'HR_bpm', 'systolic_bp', 'diastolic_bp', 'mean_bp',
147
+ 'resp_wave', 'resp_bpm', 'eda',
148
+ 'AU01', 'AU02', 'AU04', 'AU05', 'AU06', 'AU06int', 'AU07', 'AU09', 'AU10', 'AU10int',
149
+ 'AU11', 'AU12', 'AU12int', 'AU13', 'AU14', 'AU14int', 'AU15', 'AU16', 'AU17', 'AU17int',
150
+ 'AU18', 'AU19', 'AU20', 'AU22', 'AU23', 'AU24', 'AU27', 'AU28', 'AU29', 'AU30', 'AU31',
151
+ 'AU32', 'AU33', 'AU34', 'AU35', 'AU36', 'AU37', 'AU38', 'AU39',
152
+ 'pos_bvp','pos_env_norm_bvp']
153
+
154
+ used_labels = ['bp_wave', 'AU01', 'AU02', 'AU04', 'AU06', 'AU07', 'AU10', 'AU12',
155
+ 'AU14', 'AU15', 'AU17', 'AU23', 'AU24',
156
+ 'pos_env_norm_bvp', 'resp_wave']
157
+
158
+ # Get indicies for labels from npy array
159
+ au_label_list = [label for label in used_labels if 'AU' in label]
160
+ bvp_label_list_train = [label for label in used_labels if 'bvp' in label]
161
+ bvp_label_list_test = [label for label in used_labels if 'bp_wave' in label]
162
+ resp_label_list = [label for label in used_labels if 'resp' in label]
163
+
164
+ self.label_idx_train_au = self.get_label_idxs(label_list, au_label_list)
165
+ self.label_idx_valid_au = self.get_label_idxs(label_list, au_label_list)
166
+ self.label_idx_test_au = self.get_label_idxs(label_list, au_label_list)
167
+
168
+ self.label_idx_train_bvp = self.get_label_idxs(label_list, bvp_label_list_train)
169
+ self.label_idx_valid_bvp = self.get_label_idxs(label_list, bvp_label_list_train)
170
+ self.label_idx_test_bvp = self.get_label_idxs(label_list, bvp_label_list_test)
171
+
172
+ self.label_idx_train_resp = self.get_label_idxs(label_list, resp_label_list)
173
+ self.label_idx_valid_resp = self.get_label_idxs(label_list, resp_label_list)
174
+ self.label_idx_test_resp = self.get_label_idxs(label_list, resp_label_list)
175
+
176
+
177
+ def train(self, data_loader):
178
+ """Model Training"""
179
+
180
+ if data_loader["train"] is None:
181
+ raise ValueError("No data for train")
182
+
183
+ print('Starting Training Routine')
184
+ print('')
185
+
186
+ # Init min validation loss as infinity
187
+ min_valid_loss = np.inf # minimum validation loss
188
+
189
+ # ARRAYS TO SAVE (LOSS ARRAYS)
190
+ train_loss_dict = dict()
191
+ train_au_loss_dict = dict()
192
+ train_bvp_loss_dict = dict()
193
+ train_resp_loss_dict = dict()
194
+
195
+ val_loss_dict = dict()
196
+ val_au_loss_dict = dict()
197
+ val_bvp_loss_dict = dict()
198
+ val_resp_loss_dict = dict()
199
+
200
+ # TODO: Expand tracking and subsequent plotting of these losses for BigSmall
201
+ mean_training_losses = []
202
+ mean_valid_losses = []
203
+ lrs = []
204
+
205
+ # ITERATE THROUGH EPOCHS
206
+ for epoch in range(self.max_epoch_num):
207
+ print(f"====Training Epoch: {epoch}====")
208
+
209
+ # INIT PARAMS FOR TRAINING
210
+ running_loss = 0.0 # tracks avg loss over mini batches of 100
211
+ train_loss = []
212
+ train_au_loss = []
213
+ train_bvp_loss = []
214
+ train_resp_loss = []
215
+ self.model.train() # put model in train mode
216
+
217
+ # MODEL TRAINING
218
+ tbar = tqdm(data_loader["train"], ncols=80)
219
+ for idx, batch in enumerate(tbar):
220
+ tbar.set_description("Train epoch %s" % epoch)
221
+
222
+ # GATHER AND FORMAT BATCH DATA
223
+ data, labels = batch[0], batch[1]
224
+ data, labels = self.format_data_shape(data, labels)
225
+ data, labels = self.send_data_to_device(data, labels)
226
+
227
+ # FOWARD AND BACK PROPOGATE THROUGH MODEL
228
+ self.optimizer.zero_grad()
229
+ au_out, bvp_out, resp_out = self.model(data)
230
+ au_loss = self.criterionAU(au_out, labels[:, self.label_idx_train_au, 0]) # au loss
231
+ bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_train_bvp, 0]) # bvp loss
232
+ resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_train_resp, 0]) # resp loss
233
+ loss = au_loss + bvp_loss + resp_loss # sum losses
234
+ loss.backward()
235
+
236
+ # Append the current learning rate to the list
237
+ lrs.append(self.scheduler.get_last_lr())
238
+
239
+ self.optimizer.step()
240
+ # self.scaler.scale(loss).backward() # Loss scaling
241
+ # self.scaler.step(self.optimizer)
242
+ # self.scaler.update()
243
+
244
+
245
+
246
+
247
+ # UPDATE RUNNING LOSS AND PRINTED TERMINAL OUTPUT AND SAVED LOSSES
248
+ train_loss.append(loss.item())
249
+ train_au_loss.append(au_loss.item())
250
+ train_bvp_loss.append(bvp_loss.item())
251
+ train_resp_loss.append(resp_loss.item())
252
+
253
+ running_loss += loss.item()
254
+ if idx % 100 == 99: # print every 100 mini-batches
255
+ print(f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
256
+ running_loss = 0.0
257
+
258
+
259
+ tbar.set_postfix({"loss:": loss.item(), "lr:": self.optimizer.param_groups[0]["lr"]})
260
+
261
+ # APPEND EPOCH LOSS LIST TO TRAINING LOSS DICTIONARY
262
+ train_loss_dict[epoch] = train_loss
263
+ train_au_loss_dict[epoch] = train_au_loss
264
+ train_bvp_loss_dict[epoch] = train_bvp_loss
265
+ train_resp_loss_dict[epoch] = train_resp_loss
266
+
267
+ print('')
268
+
269
+ # Append the mean training loss for the epoch
270
+ mean_training_losses.append(np.mean(train_loss))
271
+
272
+ # SAVE MODEL FOR THIS EPOCH
273
+ self.save_model(epoch)
274
+
275
+ # VALIDATION (IF ENABLED)
276
+ if not self.config.TEST.USE_LAST_EPOCH:
277
+
278
+ # Get validation losses
279
+ valid_loss, valid_au_loss, valid_bvp_loss, valid_resp_loss = self.valid(data_loader)
280
+ mean_valid_losses.append(valid_loss)
281
+ val_loss_dict[epoch] = valid_loss
282
+ val_au_loss_dict[epoch] = valid_au_loss
283
+ val_bvp_loss_dict[epoch] = valid_bvp_loss
284
+ val_resp_loss_dict[epoch] = valid_resp_loss
285
+ print('validation loss: ', valid_loss)
286
+
287
+ # Update used model
288
+ if self.model_to_use == 'best_epoch' and (valid_loss < min_valid_loss):
289
+ min_valid_loss = valid_loss
290
+ self.used_epoch = epoch
291
+ print("Update best model! Best epoch: {}".format(self.used_epoch))
292
+ elif self.model_to_use == 'last_epoch':
293
+ self.used_epoch = epoch
294
+
295
+ # VALIDATION (NOT ENABLED)
296
+ else:
297
+ self.used_epoch = epoch
298
+
299
+ print('')
300
+
301
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
302
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
303
+
304
+ # PRINT MODEL TO BE USED FOR TESTING
305
+ print("Used model trained epoch:{}, val_loss:{}".format(self.used_epoch, min_valid_loss))
306
+ print('')
307
+
308
+
309
+
310
+ def valid(self, data_loader):
311
+ """ Model evaluation on the validation dataset."""
312
+
313
+ if data_loader["valid"] is None:
314
+ raise ValueError("No data for valid")
315
+
316
+ print("===Validating===")
317
+
318
+ # INIT PARAMS FOR VALIDATION
319
+ valid_loss = []
320
+ valid_au_loss = []
321
+ valid_bvp_loss = []
322
+ valid_resp_loss = []
323
+ self.model.eval()
324
+
325
+ # MODEL VALIDATION
326
+ with torch.no_grad():
327
+ vbar = tqdm(data_loader["valid"], ncols=80)
328
+ for valid_idx, valid_batch in enumerate(vbar):
329
+ vbar.set_description("Validation")
330
+
331
+ # GATHER AND FORMAT BATCH DATA
332
+ data, labels = valid_batch[0], valid_batch[1]
333
+ data, labels = self.format_data_shape(data, labels)
334
+ data, labels = self.send_data_to_device(data, labels)
335
+
336
+ au_out, bvp_out, resp_out = self.model(data)
337
+ au_loss = self.criterionAU(au_out, labels[:, self.label_idx_valid_au, 0]) # au loss
338
+ bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_valid_bvp, 0]) # bvp loss
339
+ resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_valid_resp, 0]) # resp loss
340
+ loss = au_loss + bvp_loss + resp_loss # sum losses
341
+
342
+ # APPEND VAL LOSS
343
+ valid_loss.append(loss.item())
344
+ valid_au_loss.append(au_loss.item())
345
+ valid_bvp_loss.append(bvp_loss.item())
346
+ valid_resp_loss.append(resp_loss.item())
347
+ vbar.set_postfix(loss=loss.item())
348
+
349
+ valid_loss = np.asarray(valid_loss)
350
+ valid_au_loss = np.asarray(valid_au_loss)
351
+ valid_bvp_loss = np.asarray(valid_bvp_loss)
352
+ valid_resp_loss = np.asarray(valid_resp_loss)
353
+ return np.mean(valid_loss), np.mean(valid_au_loss), np.mean(valid_bvp_loss), np.mean(valid_resp_loss)
354
+
355
+
356
+
357
+ def test(self, data_loader):
358
+ """ Model evaluation on the testing dataset."""
359
+
360
+ print("===Testing===")
361
+ print('')
362
+
363
+ # SETUP
364
+ if data_loader["test"] is None:
365
+ raise ValueError("No data for test")
366
+
367
+ # Change chunk length to be test chunk length
368
+ self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
369
+
370
+ # ARRAYS TO SAVE (PREDICTIONS AND METRICS ARRAYS)
371
+ preds_dict_au = dict()
372
+ labels_dict_au = dict()
373
+ preds_dict_bvp = dict()
374
+ labels_dict_bvp = dict()
375
+ preds_dict_resp = dict()
376
+ labels_dict_resp = dict()
377
+
378
+ # IF ONLY_TEST MODE LOAD PRETRAINED MODEL
379
+ if self.config.TOOLBOX_MODE == "only_test":
380
+ model_path = self.config.INFERENCE.MODEL_PATH
381
+ print("Testing uses pretrained model!")
382
+ print('Model path:', model_path)
383
+ if not os.path.exists(model_path):
384
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
385
+
386
+ # IF USING MODEL FROM TRAINING
387
+ else:
388
+ model_path = os.path.join(self.model_dir,
389
+ self.model_file_name + '_Epoch' + str(self.used_epoch) + '.pth')
390
+ print("Testing uses non-pretrained model!")
391
+ print('Model path:', model_path)
392
+ if not os.path.exists(model_path):
393
+ raise ValueError("Something went wrong... cant find trained model...")
394
+ print('')
395
+
396
+ # LOAD ABOVED SPECIFIED MODEL FOR TESTING
397
+ self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
398
+ self.model = self.model.to(self.device)
399
+ self.model.eval()
400
+
401
+ # MODEL TESTING
402
+ print("Running model evaluation on the testing dataset!")
403
+ with torch.no_grad():
404
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
405
+
406
+ # PROCESSING - ANALYSIS, METRICS, SAVING OUT DATA
407
+ batch_size = test_batch[1].shape[0] # get batch size
408
+
409
+ # GATHER AND FORMAT BATCH DATA
410
+ data, labels = test_batch[0], test_batch[1]
411
+ data, labels = self.format_data_shape(data, labels)
412
+ data, labels = self.send_data_to_device(data, labels)
413
+
414
+ # Weird dataloader bug is causing the final training batch to be of size 0...
415
+ if labels.shape[0] == 0:
416
+ continue
417
+
418
+ # GET MODEL PREDICTIONS
419
+ au_out, bvp_out, resp_out = self.model(data)
420
+ au_out = torch.sigmoid(au_out)
421
+
422
+ # GATHER AND SLICE LABELS USED FOR TEST DATASET
423
+ TEST_AU = False
424
+ if len(self.label_idx_test_au) > 0: # if test dataset has AU
425
+ TEST_AU = True
426
+ labels_au = labels[:, self.label_idx_test_au]
427
+ else: # if not set whole AU labels array to -1
428
+ labels_au = np.ones((batch_size, len(self.label_idx_train_au)))
429
+ labels_au = -1 * labels_au
430
+ # labels_au = torch.from_numpy(labels_au)
431
+
432
+ TEST_BVP = False
433
+ if len(self.label_idx_test_bvp) > 0: # if test dataset has BVP
434
+ TEST_BVP = True
435
+ labels_bvp = labels[:, self.label_idx_test_bvp]
436
+ else: # if not set whole BVP labels array to -1
437
+ labels_bvp = np.ones((batch_size, len(self.label_idx_train_bvp)))
438
+ labels_bvp = -1 * labels_bvp
439
+ # labels_bvp = torch.from_numpy(labels_bvp)
440
+
441
+ TEST_RESP = False
442
+ if len(self.label_idx_test_resp) > 0: # if test dataset has BVP
443
+ TEST_RESP = True
444
+ labels_resp = labels[:, self.label_idx_test_resp]
445
+ else: # if not set whole BVP labels array to -1
446
+ labels_resp = np.ones((batch_size, len(self.label_idx_train_resp)))
447
+ labels_resp = -1 * labels_resp
448
+ # labels_resp = torch.from_numpy(labels_resp)
449
+
450
+ # ITERATE THROUGH BATCH, SORT, AND ADD TO CORRECT DICTIONARY
451
+ for idx in range(batch_size):
452
+
453
+ # if the labels are cut off due to TSM dataformating
454
+ if idx * self.chunk_len >= labels.shape[0] and self.using_TSM:
455
+ continue
456
+
457
+ subj_index = test_batch[2][idx]
458
+ sort_index = int(test_batch[3][idx])
459
+
460
+ # add subject to prediction / label arrays
461
+ if subj_index not in preds_dict_bvp.keys():
462
+ preds_dict_au[subj_index] = dict()
463
+ labels_dict_au[subj_index] = dict()
464
+ preds_dict_bvp[subj_index] = dict()
465
+ labels_dict_bvp[subj_index] = dict()
466
+ preds_dict_resp[subj_index] = dict()
467
+ labels_dict_resp[subj_index] = dict()
468
+
469
+ # append predictions and labels to subject dict
470
+ preds_dict_au[subj_index][sort_index] = au_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
471
+ labels_dict_au[subj_index][sort_index] = labels_au[idx * self.chunk_len:(idx + 1) * self.chunk_len]
472
+ preds_dict_bvp[subj_index][sort_index] = bvp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
473
+ labels_dict_bvp[subj_index][sort_index] = labels_bvp[idx * self.chunk_len:(idx + 1) * self.chunk_len]
474
+ preds_dict_resp[subj_index][sort_index] = resp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
475
+ labels_dict_resp[subj_index][sort_index] = labels_resp[idx * self.chunk_len:(idx + 1) * self.chunk_len]
476
+
477
+ # Calculate Eval Metrics
478
+ bvp_metric_dict = calculate_bvp_metrics(preds_dict_bvp, labels_dict_bvp, self.config)
479
+ resp_metric_dict = calculate_resp_metrics(preds_dict_resp, labels_dict_resp, self.config)
480
+ au_metric_dict = calculate_bp4d_au_metrics(preds_dict_au, labels_dict_au, self.config)
481
+
482
+
483
+
484
+
neural_methods/trainer/BigSmallTrainer.py.backup ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trainer for BigSmall Multitask Models"""
2
+
3
+ # Training / Eval Imports
4
+ import torch
5
+ import torch.optim as optim
6
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
7
+ from neural_methods import loss
8
+ from neural_methods.model.BigSmall import BigSmall
9
+ from evaluation.bigsmall_multitask_metrics import (calculate_bvp_metrics,
10
+ calculate_resp_metrics,
11
+ calculate_bp4d_au_metrics)
12
+
13
+ # Other Imports
14
+ from collections import OrderedDict
15
+ import numpy as np
16
+ import os
17
+ from tqdm import tqdm
18
+
19
+ class BigSmallTrainer(BaseTrainer):
20
+
21
+ def define_model(self, config):
22
+
23
+ # BigSmall Model
24
+ model = BigSmall(n_segment=3)
25
+
26
+ if self.using_TSM:
27
+ self.frame_depth = config.MODEL.BIGSMALL.FRAME_DEPTH
28
+ self.base_len = self.num_of_gpu * self.frame_depth
29
+
30
+ return model
31
+
32
+ def format_data_shape(self, data, labels):
33
+ # reshape big data
34
+ data_big = data[0]
35
+ N, D, C, H, W = data_big.shape
36
+ data_big = data_big.view(N * D, C, H, W)
37
+
38
+ # reshape small data
39
+ data_small = data[1]
40
+ N, D, C, H, W = data_small.shape
41
+ data_small = data_small.view(N * D, C, H, W)
42
+
43
+ # reshape labels
44
+ if len(labels.shape) != 3: # this training format requires labels that are of shape N_label, D_label, C_label
45
+ labels = torch.unsqueeze(labels, dim=-1)
46
+ N_label, D_label, C_label = labels.shape
47
+ labels = labels.view(N_label * D_label, C_label)
48
+
49
+ # If using temporal shift module
50
+ if self.using_TSM:
51
+ data_big = data_big[:(N * D) // self.base_len * self.base_len]
52
+ data_small = data_small[:(N * D) // self.base_len * self.base_len]
53
+ labels = labels[:(N * D) // self.base_len * self.base_len]
54
+
55
+ data[0] = data_big
56
+ data[1] = data_small
57
+ labels = torch.unsqueeze(labels, dim=-1)
58
+
59
+ return data, labels
60
+
61
+
62
+ def send_data_to_device(self, data, labels):
63
+ big_data = data[0].to(self.device)
64
+ small_data = data[1].to(self.device)
65
+ labels = labels.to(self.device)
66
+ data = (big_data, small_data)
67
+ return data, labels
68
+
69
+
70
+ def get_label_idxs(self, label_list, used_labels):
71
+ label_idxs = []
72
+ for l in used_labels:
73
+ idx = label_list.index(l)
74
+ label_idxs.append(idx)
75
+ return label_idxs
76
+
77
+
78
+ def remove_data_parallel(self, old_state_dict):
79
+ new_state_dict = OrderedDict()
80
+
81
+ for k, v in old_state_dict.items():
82
+ name = k[7:] # remove `module.`
83
+ new_state_dict[name] = v
84
+
85
+ return new_state_dict
86
+
87
+
88
+ def save_model(self, index):
89
+ if not os.path.exists(self.model_dir):
90
+ os.makedirs(self.model_dir)
91
+ model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
92
+ torch.save(self.model.state_dict(), model_path)
93
+ print('Saved Model Path: ', model_path)
94
+ print('')
95
+
96
+
97
+ def __init__(self, config, data_loader):
98
+
99
+ print('')
100
+ print('Init BigSmall Multitask Trainer\n\n')
101
+
102
+ self.config = config # save config file
103
+
104
+ # Set up GPU/CPU compute device
105
+ if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0:
106
+ self.device = torch.device(config.DEVICE) # set device to primary GPU
107
+ self.num_of_gpu = config.NUM_OF_GPU_TRAIN # set number of used GPUs
108
+ else:
109
+ self.device = "cpu" # if no GPUs set device is CPU
110
+ self.num_of_gpu = 0 # no GPUs used
111
+
112
+ # Defining model
113
+ self.using_TSM = True
114
+ self.model = self.define_model(config) # define the model
115
+
116
+ if torch.cuda.device_count() > 1 and config.NUM_OF_GPU_TRAIN > 1: # distribute model across GPUs
117
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) # data parallel model
118
+
119
+ self.model = self.model.to(self.device) # send model to primary GPU
120
+
121
+ # Training parameters
122
+ self.batch_size = config.TRAIN.BATCH_SIZE
123
+ self.max_epoch_num = config.TRAIN.EPOCHS
124
+ self.LR = config.TRAIN.LR
125
+
126
+ # Set Loss and Optimizer
127
+ AU_weights = torch.as_tensor([9.64, 11.74, 16.77, 1.05, 0.53, 0.56,
128
+ 0.75, 0.69, 8.51, 6.94, 5.03, 25.00]).to(self.device)
129
+
130
+ self.criterionAU = torch.nn.BCEWithLogitsLoss(pos_weight=AU_weights).to(self.device)
131
+ self.criterionBVP = torch.nn.MSELoss().to(self.device)
132
+ self.criterionRESP = torch.nn.MSELoss().to(self.device)
133
+ self.optimizer = optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0)
134
+
135
+ # self.scaler = torch.cuda.amp.GradScaler() # Loss scalar
136
+
137
+ # Model info (saved more dir, chunk len, best epoch, etc.)
138
+ self.model_dir = config.MODEL.MODEL_DIR
139
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
140
+ self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
141
+
142
+ # Epoch To Use For Test
143
+ self.used_epoch = 0
144
+
145
+ # Indicies corresponding to used labels
146
+ label_list = ['bp_wave', 'HR_bpm', 'systolic_bp', 'diastolic_bp', 'mean_bp',
147
+ 'resp_wave', 'resp_bpm', 'eda',
148
+ 'AU01', 'AU02', 'AU04', 'AU05', 'AU06', 'AU06int', 'AU07', 'AU09', 'AU10', 'AU10int',
149
+ 'AU11', 'AU12', 'AU12int', 'AU13', 'AU14', 'AU14int', 'AU15', 'AU16', 'AU17', 'AU17int',
150
+ 'AU18', 'AU19', 'AU20', 'AU22', 'AU23', 'AU24', 'AU27', 'AU28', 'AU29', 'AU30', 'AU31',
151
+ 'AU32', 'AU33', 'AU34', 'AU35', 'AU36', 'AU37', 'AU38', 'AU39',
152
+ 'pos_bvp','pos_env_norm_bvp']
153
+
154
+ used_labels = ['bp_wave', 'AU01', 'AU02', 'AU04', 'AU06', 'AU07', 'AU10', 'AU12',
155
+ 'AU14', 'AU15', 'AU17', 'AU23', 'AU24',
156
+ 'pos_env_norm_bvp', 'resp_wave']
157
+
158
+ # Get indicies for labels from npy array
159
+ au_label_list = [label for label in used_labels if 'AU' in label]
160
+ bvp_label_list_train = [label for label in used_labels if 'bvp' in label]
161
+ bvp_label_list_test = [label for label in used_labels if 'bp_wave' in label]
162
+ resp_label_list = [label for label in used_labels if 'resp' in label]
163
+
164
+ self.label_idx_train_au = self.get_label_idxs(label_list, au_label_list)
165
+ self.label_idx_valid_au = self.get_label_idxs(label_list, au_label_list)
166
+ self.label_idx_test_au = self.get_label_idxs(label_list, au_label_list)
167
+
168
+ self.label_idx_train_bvp = self.get_label_idxs(label_list, bvp_label_list_train)
169
+ self.label_idx_valid_bvp = self.get_label_idxs(label_list, bvp_label_list_train)
170
+ self.label_idx_test_bvp = self.get_label_idxs(label_list, bvp_label_list_test)
171
+
172
+ self.label_idx_train_resp = self.get_label_idxs(label_list, resp_label_list)
173
+ self.label_idx_valid_resp = self.get_label_idxs(label_list, resp_label_list)
174
+ self.label_idx_test_resp = self.get_label_idxs(label_list, resp_label_list)
175
+
176
+
177
+ def train(self, data_loader):
178
+ """Model Training"""
179
+
180
+ if data_loader["train"] is None:
181
+ raise ValueError("No data for train")
182
+
183
+ print('Starting Training Routine')
184
+ print('')
185
+
186
+ # Init min validation loss as infinity
187
+ min_valid_loss = np.inf # minimum validation loss
188
+
189
+ # ARRAYS TO SAVE (LOSS ARRAYS)
190
+ train_loss_dict = dict()
191
+ train_au_loss_dict = dict()
192
+ train_bvp_loss_dict = dict()
193
+ train_resp_loss_dict = dict()
194
+
195
+ val_loss_dict = dict()
196
+ val_au_loss_dict = dict()
197
+ val_bvp_loss_dict = dict()
198
+ val_resp_loss_dict = dict()
199
+
200
+ # TODO: Expand tracking and subsequent plotting of these losses for BigSmall
201
+ mean_training_losses = []
202
+ mean_valid_losses = []
203
+ lrs = []
204
+
205
+ # ITERATE THROUGH EPOCHS
206
+ for epoch in range(self.max_epoch_num):
207
+ print(f"====Training Epoch: {epoch}====")
208
+
209
+ # INIT PARAMS FOR TRAINING
210
+ running_loss = 0.0 # tracks avg loss over mini batches of 100
211
+ train_loss = []
212
+ train_au_loss = []
213
+ train_bvp_loss = []
214
+ train_resp_loss = []
215
+ self.model.train() # put model in train mode
216
+
217
+ # MODEL TRAINING
218
+ tbar = tqdm(data_loader["train"], ncols=80)
219
+ for idx, batch in enumerate(tbar):
220
+ tbar.set_description("Train epoch %s" % epoch)
221
+
222
+ # GATHER AND FORMAT BATCH DATA
223
+ data, labels = batch[0], batch[1]
224
+ data, labels = self.format_data_shape(data, labels)
225
+ data, labels = self.send_data_to_device(data, labels)
226
+
227
+ # FOWARD AND BACK PROPOGATE THROUGH MODEL
228
+ self.optimizer.zero_grad()
229
+ au_out, bvp_out, resp_out = self.model(data)
230
+ au_loss = self.criterionAU(au_out, labels[:, self.label_idx_train_au, 0]) # au loss
231
+ bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_train_bvp, 0]) # bvp loss
232
+ resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_train_resp, 0]) # resp loss
233
+ loss = au_loss + bvp_loss + resp_loss # sum losses
234
+ loss.backward()
235
+
236
+ # Append the current learning rate to the list
237
+ lrs.append(self.scheduler.get_last_lr())
238
+
239
+ self.optimizer.step()
240
+ # self.scaler.scale(loss).backward() # Loss scaling
241
+ # self.scaler.step(self.optimizer)
242
+ # self.scaler.update()
243
+
244
+
245
+
246
+
247
+ # UPDATE RUNNING LOSS AND PRINTED TERMINAL OUTPUT AND SAVED LOSSES
248
+ train_loss.append(loss.item())
249
+ train_au_loss.append(au_loss.item())
250
+ train_bvp_loss.append(bvp_loss.item())
251
+ train_resp_loss.append(resp_loss.item())
252
+
253
+ running_loss += loss.item()
254
+ if idx % 100 == 99: # print every 100 mini-batches
255
+ print(f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
256
+ running_loss = 0.0
257
+
258
+
259
+ tbar.set_postfix({"loss:": loss.item(), "lr:": self.optimizer.param_groups[0]["lr"]})
260
+
261
+ # APPEND EPOCH LOSS LIST TO TRAINING LOSS DICTIONARY
262
+ train_loss_dict[epoch] = train_loss
263
+ train_au_loss_dict[epoch] = train_au_loss
264
+ train_bvp_loss_dict[epoch] = train_bvp_loss
265
+ train_resp_loss_dict[epoch] = train_resp_loss
266
+
267
+ print('')
268
+
269
+ # Append the mean training loss for the epoch
270
+ mean_training_losses.append(np.mean(train_loss))
271
+
272
+ # SAVE MODEL FOR THIS EPOCH
273
+ self.save_model(epoch)
274
+
275
+ # VALIDATION (IF ENABLED)
276
+ if not self.config.TEST.USE_LAST_EPOCH:
277
+
278
+ # Get validation losses
279
+ valid_loss, valid_au_loss, valid_bvp_loss, valid_resp_loss = self.valid(data_loader)
280
+ mean_valid_losses.append(valid_loss)
281
+ val_loss_dict[epoch] = valid_loss
282
+ val_au_loss_dict[epoch] = valid_au_loss
283
+ val_bvp_loss_dict[epoch] = valid_bvp_loss
284
+ val_resp_loss_dict[epoch] = valid_resp_loss
285
+ print('validation loss: ', valid_loss)
286
+
287
+ # Update used model
288
+ if self.model_to_use == 'best_epoch' and (valid_loss < min_valid_loss):
289
+ min_valid_loss = valid_loss
290
+ self.used_epoch = epoch
291
+ print("Update best model! Best epoch: {}".format(self.used_epoch))
292
+ elif self.model_to_use == 'last_epoch':
293
+ self.used_epoch = epoch
294
+
295
+ # VALIDATION (NOT ENABLED)
296
+ else:
297
+ self.used_epoch = epoch
298
+
299
+ print('')
300
+
301
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
302
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
303
+
304
+ # PRINT MODEL TO BE USED FOR TESTING
305
+ print("Used model trained epoch:{}, val_loss:{}".format(self.used_epoch, min_valid_loss))
306
+ print('')
307
+
308
+
309
+
310
+ def valid(self, data_loader):
311
+ """ Model evaluation on the validation dataset."""
312
+
313
+ if data_loader["valid"] is None:
314
+ raise ValueError("No data for valid")
315
+
316
+ print("===Validating===")
317
+
318
+ # INIT PARAMS FOR VALIDATION
319
+ valid_loss = []
320
+ valid_au_loss = []
321
+ valid_bvp_loss = []
322
+ valid_resp_loss = []
323
+ self.model.eval()
324
+
325
+ # MODEL VALIDATION
326
+ with torch.no_grad():
327
+ vbar = tqdm(data_loader["valid"], ncols=80)
328
+ for valid_idx, valid_batch in enumerate(vbar):
329
+ vbar.set_description("Validation")
330
+
331
+ # GATHER AND FORMAT BATCH DATA
332
+ data, labels = valid_batch[0], valid_batch[1]
333
+ data, labels = self.format_data_shape(data, labels)
334
+ data, labels = self.send_data_to_device(data, labels)
335
+
336
+ au_out, bvp_out, resp_out = self.model(data)
337
+ au_loss = self.criterionAU(au_out, labels[:, self.label_idx_valid_au, 0]) # au loss
338
+ bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_valid_bvp, 0]) # bvp loss
339
+ resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_valid_resp, 0]) # resp loss
340
+ loss = au_loss + bvp_loss + resp_loss # sum losses
341
+
342
+ # APPEND VAL LOSS
343
+ valid_loss.append(loss.item())
344
+ valid_au_loss.append(au_loss.item())
345
+ valid_bvp_loss.append(bvp_loss.item())
346
+ valid_resp_loss.append(resp_loss.item())
347
+ vbar.set_postfix(loss=loss.item())
348
+
349
+ valid_loss = np.asarray(valid_loss)
350
+ valid_au_loss = np.asarray(valid_au_loss)
351
+ valid_bvp_loss = np.asarray(valid_bvp_loss)
352
+ valid_resp_loss = np.asarray(valid_resp_loss)
353
+ return np.mean(valid_loss), np.mean(valid_au_loss), np.mean(valid_bvp_loss), np.mean(valid_resp_loss)
354
+
355
+
356
+
357
+ def test(self, data_loader):
358
+ """ Model evaluation on the testing dataset."""
359
+
360
+ print("===Testing===")
361
+ print('')
362
+
363
+ # SETUP
364
+ if data_loader["test"] is None:
365
+ raise ValueError("No data for test")
366
+
367
+ # Change chunk length to be test chunk length
368
+ self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
369
+
370
+ # ARRAYS TO SAVE (PREDICTIONS AND METRICS ARRAYS)
371
+ preds_dict_au = dict()
372
+ labels_dict_au = dict()
373
+ preds_dict_bvp = dict()
374
+ labels_dict_bvp = dict()
375
+ preds_dict_resp = dict()
376
+ labels_dict_resp = dict()
377
+
378
+ # IF ONLY_TEST MODE LOAD PRETRAINED MODEL
379
+ if self.config.TOOLBOX_MODE == "only_test":
380
+ model_path = self.config.INFERENCE.MODEL_PATH
381
+ print("Testing uses pretrained model!")
382
+ print('Model path:', model_path)
383
+ if not os.path.exists(model_path):
384
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
385
+
386
+ # IF USING MODEL FROM TRAINING
387
+ else:
388
+ model_path = os.path.join(self.model_dir,
389
+ self.model_file_name + '_Epoch' + str(self.used_epoch) + '.pth')
390
+ print("Testing uses non-pretrained model!")
391
+ print('Model path:', model_path)
392
+ if not os.path.exists(model_path):
393
+ raise ValueError("Something went wrong... cant find trained model...")
394
+ print('')
395
+
396
+ # LOAD ABOVED SPECIFIED MODEL FOR TESTING
397
+ self.model.load_state_dict(torch.load(model_path))
398
+ self.model = self.model.to(self.device)
399
+ self.model.eval()
400
+
401
+ # MODEL TESTING
402
+ print("Running model evaluation on the testing dataset!")
403
+ with torch.no_grad():
404
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
405
+
406
+ # PROCESSING - ANALYSIS, METRICS, SAVING OUT DATA
407
+ batch_size = test_batch[1].shape[0] # get batch size
408
+
409
+ # GATHER AND FORMAT BATCH DATA
410
+ data, labels = test_batch[0], test_batch[1]
411
+ data, labels = self.format_data_shape(data, labels)
412
+ data, labels = self.send_data_to_device(data, labels)
413
+
414
+ # Weird dataloader bug is causing the final training batch to be of size 0...
415
+ if labels.shape[0] == 0:
416
+ continue
417
+
418
+ # GET MODEL PREDICTIONS
419
+ au_out, bvp_out, resp_out = self.model(data)
420
+ au_out = torch.sigmoid(au_out)
421
+
422
+ # GATHER AND SLICE LABELS USED FOR TEST DATASET
423
+ TEST_AU = False
424
+ if len(self.label_idx_test_au) > 0: # if test dataset has AU
425
+ TEST_AU = True
426
+ labels_au = labels[:, self.label_idx_test_au]
427
+ else: # if not set whole AU labels array to -1
428
+ labels_au = np.ones((batch_size, len(self.label_idx_train_au)))
429
+ labels_au = -1 * labels_au
430
+ # labels_au = torch.from_numpy(labels_au)
431
+
432
+ TEST_BVP = False
433
+ if len(self.label_idx_test_bvp) > 0: # if test dataset has BVP
434
+ TEST_BVP = True
435
+ labels_bvp = labels[:, self.label_idx_test_bvp]
436
+ else: # if not set whole BVP labels array to -1
437
+ labels_bvp = np.ones((batch_size, len(self.label_idx_train_bvp)))
438
+ labels_bvp = -1 * labels_bvp
439
+ # labels_bvp = torch.from_numpy(labels_bvp)
440
+
441
+ TEST_RESP = False
442
+ if len(self.label_idx_test_resp) > 0: # if test dataset has BVP
443
+ TEST_RESP = True
444
+ labels_resp = labels[:, self.label_idx_test_resp]
445
+ else: # if not set whole BVP labels array to -1
446
+ labels_resp = np.ones((batch_size, len(self.label_idx_train_resp)))
447
+ labels_resp = -1 * labels_resp
448
+ # labels_resp = torch.from_numpy(labels_resp)
449
+
450
+ # ITERATE THROUGH BATCH, SORT, AND ADD TO CORRECT DICTIONARY
451
+ for idx in range(batch_size):
452
+
453
+ # if the labels are cut off due to TSM dataformating
454
+ if idx * self.chunk_len >= labels.shape[0] and self.using_TSM:
455
+ continue
456
+
457
+ subj_index = test_batch[2][idx]
458
+ sort_index = int(test_batch[3][idx])
459
+
460
+ # add subject to prediction / label arrays
461
+ if subj_index not in preds_dict_bvp.keys():
462
+ preds_dict_au[subj_index] = dict()
463
+ labels_dict_au[subj_index] = dict()
464
+ preds_dict_bvp[subj_index] = dict()
465
+ labels_dict_bvp[subj_index] = dict()
466
+ preds_dict_resp[subj_index] = dict()
467
+ labels_dict_resp[subj_index] = dict()
468
+
469
+ # append predictions and labels to subject dict
470
+ preds_dict_au[subj_index][sort_index] = au_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
471
+ labels_dict_au[subj_index][sort_index] = labels_au[idx * self.chunk_len:(idx + 1) * self.chunk_len]
472
+ preds_dict_bvp[subj_index][sort_index] = bvp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
473
+ labels_dict_bvp[subj_index][sort_index] = labels_bvp[idx * self.chunk_len:(idx + 1) * self.chunk_len]
474
+ preds_dict_resp[subj_index][sort_index] = resp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
475
+ labels_dict_resp[subj_index][sort_index] = labels_resp[idx * self.chunk_len:(idx + 1) * self.chunk_len]
476
+
477
+ # Calculate Eval Metrics
478
+ bvp_metric_dict = calculate_bvp_metrics(preds_dict_bvp, labels_dict_bvp, self.config)
479
+ resp_metric_dict = calculate_resp_metrics(preds_dict_resp, labels_dict_resp, self.config)
480
+ au_metric_dict = calculate_bp4d_au_metrics(preds_dict_au, labels_dict_au, self.config)
481
+
482
+
483
+
484
+
neural_methods/trainer/DeepPhysTrainer.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trainer for DeepPhys."""
2
+
3
+ import logging
4
+ import os
5
+ from collections import OrderedDict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.optim as optim
10
+ from evaluation.metrics import calculate_metrics
11
+ from neural_methods.loss.NegPearsonLoss import Neg_Pearson
12
+ from neural_methods.model.DeepPhys import DeepPhys
13
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
14
+ from tqdm import tqdm
15
+
16
+
17
+ class DeepPhysTrainer(BaseTrainer):
18
+
19
+ def __init__(self, config, data_loader):
20
+ """Inits parameters from args and the writer for TensorboardX."""
21
+ super().__init__()
22
+ self.device = torch.device(config.DEVICE)
23
+ self.max_epoch_num = config.TRAIN.EPOCHS
24
+ self.model_dir = config.MODEL.MODEL_DIR
25
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
26
+ self.batch_size = config.TRAIN.BATCH_SIZE
27
+ self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
28
+ self.config = config
29
+ self.min_valid_loss = None
30
+ self.best_epoch = 0
31
+
32
+ if config.TOOLBOX_MODE == "train_and_test":
33
+ self.model = DeepPhys(img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(self.device)
34
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
35
+
36
+ self.num_train_batches = len(data_loader["train"])
37
+ self.criterion = torch.nn.MSELoss()
38
+ self.optimizer = optim.AdamW(
39
+ self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0)
40
+ # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
41
+ self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
42
+ self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
43
+ elif config.TOOLBOX_MODE == "only_test":
44
+ self.model = DeepPhys(img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(self.device)
45
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
46
+ else:
47
+ raise ValueError("DeepPhys trainer initialized in incorrect toolbox mode!")
48
+
49
+ def train(self, data_loader):
50
+ """Training routine for model"""
51
+ if data_loader["train"] is None:
52
+ raise ValueError("No data for train")
53
+
54
+ mean_training_losses = []
55
+ mean_valid_losses = []
56
+ lrs = []
57
+ for epoch in range(self.max_epoch_num):
58
+ print('')
59
+ print(f"====Training Epoch: {epoch}====")
60
+ running_loss = 0.0
61
+ train_loss = []
62
+ self.model.train()
63
+ # Model Training
64
+ tbar = tqdm(data_loader["train"], ncols=80)
65
+ for idx, batch in enumerate(tbar):
66
+ tbar.set_description("Train epoch %s" % epoch)
67
+ data, labels = batch[0].to(
68
+ self.device), batch[1].to(self.device)
69
+ N, D, C, H, W = data.shape
70
+ data = data.view(N * D, C, H, W)
71
+ labels = labels.view(-1, 1)
72
+ self.optimizer.zero_grad()
73
+ pred_ppg = self.model(data)
74
+ loss = self.criterion(pred_ppg, labels)
75
+ loss.backward()
76
+
77
+ # Append the current learning rate to the list
78
+ lrs.append(self.scheduler.get_last_lr())
79
+
80
+ self.optimizer.step()
81
+ self.scheduler.step()
82
+ running_loss += loss.item()
83
+ if idx % 100 == 99: # print every 100 mini-batches
84
+ print(
85
+ f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
86
+ running_loss = 0.0
87
+ train_loss.append(loss.item())
88
+ tbar.set_postfix({"loss": loss.item(), "lr": self.optimizer.param_groups[0]["lr"]})
89
+
90
+ # Append the mean training loss for the epoch
91
+ mean_training_losses.append(np.mean(train_loss))
92
+
93
+ self.save_model(epoch)
94
+ if not self.config.TEST.USE_LAST_EPOCH:
95
+ valid_loss = self.valid(data_loader)
96
+ mean_valid_losses.append(valid_loss)
97
+ print('validation loss: ', valid_loss)
98
+ if self.min_valid_loss is None:
99
+ self.min_valid_loss = valid_loss
100
+ self.best_epoch = epoch
101
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
102
+ elif (valid_loss < self.min_valid_loss):
103
+ self.min_valid_loss = valid_loss
104
+ self.best_epoch = epoch
105
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
106
+ if not self.config.TEST.USE_LAST_EPOCH:
107
+ print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss))
108
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
109
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
110
+
111
+ def valid(self, data_loader):
112
+ """ Model evaluation on the validation dataset."""
113
+ if data_loader["valid"] is None:
114
+ raise ValueError("No data for valid")
115
+
116
+ print('')
117
+ print("===Validating===")
118
+ valid_loss = []
119
+ self.model.eval()
120
+ valid_step = 0
121
+ with torch.no_grad():
122
+ vbar = tqdm(data_loader["valid"], ncols=80)
123
+ for valid_idx, valid_batch in enumerate(vbar):
124
+ vbar.set_description("Validation")
125
+ data_valid, labels_valid = valid_batch[0].to(
126
+ self.device), valid_batch[1].to(self.device)
127
+ N, D, C, H, W = data_valid.shape
128
+ data_valid = data_valid.view(N * D, C, H, W)
129
+ labels_valid = labels_valid.view(-1, 1)
130
+ pred_ppg_valid = self.model(data_valid)
131
+ loss = self.criterion(pred_ppg_valid, labels_valid)
132
+ valid_loss.append(loss.item())
133
+ valid_step += 1
134
+ vbar.set_postfix(loss=loss.item())
135
+ valid_loss = np.asarray(valid_loss)
136
+ return np.mean(valid_loss)
137
+
138
+ def test(self, data_loader):
139
+ """ Model evaluation on the testing dataset."""
140
+ if data_loader["test"] is None:
141
+ raise ValueError("No data for test")
142
+ config = self.config
143
+
144
+ print('')
145
+ print("===Testing===")
146
+
147
+ # Change chunk length to be test chunk length
148
+ self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
149
+
150
+ predictions = dict()
151
+ labels = dict()
152
+ if self.config.TOOLBOX_MODE == "only_test":
153
+ if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
154
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
155
+ self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")))
156
+ print("Testing uses pretrained model!")
157
+ else:
158
+ if self.config.TEST.USE_LAST_EPOCH:
159
+ last_epoch_model_path = os.path.join(
160
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
161
+ print("Testing uses last epoch as non-pretrained model!")
162
+ print(last_epoch_model_path)
163
+ self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")))
164
+ else:
165
+ best_model_path = os.path.join(
166
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
167
+ print("Testing uses best epoch selected using model selection as non-pretrained model!")
168
+ print(best_model_path)
169
+ self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")))
170
+
171
+ self.model = self.model.to(self.config.DEVICE)
172
+ self.model.eval()
173
+ print("Running model evaluation on the testing dataset!")
174
+ with torch.no_grad():
175
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
176
+ batch_size = test_batch[0].shape[0]
177
+ data_test, labels_test = test_batch[0].to(
178
+ self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
179
+ N, D, C, H, W = data_test.shape
180
+ data_test = data_test.view(N * D, C, H, W)
181
+ labels_test = labels_test.view(-1, 1)
182
+ pred_ppg_test = self.model(data_test)
183
+
184
+ if self.config.TEST.OUTPUT_SAVE_DIR:
185
+ labels_test = labels_test.cpu()
186
+ pred_ppg_test = pred_ppg_test.cpu()
187
+
188
+ for idx in range(batch_size):
189
+ subj_index = test_batch[2][idx]
190
+ sort_index = int(test_batch[3][idx])
191
+ if subj_index not in predictions.keys():
192
+ predictions[subj_index] = dict()
193
+ labels[subj_index] = dict()
194
+ predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
195
+ labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
196
+
197
+ print('')
198
+ calculate_metrics(predictions, labels, self.config)
199
+ if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
200
+ self.save_test_outputs(predictions, labels, self.config)
201
+
202
+ def save_model(self, index):
203
+ """Inits parameters from args and the writer for TensorboardX."""
204
+ if not os.path.exists(self.model_dir):
205
+ os.makedirs(self.model_dir)
206
+ model_path = os.path.join(
207
+ self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
208
+ torch.save(self.model.state_dict(), model_path)
209
+
neural_methods/trainer/DeepPhysTrainer.py.backup ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trainer for DeepPhys."""
2
+
3
+ import logging
4
+ import os
5
+ from collections import OrderedDict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.optim as optim
10
+ from evaluation.metrics import calculate_metrics
11
+ from neural_methods.loss.NegPearsonLoss import Neg_Pearson
12
+ from neural_methods.model.DeepPhys import DeepPhys
13
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
14
+ from tqdm import tqdm
15
+
16
+
17
+ class DeepPhysTrainer(BaseTrainer):
18
+
19
+ def __init__(self, config, data_loader):
20
+ """Inits parameters from args and the writer for TensorboardX."""
21
+ super().__init__()
22
+ self.device = torch.device(config.DEVICE)
23
+ self.max_epoch_num = config.TRAIN.EPOCHS
24
+ self.model_dir = config.MODEL.MODEL_DIR
25
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
26
+ self.batch_size = config.TRAIN.BATCH_SIZE
27
+ self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
28
+ self.config = config
29
+ self.min_valid_loss = None
30
+ self.best_epoch = 0
31
+
32
+ if config.TOOLBOX_MODE == "train_and_test":
33
+ self.model = DeepPhys(img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(self.device)
34
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
35
+
36
+ self.num_train_batches = len(data_loader["train"])
37
+ self.criterion = torch.nn.MSELoss()
38
+ self.optimizer = optim.AdamW(
39
+ self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0)
40
+ # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
41
+ self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
42
+ self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
43
+ elif config.TOOLBOX_MODE == "only_test":
44
+ self.model = DeepPhys(img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(self.device)
45
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
46
+ else:
47
+ raise ValueError("DeepPhys trainer initialized in incorrect toolbox mode!")
48
+
49
+ def train(self, data_loader):
50
+ """Training routine for model"""
51
+ if data_loader["train"] is None:
52
+ raise ValueError("No data for train")
53
+
54
+ mean_training_losses = []
55
+ mean_valid_losses = []
56
+ lrs = []
57
+ for epoch in range(self.max_epoch_num):
58
+ print('')
59
+ print(f"====Training Epoch: {epoch}====")
60
+ running_loss = 0.0
61
+ train_loss = []
62
+ self.model.train()
63
+ # Model Training
64
+ tbar = tqdm(data_loader["train"], ncols=80)
65
+ for idx, batch in enumerate(tbar):
66
+ tbar.set_description("Train epoch %s" % epoch)
67
+ data, labels = batch[0].to(
68
+ self.device), batch[1].to(self.device)
69
+ N, D, C, H, W = data.shape
70
+ data = data.view(N * D, C, H, W)
71
+ labels = labels.view(-1, 1)
72
+ self.optimizer.zero_grad()
73
+ pred_ppg = self.model(data)
74
+ loss = self.criterion(pred_ppg, labels)
75
+ loss.backward()
76
+
77
+ # Append the current learning rate to the list
78
+ lrs.append(self.scheduler.get_last_lr())
79
+
80
+ self.optimizer.step()
81
+ self.scheduler.step()
82
+ running_loss += loss.item()
83
+ if idx % 100 == 99: # print every 100 mini-batches
84
+ print(
85
+ f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
86
+ running_loss = 0.0
87
+ train_loss.append(loss.item())
88
+ tbar.set_postfix({"loss": loss.item(), "lr": self.optimizer.param_groups[0]["lr"]})
89
+
90
+ # Append the mean training loss for the epoch
91
+ mean_training_losses.append(np.mean(train_loss))
92
+
93
+ self.save_model(epoch)
94
+ if not self.config.TEST.USE_LAST_EPOCH:
95
+ valid_loss = self.valid(data_loader)
96
+ mean_valid_losses.append(valid_loss)
97
+ print('validation loss: ', valid_loss)
98
+ if self.min_valid_loss is None:
99
+ self.min_valid_loss = valid_loss
100
+ self.best_epoch = epoch
101
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
102
+ elif (valid_loss < self.min_valid_loss):
103
+ self.min_valid_loss = valid_loss
104
+ self.best_epoch = epoch
105
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
106
+ if not self.config.TEST.USE_LAST_EPOCH:
107
+ print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss))
108
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
109
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
110
+
111
+ def valid(self, data_loader):
112
+ """ Model evaluation on the validation dataset."""
113
+ if data_loader["valid"] is None:
114
+ raise ValueError("No data for valid")
115
+
116
+ print('')
117
+ print("===Validating===")
118
+ valid_loss = []
119
+ self.model.eval()
120
+ valid_step = 0
121
+ with torch.no_grad():
122
+ vbar = tqdm(data_loader["valid"], ncols=80)
123
+ for valid_idx, valid_batch in enumerate(vbar):
124
+ vbar.set_description("Validation")
125
+ data_valid, labels_valid = valid_batch[0].to(
126
+ self.device), valid_batch[1].to(self.device)
127
+ N, D, C, H, W = data_valid.shape
128
+ data_valid = data_valid.view(N * D, C, H, W)
129
+ labels_valid = labels_valid.view(-1, 1)
130
+ pred_ppg_valid = self.model(data_valid)
131
+ loss = self.criterion(pred_ppg_valid, labels_valid)
132
+ valid_loss.append(loss.item())
133
+ valid_step += 1
134
+ vbar.set_postfix(loss=loss.item())
135
+ valid_loss = np.asarray(valid_loss)
136
+ return np.mean(valid_loss)
137
+
138
+ def test(self, data_loader):
139
+ """ Model evaluation on the testing dataset."""
140
+ if data_loader["test"] is None:
141
+ raise ValueError("No data for test")
142
+ config = self.config
143
+
144
+ print('')
145
+ print("===Testing===")
146
+
147
+ # Change chunk length to be test chunk length
148
+ self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
149
+
150
+ predictions = dict()
151
+ labels = dict()
152
+ if self.config.TOOLBOX_MODE == "only_test":
153
+ if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
154
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
155
+ self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=self.device))
156
+ print("Testing uses pretrained model!")
157
+ else:
158
+ if self.config.TEST.USE_LAST_EPOCH:
159
+ last_epoch_model_path = os.path.join(
160
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
161
+ print("Testing uses last epoch as non-pretrained model!")
162
+ print(last_epoch_model_path)
163
+ self.model.load_state_dict(torch.load(last_epoch_model_path))
164
+ else:
165
+ best_model_path = os.path.join(
166
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
167
+ print("Testing uses best epoch selected using model selection as non-pretrained model!")
168
+ print(best_model_path)
169
+ self.model.load_state_dict(torch.load(best_model_path))
170
+
171
+ self.model = self.model.to(self.config.DEVICE)
172
+ self.model.eval()
173
+ print("Running model evaluation on the testing dataset!")
174
+ with torch.no_grad():
175
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
176
+ batch_size = test_batch[0].shape[0]
177
+ data_test, labels_test = test_batch[0].to(
178
+ self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
179
+ N, D, C, H, W = data_test.shape
180
+ data_test = data_test.view(N * D, C, H, W)
181
+ labels_test = labels_test.view(-1, 1)
182
+ pred_ppg_test = self.model(data_test)
183
+
184
+ if self.config.TEST.OUTPUT_SAVE_DIR:
185
+ labels_test = labels_test.cpu()
186
+ pred_ppg_test = pred_ppg_test.cpu()
187
+
188
+ for idx in range(batch_size):
189
+ subj_index = test_batch[2][idx]
190
+ sort_index = int(test_batch[3][idx])
191
+ if subj_index not in predictions.keys():
192
+ predictions[subj_index] = dict()
193
+ labels[subj_index] = dict()
194
+ predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
195
+ labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
196
+
197
+ print('')
198
+ calculate_metrics(predictions, labels, self.config)
199
+ if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
200
+ self.save_test_outputs(predictions, labels, self.config)
201
+
202
+ def save_model(self, index):
203
+ """Inits parameters from args and the writer for TensorboardX."""
204
+ if not os.path.exists(self.model_dir):
205
+ os.makedirs(self.model_dir)
206
+ model_path = os.path.join(
207
+ self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
208
+ torch.save(self.model.state_dict(), model_path)
209
+
neural_methods/trainer/EfficientPhysTrainer.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trainer for EfficientPhys."""
2
+
3
+ import logging
4
+ import os
5
+ from collections import OrderedDict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.optim as optim
10
+ from evaluation.metrics import calculate_metrics
11
+ from neural_methods.loss.NegPearsonLoss import Neg_Pearson
12
+ from neural_methods.model.EfficientPhys import EfficientPhys
13
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
14
+ from tqdm import tqdm
15
+
16
+
17
+ class EfficientPhysTrainer(BaseTrainer):
18
+
19
+ def __init__(self, config, data_loader):
20
+ """Inits parameters from args and the writer for TensorboardX."""
21
+ super().__init__()
22
+ self.device = torch.device(config.DEVICE)
23
+ self.frame_depth = config.MODEL.EFFICIENTPHYS.FRAME_DEPTH
24
+ self.max_epoch_num = config.TRAIN.EPOCHS
25
+ self.model_dir = config.MODEL.MODEL_DIR
26
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
27
+ self.batch_size = config.TRAIN.BATCH_SIZE
28
+ self.num_of_gpu = config.NUM_OF_GPU_TRAIN
29
+ self.base_len = self.num_of_gpu * self.frame_depth
30
+ self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
31
+ self.config = config
32
+ self.min_valid_loss = None
33
+ self.best_epoch = 0
34
+
35
+ if config.TOOLBOX_MODE == "train_and_test":
36
+ self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(
37
+ self.device)
38
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
39
+
40
+ self.num_train_batches = len(data_loader["train"])
41
+ self.criterion = torch.nn.MSELoss()
42
+ self.optimizer = optim.AdamW(
43
+ self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0)
44
+ # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
45
+ self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
46
+ self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
47
+ elif config.TOOLBOX_MODE == "only_test":
48
+ self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(
49
+ self.device)
50
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
51
+ else:
52
+ raise ValueError("EfficientPhys trainer initialized in incorrect toolbox mode!")
53
+
54
+ def train(self, data_loader):
55
+ """Training routine for model"""
56
+ if data_loader["train"] is None:
57
+ raise ValueError("No data for train")
58
+
59
+ mean_training_losses = []
60
+ mean_valid_losses = []
61
+ lrs = []
62
+ for epoch in range(self.max_epoch_num):
63
+ print('')
64
+ print(f"====Training Epoch: {epoch}====")
65
+ running_loss = 0.0
66
+ train_loss = []
67
+ self.model.train()
68
+ # Model Training
69
+ tbar = tqdm(data_loader["train"], ncols=80)
70
+ for idx, batch in enumerate(tbar):
71
+ tbar.set_description("Train epoch %s" % epoch)
72
+ data, labels = batch[0].to(
73
+ self.device), batch[1].to(self.device)
74
+ N, D, C, H, W = data.shape
75
+ data = data.view(N * D, C, H, W)
76
+ labels = labels.view(-1, 1)
77
+ data = data[:(N * D) // self.base_len * self.base_len]
78
+ # Add one more frame for EfficientPhys since it does torch.diff for the input
79
+ last_frame = torch.unsqueeze(data[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
80
+ data = torch.cat((data, last_frame), 0)
81
+ labels = labels[:(N * D) // self.base_len * self.base_len]
82
+ self.optimizer.zero_grad()
83
+ pred_ppg = self.model(data)
84
+ loss = self.criterion(pred_ppg, labels)
85
+ loss.backward()
86
+
87
+ # Append the current learning rate to the list
88
+ lrs.append(self.scheduler.get_last_lr())
89
+
90
+ self.optimizer.step()
91
+ self.scheduler.step()
92
+ running_loss += loss.item()
93
+ if idx % 100 == 99: # print every 100 mini-batches
94
+ print(
95
+ f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
96
+ running_loss = 0.0
97
+ train_loss.append(loss.item())
98
+ tbar.set_postfix(loss=loss.item())
99
+
100
+ # Append the mean training loss for the epoch
101
+ mean_training_losses.append(np.mean(train_loss))
102
+
103
+ self.save_model(epoch)
104
+ if not self.config.TEST.USE_LAST_EPOCH:
105
+ valid_loss = self.valid(data_loader)
106
+ mean_valid_losses.append(valid_loss)
107
+ print('validation loss: ', valid_loss)
108
+ if self.min_valid_loss is None:
109
+ self.min_valid_loss = valid_loss
110
+ self.best_epoch = epoch
111
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
112
+ elif (valid_loss < self.min_valid_loss):
113
+ self.min_valid_loss = valid_loss
114
+ self.best_epoch = epoch
115
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
116
+ if not self.config.TEST.USE_LAST_EPOCH:
117
+ print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss))
118
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
119
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
120
+
121
+ def valid(self, data_loader):
122
+ """ Model evaluation on the validation dataset."""
123
+ if data_loader["valid"] is None:
124
+ raise ValueError("No data for valid")
125
+
126
+ print('')
127
+ print("===Validating===")
128
+ valid_loss = []
129
+ self.model.eval()
130
+ valid_step = 0
131
+ with torch.no_grad():
132
+ vbar = tqdm(data_loader["valid"], ncols=80)
133
+ for valid_idx, valid_batch in enumerate(vbar):
134
+ vbar.set_description("Validation")
135
+ data_valid, labels_valid = valid_batch[0].to(
136
+ self.device), valid_batch[1].to(self.device)
137
+ N, D, C, H, W = data_valid.shape
138
+ data_valid = data_valid.view(N * D, C, H, W)
139
+ labels_valid = labels_valid.view(-1, 1)
140
+ data_valid = data_valid[:(N * D) // self.base_len * self.base_len]
141
+ # Add one more frame for EfficientPhys since it does torch.diff for the input
142
+ last_frame = torch.unsqueeze(data_valid[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
143
+ data_valid = torch.cat((data_valid, last_frame), 0)
144
+ labels_valid = labels_valid[:(N * D) // self.base_len * self.base_len]
145
+ pred_ppg_valid = self.model(data_valid)
146
+ loss = self.criterion(pred_ppg_valid, labels_valid)
147
+ valid_loss.append(loss.item())
148
+ valid_step += 1
149
+ vbar.set_postfix(loss=loss.item())
150
+ valid_loss = np.asarray(valid_loss)
151
+ return np.mean(valid_loss)
152
+
153
+ def test(self, data_loader):
154
+ """ Model evaluation on the testing dataset."""
155
+ if data_loader["test"] is None:
156
+ raise ValueError("No data for test")
157
+
158
+ print('')
159
+ print("===Testing===")
160
+
161
+ # Change chunk length to be test chunk length
162
+ self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
163
+
164
+ predictions = dict()
165
+ labels = dict()
166
+
167
+ if self.config.TOOLBOX_MODE == "only_test":
168
+ if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
169
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
170
+ self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")))
171
+ print("Testing uses pretrained model!")
172
+ else:
173
+ if self.config.TEST.USE_LAST_EPOCH:
174
+ last_epoch_model_path = os.path.join(
175
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
176
+ print("Testing uses last epoch as non-pretrained model!")
177
+ print(last_epoch_model_path)
178
+ self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")))
179
+ else:
180
+ best_model_path = os.path.join(
181
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
182
+ print("Testing uses best epoch selected using model selection as non-pretrained model!")
183
+ print(best_model_path)
184
+ self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")))
185
+
186
+ self.model = self.model.to(self.config.DEVICE)
187
+ self.model.eval()
188
+ print("Running model evaluation on the testing dataset!")
189
+ with torch.no_grad():
190
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
191
+ batch_size = test_batch[0].shape[0]
192
+ data_test, labels_test = test_batch[0].to(
193
+ self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
194
+ N, D, C, H, W = data_test.shape
195
+ data_test = data_test.view(N * D, C, H, W)
196
+ labels_test = labels_test.view(-1, 1)
197
+ data_test = data_test[:(N * D) // self.base_len * self.base_len]
198
+ # Add one more frame for EfficientPhys since it does torch.diff for the input
199
+ last_frame = torch.unsqueeze(data_test[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
200
+ data_test = torch.cat((data_test, last_frame), 0)
201
+ labels_test = labels_test[:(N * D) // self.base_len * self.base_len]
202
+ pred_ppg_test = self.model(data_test)
203
+
204
+ if self.config.TEST.OUTPUT_SAVE_DIR:
205
+ labels_test = labels_test.cpu()
206
+ pred_ppg_test = pred_ppg_test.cpu()
207
+
208
+ for idx in range(batch_size):
209
+ subj_index = test_batch[2][idx]
210
+ sort_index = int(test_batch[3][idx])
211
+ if subj_index not in predictions.keys():
212
+ predictions[subj_index] = dict()
213
+ labels[subj_index] = dict()
214
+ predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
215
+ labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
216
+
217
+ print('')
218
+ calculate_metrics(predictions, labels, self.config)
219
+ if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
220
+ self.save_test_outputs(predictions, labels, self.config)
221
+
222
+ def save_model(self, index):
223
+ if not os.path.exists(self.model_dir):
224
+ os.makedirs(self.model_dir)
225
+ model_path = os.path.join(
226
+ self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
227
+ torch.save(self.model.state_dict(), model_path)
228
+ print('Saved Model Path: ', model_path)
neural_methods/trainer/EfficientPhysTrainer.py.backup ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trainer for EfficientPhys."""
2
+
3
+ import logging
4
+ import os
5
+ from collections import OrderedDict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.optim as optim
10
+ from evaluation.metrics import calculate_metrics
11
+ from neural_methods.loss.NegPearsonLoss import Neg_Pearson
12
+ from neural_methods.model.EfficientPhys import EfficientPhys
13
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
14
+ from tqdm import tqdm
15
+
16
+
17
+ class EfficientPhysTrainer(BaseTrainer):
18
+
19
+ def __init__(self, config, data_loader):
20
+ """Inits parameters from args and the writer for TensorboardX."""
21
+ super().__init__()
22
+ self.device = torch.device(config.DEVICE)
23
+ self.frame_depth = config.MODEL.EFFICIENTPHYS.FRAME_DEPTH
24
+ self.max_epoch_num = config.TRAIN.EPOCHS
25
+ self.model_dir = config.MODEL.MODEL_DIR
26
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
27
+ self.batch_size = config.TRAIN.BATCH_SIZE
28
+ self.num_of_gpu = config.NUM_OF_GPU_TRAIN
29
+ self.base_len = self.num_of_gpu * self.frame_depth
30
+ self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
31
+ self.config = config
32
+ self.min_valid_loss = None
33
+ self.best_epoch = 0
34
+
35
+ if config.TOOLBOX_MODE == "train_and_test":
36
+ self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(
37
+ self.device)
38
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
39
+
40
+ self.num_train_batches = len(data_loader["train"])
41
+ self.criterion = torch.nn.MSELoss()
42
+ self.optimizer = optim.AdamW(
43
+ self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0)
44
+ # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
45
+ self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
46
+ self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
47
+ elif config.TOOLBOX_MODE == "only_test":
48
+ self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(
49
+ self.device)
50
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
51
+ else:
52
+ raise ValueError("EfficientPhys trainer initialized in incorrect toolbox mode!")
53
+
54
+ def train(self, data_loader):
55
+ """Training routine for model"""
56
+ if data_loader["train"] is None:
57
+ raise ValueError("No data for train")
58
+
59
+ mean_training_losses = []
60
+ mean_valid_losses = []
61
+ lrs = []
62
+ for epoch in range(self.max_epoch_num):
63
+ print('')
64
+ print(f"====Training Epoch: {epoch}====")
65
+ running_loss = 0.0
66
+ train_loss = []
67
+ self.model.train()
68
+ # Model Training
69
+ tbar = tqdm(data_loader["train"], ncols=80)
70
+ for idx, batch in enumerate(tbar):
71
+ tbar.set_description("Train epoch %s" % epoch)
72
+ data, labels = batch[0].to(
73
+ self.device), batch[1].to(self.device)
74
+ N, D, C, H, W = data.shape
75
+ data = data.view(N * D, C, H, W)
76
+ labels = labels.view(-1, 1)
77
+ data = data[:(N * D) // self.base_len * self.base_len]
78
+ # Add one more frame for EfficientPhys since it does torch.diff for the input
79
+ last_frame = torch.unsqueeze(data[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
80
+ data = torch.cat((data, last_frame), 0)
81
+ labels = labels[:(N * D) // self.base_len * self.base_len]
82
+ self.optimizer.zero_grad()
83
+ pred_ppg = self.model(data)
84
+ loss = self.criterion(pred_ppg, labels)
85
+ loss.backward()
86
+
87
+ # Append the current learning rate to the list
88
+ lrs.append(self.scheduler.get_last_lr())
89
+
90
+ self.optimizer.step()
91
+ self.scheduler.step()
92
+ running_loss += loss.item()
93
+ if idx % 100 == 99: # print every 100 mini-batches
94
+ print(
95
+ f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
96
+ running_loss = 0.0
97
+ train_loss.append(loss.item())
98
+ tbar.set_postfix(loss=loss.item())
99
+
100
+ # Append the mean training loss for the epoch
101
+ mean_training_losses.append(np.mean(train_loss))
102
+
103
+ self.save_model(epoch)
104
+ if not self.config.TEST.USE_LAST_EPOCH:
105
+ valid_loss = self.valid(data_loader)
106
+ mean_valid_losses.append(valid_loss)
107
+ print('validation loss: ', valid_loss)
108
+ if self.min_valid_loss is None:
109
+ self.min_valid_loss = valid_loss
110
+ self.best_epoch = epoch
111
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
112
+ elif (valid_loss < self.min_valid_loss):
113
+ self.min_valid_loss = valid_loss
114
+ self.best_epoch = epoch
115
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
116
+ if not self.config.TEST.USE_LAST_EPOCH:
117
+ print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss))
118
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
119
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
120
+
121
+ def valid(self, data_loader):
122
+ """ Model evaluation on the validation dataset."""
123
+ if data_loader["valid"] is None:
124
+ raise ValueError("No data for valid")
125
+
126
+ print('')
127
+ print("===Validating===")
128
+ valid_loss = []
129
+ self.model.eval()
130
+ valid_step = 0
131
+ with torch.no_grad():
132
+ vbar = tqdm(data_loader["valid"], ncols=80)
133
+ for valid_idx, valid_batch in enumerate(vbar):
134
+ vbar.set_description("Validation")
135
+ data_valid, labels_valid = valid_batch[0].to(
136
+ self.device), valid_batch[1].to(self.device)
137
+ N, D, C, H, W = data_valid.shape
138
+ data_valid = data_valid.view(N * D, C, H, W)
139
+ labels_valid = labels_valid.view(-1, 1)
140
+ data_valid = data_valid[:(N * D) // self.base_len * self.base_len]
141
+ # Add one more frame for EfficientPhys since it does torch.diff for the input
142
+ last_frame = torch.unsqueeze(data_valid[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
143
+ data_valid = torch.cat((data_valid, last_frame), 0)
144
+ labels_valid = labels_valid[:(N * D) // self.base_len * self.base_len]
145
+ pred_ppg_valid = self.model(data_valid)
146
+ loss = self.criterion(pred_ppg_valid, labels_valid)
147
+ valid_loss.append(loss.item())
148
+ valid_step += 1
149
+ vbar.set_postfix(loss=loss.item())
150
+ valid_loss = np.asarray(valid_loss)
151
+ return np.mean(valid_loss)
152
+
153
+ def test(self, data_loader):
154
+ """ Model evaluation on the testing dataset."""
155
+ if data_loader["test"] is None:
156
+ raise ValueError("No data for test")
157
+
158
+ print('')
159
+ print("===Testing===")
160
+
161
+ # Change chunk length to be test chunk length
162
+ self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
163
+
164
+ predictions = dict()
165
+ labels = dict()
166
+
167
+ if self.config.TOOLBOX_MODE == "only_test":
168
+ if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
169
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
170
+ self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH))
171
+ print("Testing uses pretrained model!")
172
+ else:
173
+ if self.config.TEST.USE_LAST_EPOCH:
174
+ last_epoch_model_path = os.path.join(
175
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
176
+ print("Testing uses last epoch as non-pretrained model!")
177
+ print(last_epoch_model_path)
178
+ self.model.load_state_dict(torch.load(last_epoch_model_path))
179
+ else:
180
+ best_model_path = os.path.join(
181
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
182
+ print("Testing uses best epoch selected using model selection as non-pretrained model!")
183
+ print(best_model_path)
184
+ self.model.load_state_dict(torch.load(best_model_path))
185
+
186
+ self.model = self.model.to(self.config.DEVICE)
187
+ self.model.eval()
188
+ print("Running model evaluation on the testing dataset!")
189
+ with torch.no_grad():
190
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
191
+ batch_size = test_batch[0].shape[0]
192
+ data_test, labels_test = test_batch[0].to(
193
+ self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
194
+ N, D, C, H, W = data_test.shape
195
+ data_test = data_test.view(N * D, C, H, W)
196
+ labels_test = labels_test.view(-1, 1)
197
+ data_test = data_test[:(N * D) // self.base_len * self.base_len]
198
+ # Add one more frame for EfficientPhys since it does torch.diff for the input
199
+ last_frame = torch.unsqueeze(data_test[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
200
+ data_test = torch.cat((data_test, last_frame), 0)
201
+ labels_test = labels_test[:(N * D) // self.base_len * self.base_len]
202
+ pred_ppg_test = self.model(data_test)
203
+
204
+ if self.config.TEST.OUTPUT_SAVE_DIR:
205
+ labels_test = labels_test.cpu()
206
+ pred_ppg_test = pred_ppg_test.cpu()
207
+
208
+ for idx in range(batch_size):
209
+ subj_index = test_batch[2][idx]
210
+ sort_index = int(test_batch[3][idx])
211
+ if subj_index not in predictions.keys():
212
+ predictions[subj_index] = dict()
213
+ labels[subj_index] = dict()
214
+ predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
215
+ labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
216
+
217
+ print('')
218
+ calculate_metrics(predictions, labels, self.config)
219
+ if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
220
+ self.save_test_outputs(predictions, labels, self.config)
221
+
222
+ def save_model(self, index):
223
+ if not os.path.exists(self.model_dir):
224
+ os.makedirs(self.model_dir)
225
+ model_path = os.path.join(
226
+ self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
227
+ torch.save(self.model.state_dict(), model_path)
228
+ print('Saved Model Path: ', model_path)
neural_methods/trainer/FactorizePhysTrainer.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
3
+ NeurIPS 2024
4
+ Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
5
+ """
6
+
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ import torch.optim as optim
11
+ from evaluation.metrics import calculate_metrics
12
+ from neural_methods.loss.NegPearsonLoss import Neg_Pearson
13
+ from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys
14
+ from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig
15
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
16
+ from tqdm import tqdm
17
+
18
+
19
+ class FactorizePhysTrainer(BaseTrainer):
20
+
21
+ def __init__(self, config, data_loader):
22
+ """Inits parameters from args and the writer for TensorboardX."""
23
+ super().__init__()
24
+ self.max_epoch_num = config.TRAIN.EPOCHS
25
+ self.model_dir = config.MODEL.MODEL_DIR
26
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
27
+ self.batch_size = config.TRAIN.BATCH_SIZE
28
+ self.num_of_gpu = config.NUM_OF_GPU_TRAIN
29
+ self.dropout_rate = config.MODEL.DROP_RATE
30
+ self.base_len = self.num_of_gpu
31
+ self.config = config
32
+ self.min_valid_loss = None
33
+ self.best_epoch = 0
34
+
35
+ if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0:
36
+ dev_list = [int(d) for d in config.DEVICE.replace("cuda:", "").split(",")]
37
+ self.device = torch.device(dev_list[0]) #currently toolbox only supports 1 GPU
38
+ self.num_of_gpu = 1 #config.NUM_OF_GPU_TRAIN # set number of used GPUs
39
+ else:
40
+ self.device = torch.device("cpu") # if no GPUs set device is CPU
41
+ self.num_of_gpu = 0 # no GPUs used
42
+
43
+ frames = self.config.MODEL.FactorizePhys.FRAME_NUM
44
+ in_channels = self.config.MODEL.FactorizePhys.CHANNELS
45
+ model_type = self.config.MODEL.FactorizePhys.TYPE
46
+ model_type = model_type.lower()
47
+
48
+ md_config = {}
49
+ md_config["FRAME_NUM"] = self.config.MODEL.FactorizePhys.FRAME_NUM
50
+ md_config["MD_TYPE"] = self.config.MODEL.FactorizePhys.MD_TYPE
51
+ md_config["MD_FSAM"] = self.config.MODEL.FactorizePhys.MD_FSAM
52
+ md_config["MD_TRANSFORM"] = self.config.MODEL.FactorizePhys.MD_TRANSFORM
53
+ md_config["MD_S"] = self.config.MODEL.FactorizePhys.MD_S
54
+ md_config["MD_R"] = self.config.MODEL.FactorizePhys.MD_R
55
+ md_config["MD_STEPS"] = self.config.MODEL.FactorizePhys.MD_STEPS
56
+ md_config["MD_INFERENCE"] = self.config.MODEL.FactorizePhys.MD_INFERENCE
57
+ md_config["MD_RESIDUAL"] = self.config.MODEL.FactorizePhys.MD_RESIDUAL
58
+
59
+ self.md_infer = self.config.MODEL.FactorizePhys.MD_INFERENCE
60
+ self.use_fsam = self.config.MODEL.FactorizePhys.MD_FSAM
61
+
62
+ if model_type == "standard":
63
+ self.model = FactorizePhys(frames=frames, md_config=md_config, in_channels=in_channels,
64
+ dropout=self.dropout_rate, device=self.device) # [3, T, 72,72]
65
+ elif model_type == "big":
66
+ self.model = FactorizePhysBig(frames=frames, md_config=md_config, in_channels=in_channels,
67
+ dropout=self.dropout_rate, device=self.device) # [3, T, 144,144]
68
+ else:
69
+ print("Unexpected model type specified. Should be standard or big, but specified:", model_type)
70
+ exit()
71
+
72
+ if torch.cuda.device_count() > 0 and self.num_of_gpu > 0: # distribute model across GPUs
73
+ self.model = torch.nn.DataParallel(self.model, device_ids=[self.device]) # data parallel model
74
+ else:
75
+ self.model = torch.nn.DataParallel(self.model).to(self.device)
76
+
77
+ if self.config.TOOLBOX_MODE == "train_and_test" or self.config.TOOLBOX_MODE == "only_train":
78
+ self.num_train_batches = len(data_loader["train"])
79
+ self.criterion = Neg_Pearson()
80
+ self.optimizer = optim.Adam(
81
+ self.model.parameters(), lr=self.config.TRAIN.LR)
82
+ # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
83
+ self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
84
+ self.optimizer, max_lr=self.config.TRAIN.LR, epochs=self.config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
85
+ elif self.config.TOOLBOX_MODE == "only_test":
86
+ pass
87
+ else:
88
+ raise ValueError("FactorizePhys trainer initialized in incorrect toolbox mode!")
89
+
90
+ def train(self, data_loader):
91
+ """Training routine for model"""
92
+ if data_loader["train"] is None:
93
+ raise ValueError("No data for train")
94
+
95
+ mean_training_losses = []
96
+ mean_valid_losses = []
97
+ mean_appx_error = []
98
+ lrs = []
99
+ for epoch in range(self.max_epoch_num):
100
+ print('')
101
+ print(f"====Training Epoch: {epoch}====")
102
+ running_loss = 0.0
103
+ train_loss = []
104
+ appx_error_list = []
105
+ self.model.train()
106
+ tbar = tqdm(data_loader["train"], ncols=80)
107
+ for idx, batch in enumerate(tbar):
108
+ tbar.set_description("Train epoch %s" % epoch)
109
+
110
+ data = batch[0].to(self.device)
111
+ labels = batch[1].to(self.device)
112
+
113
+ if len(labels.shape) > 2:
114
+ labels = labels[..., 0] # Compatibility wigth multi-signal labelled data
115
+ labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize
116
+ last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
117
+ data = torch.cat((data, last_frame), 2)
118
+
119
+ # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
120
+ # labels = torch.cat((labels, last_sample), 0)
121
+ # labels = torch.diff(labels, dim=0)
122
+ # labels = labels/ torch.std(labels) # normalize
123
+ # labels[torch.isnan(labels)] = 0
124
+
125
+ self.optimizer.zero_grad()
126
+ if self.model.training and self.use_fsam:
127
+ pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data)
128
+ else:
129
+ pred_ppg, vox_embed = self.model(data)
130
+
131
+ pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize
132
+
133
+ loss = self.criterion(pred_ppg, labels)
134
+
135
+ loss.backward()
136
+ running_loss += loss.item()
137
+ if idx % 100 == 99: # print every 100 mini-batches
138
+ print(
139
+ f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
140
+ running_loss = 0.0
141
+ train_loss.append(loss.item())
142
+ if self.use_fsam:
143
+ appx_error_list.append(appx_error.item())
144
+
145
+ # Append the current learning rate to the list
146
+ lrs.append(self.scheduler.get_last_lr())
147
+
148
+ self.optimizer.step()
149
+ self.scheduler.step()
150
+
151
+ if self.use_fsam:
152
+ tbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item())
153
+ else:
154
+ tbar.set_postfix(loss=loss.item())
155
+
156
+ # Append the mean training loss for the epoch
157
+ mean_training_losses.append(np.mean(train_loss))
158
+ if self.use_fsam:
159
+ mean_appx_error.append(np.mean(appx_error_list))
160
+ print("Mean train loss: {}, Mean appx error: {}".format(
161
+ np.mean(train_loss), np.mean(appx_error_list)))
162
+ else:
163
+ print("Mean train loss: {}".format(np.mean(train_loss)))
164
+
165
+ self.save_model(epoch)
166
+ if not self.config.TEST.USE_LAST_EPOCH:
167
+ valid_loss = self.valid(data_loader)
168
+ mean_valid_losses.append(valid_loss)
169
+ print('validation loss: ', valid_loss)
170
+ if self.min_valid_loss is None:
171
+ self.min_valid_loss = valid_loss
172
+ self.best_epoch = epoch
173
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
174
+ elif (valid_loss < self.min_valid_loss):
175
+ self.min_valid_loss = valid_loss
176
+ self.best_epoch = epoch
177
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
178
+ if not self.config.TEST.USE_LAST_EPOCH:
179
+ print("best trained epoch: {}, min_val_loss: {}".format(
180
+ self.best_epoch, self.min_valid_loss))
181
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
182
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
183
+
184
+ def valid(self, data_loader):
185
+ """ Runs the model on valid sets."""
186
+ if data_loader["valid"] is None:
187
+ raise ValueError("No data for valid")
188
+
189
+ print('')
190
+ print(" ====Validing===")
191
+ valid_loss = []
192
+ self.model.eval()
193
+ valid_step = 0
194
+ with torch.no_grad():
195
+ vbar = tqdm(data_loader["valid"], ncols=80)
196
+ for valid_idx, valid_batch in enumerate(vbar):
197
+ vbar.set_description("Validation")
198
+
199
+ data, labels = valid_batch[0].to(self.device), valid_batch[1].to(self.device)
200
+ if len(labels.shape) > 2:
201
+ labels = labels[..., 0] # Compatibility wigth multi-signal labelled data
202
+ labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize
203
+
204
+ last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
205
+ data = torch.cat((data, last_frame), 2)
206
+
207
+ # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
208
+ # labels = torch.cat((labels, last_sample), 0)
209
+ # labels = torch.diff(labels, dim=0)
210
+ # labels = labels/ torch.std(labels) # normalize
211
+ # labels[torch.isnan(labels)] = 0
212
+
213
+ if self.md_infer and self.use_fsam:
214
+ pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data)
215
+ else:
216
+ pred_ppg, vox_embed = self.model(data)
217
+ pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize
218
+ loss = self.criterion(pred_ppg, labels)
219
+
220
+ valid_loss.append(loss.item())
221
+ valid_step += 1
222
+ # vbar.set_postfix(loss=loss.item())
223
+ if self.md_infer and self.use_fsam:
224
+ vbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item())
225
+ else:
226
+ vbar.set_postfix(loss=loss.item())
227
+ valid_loss = np.asarray(valid_loss)
228
+ return np.mean(valid_loss)
229
+
230
+ def test(self, data_loader):
231
+ """ Runs the model on test sets."""
232
+ if data_loader["test"] is None:
233
+ raise ValueError("No data for test")
234
+
235
+ print('')
236
+ print("===Testing===")
237
+ predictions = dict()
238
+ labels = dict()
239
+
240
+ if self.config.TOOLBOX_MODE == "only_test":
241
+ if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
242
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
243
+ self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")), strict=False)
244
+ print("Testing uses pretrained model!")
245
+ print(self.config.INFERENCE.MODEL_PATH)
246
+ else:
247
+ if self.config.TEST.USE_LAST_EPOCH:
248
+ last_epoch_model_path = os.path.join(
249
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
250
+ print("Testing uses last epoch as non-pretrained model!")
251
+ print(last_epoch_model_path)
252
+ self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")), strict=False)
253
+ else:
254
+ best_model_path = os.path.join(
255
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
256
+ print("Testing uses best epoch selected using model selection as non-pretrained model!")
257
+ print(best_model_path)
258
+ self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")), strict=False)
259
+
260
+ self.model = self.model.to(self.device)
261
+ self.model.eval()
262
+ print("Running model evaluation on the testing dataset!")
263
+ with torch.no_grad():
264
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
265
+ batch_size = test_batch[0].shape[0]
266
+ data, labels_test = test_batch[0].to(self.device), test_batch[1].to(self.device)
267
+
268
+ if len(labels_test.shape) > 2:
269
+ labels_test = labels_test[..., 0] # Compatibility wigth multi-signal labelled data
270
+ labels_test = (labels_test - torch.mean(labels_test)) / torch.std(labels_test) # normalize
271
+
272
+ last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
273
+ data = torch.cat((data, last_frame), 2)
274
+
275
+ # last_sample = torch.unsqueeze(labels_test[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
276
+ # labels_test = torch.cat((labels_test, last_sample), 0)
277
+ # labels_test = torch.diff(labels_test, dim=0)
278
+ # labels_test = labels_test/ torch.std(labels_test) # normalize
279
+ # labels_test[torch.isnan(labels_test)] = 0
280
+
281
+ if self.md_infer and self.use_fsam:
282
+ pred_ppg_test, vox_embed, factorized_embed, appx_error = self.model(data)
283
+ else:
284
+ pred_ppg_test, vox_embed = self.model(data)
285
+ pred_ppg_test = (pred_ppg_test - torch.mean(pred_ppg_test)) / torch.std(pred_ppg_test) # normalize
286
+
287
+ if self.config.TEST.OUTPUT_SAVE_DIR:
288
+ labels_test = labels_test.cpu()
289
+ pred_ppg_test = pred_ppg_test.cpu()
290
+
291
+ for idx in range(batch_size):
292
+ subj_index = test_batch[2][idx]
293
+ sort_index = int(test_batch[3][idx])
294
+ if subj_index not in predictions.keys():
295
+ predictions[subj_index] = dict()
296
+ labels[subj_index] = dict()
297
+ predictions[subj_index][sort_index] = pred_ppg_test[idx]
298
+ labels[subj_index][sort_index] = labels_test[idx]
299
+
300
+
301
+ print('')
302
+ calculate_metrics(predictions, labels, self.config)
303
+ if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
304
+ self.save_test_outputs(predictions, labels, self.config)
305
+
306
+ def save_model(self, index):
307
+ if not os.path.exists(self.model_dir):
308
+ os.makedirs(self.model_dir)
309
+ model_path = os.path.join(
310
+ self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
311
+ torch.save(self.model.state_dict(), model_path)
312
+ print('Saved Model Path: ', model_path)
neural_methods/trainer/FactorizePhysTrainer.py.backup ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
3
+ NeurIPS 2024
4
+ Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
5
+ """
6
+
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ import torch.optim as optim
11
+ from evaluation.metrics import calculate_metrics
12
+ from neural_methods.loss.NegPearsonLoss import Neg_Pearson
13
+ from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys
14
+ from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig
15
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
16
+ from tqdm import tqdm
17
+
18
+
19
+ class FactorizePhysTrainer(BaseTrainer):
20
+
21
+ def __init__(self, config, data_loader):
22
+ """Inits parameters from args and the writer for TensorboardX."""
23
+ super().__init__()
24
+ self.max_epoch_num = config.TRAIN.EPOCHS
25
+ self.model_dir = config.MODEL.MODEL_DIR
26
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
27
+ self.batch_size = config.TRAIN.BATCH_SIZE
28
+ self.num_of_gpu = config.NUM_OF_GPU_TRAIN
29
+ self.dropout_rate = config.MODEL.DROP_RATE
30
+ self.base_len = self.num_of_gpu
31
+ self.config = config
32
+ self.min_valid_loss = None
33
+ self.best_epoch = 0
34
+
35
+ if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0:
36
+ dev_list = [int(d) for d in config.DEVICE.replace("cuda:", "").split(",")]
37
+ self.device = torch.device(dev_list[0]) #currently toolbox only supports 1 GPU
38
+ self.num_of_gpu = 1 #config.NUM_OF_GPU_TRAIN # set number of used GPUs
39
+ else:
40
+ self.device = torch.device("cpu") # if no GPUs set device is CPU
41
+ self.num_of_gpu = 0 # no GPUs used
42
+
43
+ frames = self.config.MODEL.FactorizePhys.FRAME_NUM
44
+ in_channels = self.config.MODEL.FactorizePhys.CHANNELS
45
+ model_type = self.config.MODEL.FactorizePhys.TYPE
46
+ model_type = model_type.lower()
47
+
48
+ md_config = {}
49
+ md_config["FRAME_NUM"] = self.config.MODEL.FactorizePhys.FRAME_NUM
50
+ md_config["MD_TYPE"] = self.config.MODEL.FactorizePhys.MD_TYPE
51
+ md_config["MD_FSAM"] = self.config.MODEL.FactorizePhys.MD_FSAM
52
+ md_config["MD_TRANSFORM"] = self.config.MODEL.FactorizePhys.MD_TRANSFORM
53
+ md_config["MD_S"] = self.config.MODEL.FactorizePhys.MD_S
54
+ md_config["MD_R"] = self.config.MODEL.FactorizePhys.MD_R
55
+ md_config["MD_STEPS"] = self.config.MODEL.FactorizePhys.MD_STEPS
56
+ md_config["MD_INFERENCE"] = self.config.MODEL.FactorizePhys.MD_INFERENCE
57
+ md_config["MD_RESIDUAL"] = self.config.MODEL.FactorizePhys.MD_RESIDUAL
58
+
59
+ self.md_infer = self.config.MODEL.FactorizePhys.MD_INFERENCE
60
+ self.use_fsam = self.config.MODEL.FactorizePhys.MD_FSAM
61
+
62
+ if model_type == "standard":
63
+ self.model = FactorizePhys(frames=frames, md_config=md_config, in_channels=in_channels,
64
+ dropout=self.dropout_rate, device=self.device) # [3, T, 72,72]
65
+ elif model_type == "big":
66
+ self.model = FactorizePhysBig(frames=frames, md_config=md_config, in_channels=in_channels,
67
+ dropout=self.dropout_rate, device=self.device) # [3, T, 144,144]
68
+ else:
69
+ print("Unexpected model type specified. Should be standard or big, but specified:", model_type)
70
+ exit()
71
+
72
+ if torch.cuda.device_count() > 0 and self.num_of_gpu > 0: # distribute model across GPUs
73
+ self.model = torch.nn.DataParallel(self.model, device_ids=[self.device]) # data parallel model
74
+ else:
75
+ self.model = torch.nn.DataParallel(self.model).to(self.device)
76
+
77
+ if self.config.TOOLBOX_MODE == "train_and_test" or self.config.TOOLBOX_MODE == "only_train":
78
+ self.num_train_batches = len(data_loader["train"])
79
+ self.criterion = Neg_Pearson()
80
+ self.optimizer = optim.Adam(
81
+ self.model.parameters(), lr=self.config.TRAIN.LR)
82
+ # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
83
+ self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
84
+ self.optimizer, max_lr=self.config.TRAIN.LR, epochs=self.config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
85
+ elif self.config.TOOLBOX_MODE == "only_test":
86
+ pass
87
+ else:
88
+ raise ValueError("FactorizePhys trainer initialized in incorrect toolbox mode!")
89
+
90
+ def train(self, data_loader):
91
+ """Training routine for model"""
92
+ if data_loader["train"] is None:
93
+ raise ValueError("No data for train")
94
+
95
+ mean_training_losses = []
96
+ mean_valid_losses = []
97
+ mean_appx_error = []
98
+ lrs = []
99
+ for epoch in range(self.max_epoch_num):
100
+ print('')
101
+ print(f"====Training Epoch: {epoch}====")
102
+ running_loss = 0.0
103
+ train_loss = []
104
+ appx_error_list = []
105
+ self.model.train()
106
+ tbar = tqdm(data_loader["train"], ncols=80)
107
+ for idx, batch in enumerate(tbar):
108
+ tbar.set_description("Train epoch %s" % epoch)
109
+
110
+ data = batch[0].to(self.device)
111
+ labels = batch[1].to(self.device)
112
+
113
+ if len(labels.shape) > 2:
114
+ labels = labels[..., 0] # Compatibility wigth multi-signal labelled data
115
+ labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize
116
+ last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
117
+ data = torch.cat((data, last_frame), 2)
118
+
119
+ # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
120
+ # labels = torch.cat((labels, last_sample), 0)
121
+ # labels = torch.diff(labels, dim=0)
122
+ # labels = labels/ torch.std(labels) # normalize
123
+ # labels[torch.isnan(labels)] = 0
124
+
125
+ self.optimizer.zero_grad()
126
+ if self.model.training and self.use_fsam:
127
+ pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data)
128
+ else:
129
+ pred_ppg, vox_embed = self.model(data)
130
+
131
+ pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize
132
+
133
+ loss = self.criterion(pred_ppg, labels)
134
+
135
+ loss.backward()
136
+ running_loss += loss.item()
137
+ if idx % 100 == 99: # print every 100 mini-batches
138
+ print(
139
+ f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
140
+ running_loss = 0.0
141
+ train_loss.append(loss.item())
142
+ if self.use_fsam:
143
+ appx_error_list.append(appx_error.item())
144
+
145
+ # Append the current learning rate to the list
146
+ lrs.append(self.scheduler.get_last_lr())
147
+
148
+ self.optimizer.step()
149
+ self.scheduler.step()
150
+
151
+ if self.use_fsam:
152
+ tbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item())
153
+ else:
154
+ tbar.set_postfix(loss=loss.item())
155
+
156
+ # Append the mean training loss for the epoch
157
+ mean_training_losses.append(np.mean(train_loss))
158
+ if self.use_fsam:
159
+ mean_appx_error.append(np.mean(appx_error_list))
160
+ print("Mean train loss: {}, Mean appx error: {}".format(
161
+ np.mean(train_loss), np.mean(appx_error_list)))
162
+ else:
163
+ print("Mean train loss: {}".format(np.mean(train_loss)))
164
+
165
+ self.save_model(epoch)
166
+ if not self.config.TEST.USE_LAST_EPOCH:
167
+ valid_loss = self.valid(data_loader)
168
+ mean_valid_losses.append(valid_loss)
169
+ print('validation loss: ', valid_loss)
170
+ if self.min_valid_loss is None:
171
+ self.min_valid_loss = valid_loss
172
+ self.best_epoch = epoch
173
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
174
+ elif (valid_loss < self.min_valid_loss):
175
+ self.min_valid_loss = valid_loss
176
+ self.best_epoch = epoch
177
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
178
+ if not self.config.TEST.USE_LAST_EPOCH:
179
+ print("best trained epoch: {}, min_val_loss: {}".format(
180
+ self.best_epoch, self.min_valid_loss))
181
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
182
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
183
+
184
+ def valid(self, data_loader):
185
+ """ Runs the model on valid sets."""
186
+ if data_loader["valid"] is None:
187
+ raise ValueError("No data for valid")
188
+
189
+ print('')
190
+ print(" ====Validing===")
191
+ valid_loss = []
192
+ self.model.eval()
193
+ valid_step = 0
194
+ with torch.no_grad():
195
+ vbar = tqdm(data_loader["valid"], ncols=80)
196
+ for valid_idx, valid_batch in enumerate(vbar):
197
+ vbar.set_description("Validation")
198
+
199
+ data, labels = valid_batch[0].to(self.device), valid_batch[1].to(self.device)
200
+ if len(labels.shape) > 2:
201
+ labels = labels[..., 0] # Compatibility wigth multi-signal labelled data
202
+ labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize
203
+
204
+ last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
205
+ data = torch.cat((data, last_frame), 2)
206
+
207
+ # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
208
+ # labels = torch.cat((labels, last_sample), 0)
209
+ # labels = torch.diff(labels, dim=0)
210
+ # labels = labels/ torch.std(labels) # normalize
211
+ # labels[torch.isnan(labels)] = 0
212
+
213
+ if self.md_infer and self.use_fsam:
214
+ pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data)
215
+ else:
216
+ pred_ppg, vox_embed = self.model(data)
217
+ pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize
218
+ loss = self.criterion(pred_ppg, labels)
219
+
220
+ valid_loss.append(loss.item())
221
+ valid_step += 1
222
+ # vbar.set_postfix(loss=loss.item())
223
+ if self.md_infer and self.use_fsam:
224
+ vbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item())
225
+ else:
226
+ vbar.set_postfix(loss=loss.item())
227
+ valid_loss = np.asarray(valid_loss)
228
+ return np.mean(valid_loss)
229
+
230
+ def test(self, data_loader):
231
+ """ Runs the model on test sets."""
232
+ if data_loader["test"] is None:
233
+ raise ValueError("No data for test")
234
+
235
+ print('')
236
+ print("===Testing===")
237
+ predictions = dict()
238
+ labels = dict()
239
+
240
+ if self.config.TOOLBOX_MODE == "only_test":
241
+ if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
242
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
243
+ self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=self.device), strict=False)
244
+ print("Testing uses pretrained model!")
245
+ print(self.config.INFERENCE.MODEL_PATH)
246
+ else:
247
+ if self.config.TEST.USE_LAST_EPOCH:
248
+ last_epoch_model_path = os.path.join(
249
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
250
+ print("Testing uses last epoch as non-pretrained model!")
251
+ print(last_epoch_model_path)
252
+ self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=self.device), strict=False)
253
+ else:
254
+ best_model_path = os.path.join(
255
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
256
+ print("Testing uses best epoch selected using model selection as non-pretrained model!")
257
+ print(best_model_path)
258
+ self.model.load_state_dict(torch.load(best_model_path, map_location=self.device), strict=False)
259
+
260
+ self.model = self.model.to(self.device)
261
+ self.model.eval()
262
+ print("Running model evaluation on the testing dataset!")
263
+ with torch.no_grad():
264
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
265
+ batch_size = test_batch[0].shape[0]
266
+ data, labels_test = test_batch[0].to(self.device), test_batch[1].to(self.device)
267
+
268
+ if len(labels_test.shape) > 2:
269
+ labels_test = labels_test[..., 0] # Compatibility wigth multi-signal labelled data
270
+ labels_test = (labels_test - torch.mean(labels_test)) / torch.std(labels_test) # normalize
271
+
272
+ last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
273
+ data = torch.cat((data, last_frame), 2)
274
+
275
+ # last_sample = torch.unsqueeze(labels_test[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
276
+ # labels_test = torch.cat((labels_test, last_sample), 0)
277
+ # labels_test = torch.diff(labels_test, dim=0)
278
+ # labels_test = labels_test/ torch.std(labels_test) # normalize
279
+ # labels_test[torch.isnan(labels_test)] = 0
280
+
281
+ if self.md_infer and self.use_fsam:
282
+ pred_ppg_test, vox_embed, factorized_embed, appx_error = self.model(data)
283
+ else:
284
+ pred_ppg_test, vox_embed = self.model(data)
285
+ pred_ppg_test = (pred_ppg_test - torch.mean(pred_ppg_test)) / torch.std(pred_ppg_test) # normalize
286
+
287
+ if self.config.TEST.OUTPUT_SAVE_DIR:
288
+ labels_test = labels_test.cpu()
289
+ pred_ppg_test = pred_ppg_test.cpu()
290
+
291
+ for idx in range(batch_size):
292
+ subj_index = test_batch[2][idx]
293
+ sort_index = int(test_batch[3][idx])
294
+ if subj_index not in predictions.keys():
295
+ predictions[subj_index] = dict()
296
+ labels[subj_index] = dict()
297
+ predictions[subj_index][sort_index] = pred_ppg_test[idx]
298
+ labels[subj_index][sort_index] = labels_test[idx]
299
+
300
+
301
+ print('')
302
+ calculate_metrics(predictions, labels, self.config)
303
+ if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
304
+ self.save_test_outputs(predictions, labels, self.config)
305
+
306
+ def save_model(self, index):
307
+ if not os.path.exists(self.model_dir):
308
+ os.makedirs(self.model_dir)
309
+ model_path = os.path.join(
310
+ self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
311
+ torch.save(self.model.state_dict(), model_path)
312
+ print('Saved Model Path: ', model_path)
neural_methods/trainer/PhysFormerTrainer.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trainer for Physformer.
2
+
3
+ Based on open-source code from the original PhysFormer authors below:
4
+ https://github.com/ZitongYu/PhysFormer/blob/main/train_Physformer_160_VIPL.py
5
+
6
+ We also thank the PhysBench authors for their open-source code based on the code
7
+ of the original authors. Their code below provided a better reference for tuning loss
8
+ parameters of interest and utilizing RSME as a validation loss:
9
+ https://github.com/KegangWangCCNU/PhysBench/blob/main/benchmark_addition/PhysFormer_pure.ipynb
10
+
11
+ """
12
+
13
+ import os
14
+ import numpy as np
15
+ import math
16
+ import torch
17
+ import torch.optim as optim
18
+ from evaluation.metrics import calculate_metrics
19
+ from neural_methods.loss.PhysNetNegPearsonLoss import Neg_Pearson
20
+ from neural_methods.loss.PhysFormerLossComputer import TorchLossComputer
21
+ from neural_methods.model.PhysFormer import ViT_ST_ST_Compact3_TDC_gra_sharp
22
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
23
+ from tqdm import tqdm
24
+ from scipy.signal import welch
25
+
26
+ class PhysFormerTrainer(BaseTrainer):
27
+
28
+ def __init__(self, config, data_loader):
29
+ """Inits parameters from args and the writer for TensorboardX."""
30
+ super().__init__()
31
+ self.device = torch.device(config.DEVICE)
32
+ self.max_epoch_num = config.TRAIN.EPOCHS
33
+ self.model_dir = config.MODEL.MODEL_DIR
34
+ self.dropout_rate = config.MODEL.DROP_RATE
35
+ self.patch_size = config.MODEL.PHYSFORMER.PATCH_SIZE
36
+ self.dim = config.MODEL.PHYSFORMER.DIM
37
+ self.ff_dim = config.MODEL.PHYSFORMER.FF_DIM
38
+ self.num_heads = config.MODEL.PHYSFORMER.NUM_HEADS
39
+ self.num_layers = config.MODEL.PHYSFORMER.NUM_LAYERS
40
+ self.theta = config.MODEL.PHYSFORMER.THETA
41
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
42
+ self.batch_size = config.TRAIN.BATCH_SIZE
43
+ self.num_of_gpu = config.NUM_OF_GPU_TRAIN
44
+ self.frame_rate = config.TRAIN.DATA.FS
45
+ self.config = config
46
+ self.min_valid_loss = None
47
+ self.best_epoch = 0
48
+
49
+ if config.TOOLBOX_MODE == "train_and_test":
50
+ self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
51
+ self.model = ViT_ST_ST_Compact3_TDC_gra_sharp(
52
+ image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W),
53
+ patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers,
54
+ dropout_rate=self.dropout_rate, theta=self.theta).to(self.device)
55
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
56
+
57
+ self.num_train_batches = len(data_loader["train"])
58
+ self.criterion_reg = torch.nn.MSELoss()
59
+ self.criterion_L1loss = torch.nn.L1Loss()
60
+ self.criterion_class = torch.nn.CrossEntropyLoss()
61
+ self.criterion_Pearson = Neg_Pearson()
62
+ self.optimizer = optim.Adam(self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0.00005)
63
+ # TODO: In both the PhysFormer repo's training example and other implementations of a PhysFormer trainer,
64
+ # a step_size that doesn't end up changing the LR always seems to be used. This seems to defeat the point
65
+ # of using StepLR in the first place. Consider investigating and using another approach (e.g., OneCycleLR).
66
+ self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=50, gamma=0.5)
67
+ elif config.TOOLBOX_MODE == "only_test":
68
+ self.chunk_len = config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
69
+ self.model = ViT_ST_ST_Compact3_TDC_gra_sharp(
70
+ image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W),
71
+ patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers,
72
+ dropout_rate=self.dropout_rate, theta=self.theta).to(self.device)
73
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
74
+ else:
75
+ raise ValueError("Physformer trainer initialized in incorrect toolbox mode!")
76
+
77
+
78
+ def train(self, data_loader):
79
+ """Training routine for model"""
80
+ if data_loader["train"] is None:
81
+ raise ValueError("No data for train")
82
+
83
+ # a --> Pearson loss; b --> frequency loss
84
+ a_start = 1.0
85
+ b_start = 1.0
86
+ exp_a = 0.5 # Unused
87
+ exp_b = 1.0
88
+
89
+ # TODO: Expand tracking and subsequent plotting of these losses for PhysFormer
90
+ mean_training_losses = []
91
+ mean_valid_losses = []
92
+ lrs = []
93
+
94
+ for epoch in range(self.max_epoch_num):
95
+ print('')
96
+ print(f"====Training Epoch: {epoch}====")
97
+ loss_rPPG_avg = []
98
+ loss_peak_avg = []
99
+ loss_kl_avg_test = []
100
+ loss_hr_mae = []
101
+
102
+ self.model.train()
103
+ tbar = tqdm(data_loader["train"], ncols=80)
104
+ for idx, batch in enumerate(tbar):
105
+ hr = torch.tensor([self.get_hr(i) for i in batch[1]]).float().to(self.device)
106
+ data, label = batch[0].float().to(self.device), batch[1].float().to(self.device)
107
+
108
+ self.optimizer.zero_grad()
109
+
110
+ gra_sharp = 2.0
111
+ rPPG, _, _, _ = self.model(data, gra_sharp)
112
+ rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG, axis=-1).view(-1, 1) # normalize
113
+ loss_rPPG = self.criterion_Pearson(rPPG, label)
114
+
115
+ fre_loss = 0.0
116
+ kl_loss = 0.0
117
+ train_mae = 0.0
118
+ for bb in range(data.shape[0]):
119
+ loss_distribution_kl, \
120
+ fre_loss_temp, \
121
+ train_mae_temp = TorchLossComputer.cross_entropy_power_spectrum_DLDL_softmax2(
122
+ rPPG[bb],
123
+ hr[bb],
124
+ self.frame_rate,
125
+ std=1.0
126
+ )
127
+ fre_loss = fre_loss+fre_loss_temp
128
+ kl_loss = kl_loss+loss_distribution_kl
129
+ train_mae = train_mae+train_mae_temp
130
+ fre_loss /= data.shape[0]
131
+ kl_loss /= data.shape[0]
132
+ train_mae /= data.shape[0]
133
+
134
+ if epoch>10:
135
+ a = 0.05
136
+ b = 5.0
137
+ else:
138
+ a = a_start
139
+ # exp ascend
140
+ b = b_start*math.pow(exp_b, epoch/10.0)
141
+
142
+ loss = a*loss_rPPG + b*(fre_loss+kl_loss)
143
+ loss.backward()
144
+ self.optimizer.step()
145
+
146
+ n = data.size(0)
147
+ loss_rPPG_avg.append(float(loss_rPPG.data))
148
+ loss_peak_avg.append(float(fre_loss.data))
149
+ loss_kl_avg_test.append(float(kl_loss.data))
150
+ loss_hr_mae.append(float(train_mae))
151
+ if idx % 100 == 99: # print every 100 mini-batches
152
+ print(f'\nepoch:{epoch}, batch:{idx + 1}, total:{len(data_loader["train"]) // self.batch_size}, '
153
+ f'lr:0.0001, sharp:{gra_sharp:.3f}, a:{a:.3f}, NegPearson:{np.mean(loss_rPPG_avg[-2000:]):.4f}, '
154
+ f'\nb:{b:.3f}, kl:{np.mean(loss_kl_avg_test[-2000:]):.3f}, fre_CEloss:{np.mean(loss_peak_avg[-2000:]):.3f}, '
155
+ f'hr_mae:{np.mean(loss_hr_mae[-2000:]):.3f}')
156
+
157
+ # Append the current learning rate to the list
158
+ lrs.append(self.scheduler.get_last_lr())
159
+ # Append the mean training loss for the epoch
160
+ mean_training_losses.append(np.mean(loss_rPPG_avg))
161
+ self.save_model(epoch)
162
+ self.scheduler.step()
163
+ self.model.eval()
164
+
165
+ if not self.config.TEST.USE_LAST_EPOCH:
166
+ valid_loss = self.valid(data_loader)
167
+ mean_valid_losses.append(valid_loss)
168
+ print(f'Validation RMSE:{valid_loss:.3f}, batch:{idx+1}')
169
+ if self.min_valid_loss is None:
170
+ self.min_valid_loss = valid_loss
171
+ self.best_epoch = epoch
172
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
173
+ elif (valid_loss < self.min_valid_loss):
174
+ self.min_valid_loss = valid_loss
175
+ self.best_epoch = epoch
176
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
177
+ if not self.config.TEST.USE_LAST_EPOCH:
178
+ print("best trained epoch: {}, min_val_loss: {}".format(
179
+ self.best_epoch, self.min_valid_loss))
180
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
181
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
182
+
183
+ def valid(self, data_loader):
184
+ """ Runs the model on valid sets."""
185
+ if data_loader["valid"] is None:
186
+ raise ValueError("No data for valid")
187
+
188
+ print('')
189
+ print(" ====Validating===")
190
+ self.optimizer.zero_grad()
191
+ with torch.no_grad():
192
+ hrs = []
193
+ vbar = tqdm(data_loader["valid"], ncols=80)
194
+ for val_idx, val_batch in enumerate(vbar):
195
+ data, label = val_batch[0].float().to(self.device), val_batch[1].float().to(self.device)
196
+ gra_sharp = 2.0
197
+ rPPG, _, _, _ = self.model(data, gra_sharp)
198
+ rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG).view(-1, 1)
199
+ for _1, _2 in zip(rPPG, label):
200
+ hrs.append((self.get_hr(_1.cpu().detach().numpy()), self.get_hr(_2.cpu().detach().numpy())))
201
+ RMSE = np.mean([(i-j)**2 for i, j in hrs])**0.5
202
+ return RMSE
203
+
204
+ def test(self, data_loader):
205
+ """ Runs the model on test sets."""
206
+ if data_loader["test"] is None:
207
+ raise ValueError("No data for test")
208
+
209
+ print('')
210
+ print("===Testing===")
211
+
212
+ # Change chunk length to be test chunk length
213
+ self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
214
+
215
+ predictions = dict()
216
+ labels = dict()
217
+
218
+ if self.config.TOOLBOX_MODE == "only_test":
219
+ if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
220
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
221
+ self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")))
222
+ print("Testing uses pretrained model!")
223
+ print(self.config.INFERENCE.MODEL_PATH)
224
+ else:
225
+ if self.config.TEST.USE_LAST_EPOCH:
226
+ last_epoch_model_path = os.path.join(
227
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
228
+ print("Testing uses last epoch as non-pretrained model!")
229
+ print(last_epoch_model_path)
230
+ self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")))
231
+ else:
232
+ best_model_path = os.path.join(
233
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
234
+ print("Testing uses best epoch selected using model selection as non-pretrained model!")
235
+ print(best_model_path)
236
+ self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")))
237
+
238
+ self.model = self.model.to(self.config.DEVICE)
239
+ self.model.eval()
240
+ print("Running model evaluation on the testing dataset!")
241
+ with torch.no_grad():
242
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
243
+ batch_size = test_batch[0].shape[0]
244
+ data, label = test_batch[0].to(
245
+ self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
246
+ gra_sharp = 2.0
247
+ pred_ppg_test, _, _, _ = self.model(data, gra_sharp)
248
+ for idx in range(batch_size):
249
+ subj_index = test_batch[2][idx]
250
+ sort_index = int(test_batch[3][idx])
251
+ if subj_index not in predictions.keys():
252
+ predictions[subj_index] = dict()
253
+ labels[subj_index] = dict()
254
+ predictions[subj_index][sort_index] = pred_ppg_test[idx]
255
+ labels[subj_index][sort_index] = label[idx]
256
+
257
+ print('')
258
+ calculate_metrics(predictions, labels, self.config)
259
+ if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
260
+ self.save_test_outputs(predictions, labels, self.config)
261
+
262
+ def save_model(self, index):
263
+ if not os.path.exists(self.model_dir):
264
+ os.makedirs(self.model_dir)
265
+ model_path = os.path.join(
266
+ self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
267
+ torch.save(self.model.state_dict(), model_path)
268
+ print('Saved Model Path: ', model_path)
269
+
270
+ # HR calculation based on ground truth label
271
+ def get_hr(self, y, sr=30, min=30, max=180):
272
+ p, q = welch(y, sr, nfft=1e5/sr, nperseg=np.min((len(y)-1, 256)))
273
+ return p[(p>min/60)&(p<max/60)][np.argmax(q[(p>min/60)&(p<max/60)])]*60
neural_methods/trainer/PhysFormerTrainer.py.backup ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trainer for Physformer.
2
+
3
+ Based on open-source code from the original PhysFormer authors below:
4
+ https://github.com/ZitongYu/PhysFormer/blob/main/train_Physformer_160_VIPL.py
5
+
6
+ We also thank the PhysBench authors for their open-source code based on the code
7
+ of the original authors. Their code below provided a better reference for tuning loss
8
+ parameters of interest and utilizing RSME as a validation loss:
9
+ https://github.com/KegangWangCCNU/PhysBench/blob/main/benchmark_addition/PhysFormer_pure.ipynb
10
+
11
+ """
12
+
13
+ import os
14
+ import numpy as np
15
+ import math
16
+ import torch
17
+ import torch.optim as optim
18
+ from evaluation.metrics import calculate_metrics
19
+ from neural_methods.loss.PhysNetNegPearsonLoss import Neg_Pearson
20
+ from neural_methods.loss.PhysFormerLossComputer import TorchLossComputer
21
+ from neural_methods.model.PhysFormer import ViT_ST_ST_Compact3_TDC_gra_sharp
22
+ from neural_methods.trainer.BaseTrainer import BaseTrainer
23
+ from tqdm import tqdm
24
+ from scipy.signal import welch
25
+
26
+ class PhysFormerTrainer(BaseTrainer):
27
+
28
+ def __init__(self, config, data_loader):
29
+ """Inits parameters from args and the writer for TensorboardX."""
30
+ super().__init__()
31
+ self.device = torch.device(config.DEVICE)
32
+ self.max_epoch_num = config.TRAIN.EPOCHS
33
+ self.model_dir = config.MODEL.MODEL_DIR
34
+ self.dropout_rate = config.MODEL.DROP_RATE
35
+ self.patch_size = config.MODEL.PHYSFORMER.PATCH_SIZE
36
+ self.dim = config.MODEL.PHYSFORMER.DIM
37
+ self.ff_dim = config.MODEL.PHYSFORMER.FF_DIM
38
+ self.num_heads = config.MODEL.PHYSFORMER.NUM_HEADS
39
+ self.num_layers = config.MODEL.PHYSFORMER.NUM_LAYERS
40
+ self.theta = config.MODEL.PHYSFORMER.THETA
41
+ self.model_file_name = config.TRAIN.MODEL_FILE_NAME
42
+ self.batch_size = config.TRAIN.BATCH_SIZE
43
+ self.num_of_gpu = config.NUM_OF_GPU_TRAIN
44
+ self.frame_rate = config.TRAIN.DATA.FS
45
+ self.config = config
46
+ self.min_valid_loss = None
47
+ self.best_epoch = 0
48
+
49
+ if config.TOOLBOX_MODE == "train_and_test":
50
+ self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
51
+ self.model = ViT_ST_ST_Compact3_TDC_gra_sharp(
52
+ image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W),
53
+ patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers,
54
+ dropout_rate=self.dropout_rate, theta=self.theta).to(self.device)
55
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
56
+
57
+ self.num_train_batches = len(data_loader["train"])
58
+ self.criterion_reg = torch.nn.MSELoss()
59
+ self.criterion_L1loss = torch.nn.L1Loss()
60
+ self.criterion_class = torch.nn.CrossEntropyLoss()
61
+ self.criterion_Pearson = Neg_Pearson()
62
+ self.optimizer = optim.Adam(self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0.00005)
63
+ # TODO: In both the PhysFormer repo's training example and other implementations of a PhysFormer trainer,
64
+ # a step_size that doesn't end up changing the LR always seems to be used. This seems to defeat the point
65
+ # of using StepLR in the first place. Consider investigating and using another approach (e.g., OneCycleLR).
66
+ self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=50, gamma=0.5)
67
+ elif config.TOOLBOX_MODE == "only_test":
68
+ self.chunk_len = config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
69
+ self.model = ViT_ST_ST_Compact3_TDC_gra_sharp(
70
+ image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W),
71
+ patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers,
72
+ dropout_rate=self.dropout_rate, theta=self.theta).to(self.device)
73
+ self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
74
+ else:
75
+ raise ValueError("Physformer trainer initialized in incorrect toolbox mode!")
76
+
77
+
78
+ def train(self, data_loader):
79
+ """Training routine for model"""
80
+ if data_loader["train"] is None:
81
+ raise ValueError("No data for train")
82
+
83
+ # a --> Pearson loss; b --> frequency loss
84
+ a_start = 1.0
85
+ b_start = 1.0
86
+ exp_a = 0.5 # Unused
87
+ exp_b = 1.0
88
+
89
+ # TODO: Expand tracking and subsequent plotting of these losses for PhysFormer
90
+ mean_training_losses = []
91
+ mean_valid_losses = []
92
+ lrs = []
93
+
94
+ for epoch in range(self.max_epoch_num):
95
+ print('')
96
+ print(f"====Training Epoch: {epoch}====")
97
+ loss_rPPG_avg = []
98
+ loss_peak_avg = []
99
+ loss_kl_avg_test = []
100
+ loss_hr_mae = []
101
+
102
+ self.model.train()
103
+ tbar = tqdm(data_loader["train"], ncols=80)
104
+ for idx, batch in enumerate(tbar):
105
+ hr = torch.tensor([self.get_hr(i) for i in batch[1]]).float().to(self.device)
106
+ data, label = batch[0].float().to(self.device), batch[1].float().to(self.device)
107
+
108
+ self.optimizer.zero_grad()
109
+
110
+ gra_sharp = 2.0
111
+ rPPG, _, _, _ = self.model(data, gra_sharp)
112
+ rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG, axis=-1).view(-1, 1) # normalize
113
+ loss_rPPG = self.criterion_Pearson(rPPG, label)
114
+
115
+ fre_loss = 0.0
116
+ kl_loss = 0.0
117
+ train_mae = 0.0
118
+ for bb in range(data.shape[0]):
119
+ loss_distribution_kl, \
120
+ fre_loss_temp, \
121
+ train_mae_temp = TorchLossComputer.cross_entropy_power_spectrum_DLDL_softmax2(
122
+ rPPG[bb],
123
+ hr[bb],
124
+ self.frame_rate,
125
+ std=1.0
126
+ )
127
+ fre_loss = fre_loss+fre_loss_temp
128
+ kl_loss = kl_loss+loss_distribution_kl
129
+ train_mae = train_mae+train_mae_temp
130
+ fre_loss /= data.shape[0]
131
+ kl_loss /= data.shape[0]
132
+ train_mae /= data.shape[0]
133
+
134
+ if epoch>10:
135
+ a = 0.05
136
+ b = 5.0
137
+ else:
138
+ a = a_start
139
+ # exp ascend
140
+ b = b_start*math.pow(exp_b, epoch/10.0)
141
+
142
+ loss = a*loss_rPPG + b*(fre_loss+kl_loss)
143
+ loss.backward()
144
+ self.optimizer.step()
145
+
146
+ n = data.size(0)
147
+ loss_rPPG_avg.append(float(loss_rPPG.data))
148
+ loss_peak_avg.append(float(fre_loss.data))
149
+ loss_kl_avg_test.append(float(kl_loss.data))
150
+ loss_hr_mae.append(float(train_mae))
151
+ if idx % 100 == 99: # print every 100 mini-batches
152
+ print(f'\nepoch:{epoch}, batch:{idx + 1}, total:{len(data_loader["train"]) // self.batch_size}, '
153
+ f'lr:0.0001, sharp:{gra_sharp:.3f}, a:{a:.3f}, NegPearson:{np.mean(loss_rPPG_avg[-2000:]):.4f}, '
154
+ f'\nb:{b:.3f}, kl:{np.mean(loss_kl_avg_test[-2000:]):.3f}, fre_CEloss:{np.mean(loss_peak_avg[-2000:]):.3f}, '
155
+ f'hr_mae:{np.mean(loss_hr_mae[-2000:]):.3f}')
156
+
157
+ # Append the current learning rate to the list
158
+ lrs.append(self.scheduler.get_last_lr())
159
+ # Append the mean training loss for the epoch
160
+ mean_training_losses.append(np.mean(loss_rPPG_avg))
161
+ self.save_model(epoch)
162
+ self.scheduler.step()
163
+ self.model.eval()
164
+
165
+ if not self.config.TEST.USE_LAST_EPOCH:
166
+ valid_loss = self.valid(data_loader)
167
+ mean_valid_losses.append(valid_loss)
168
+ print(f'Validation RMSE:{valid_loss:.3f}, batch:{idx+1}')
169
+ if self.min_valid_loss is None:
170
+ self.min_valid_loss = valid_loss
171
+ self.best_epoch = epoch
172
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
173
+ elif (valid_loss < self.min_valid_loss):
174
+ self.min_valid_loss = valid_loss
175
+ self.best_epoch = epoch
176
+ print("Update best model! Best epoch: {}".format(self.best_epoch))
177
+ if not self.config.TEST.USE_LAST_EPOCH:
178
+ print("best trained epoch: {}, min_val_loss: {}".format(
179
+ self.best_epoch, self.min_valid_loss))
180
+ if self.config.TRAIN.PLOT_LOSSES_AND_LR:
181
+ self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
182
+
183
+ def valid(self, data_loader):
184
+ """ Runs the model on valid sets."""
185
+ if data_loader["valid"] is None:
186
+ raise ValueError("No data for valid")
187
+
188
+ print('')
189
+ print(" ====Validating===")
190
+ self.optimizer.zero_grad()
191
+ with torch.no_grad():
192
+ hrs = []
193
+ vbar = tqdm(data_loader["valid"], ncols=80)
194
+ for val_idx, val_batch in enumerate(vbar):
195
+ data, label = val_batch[0].float().to(self.device), val_batch[1].float().to(self.device)
196
+ gra_sharp = 2.0
197
+ rPPG, _, _, _ = self.model(data, gra_sharp)
198
+ rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG).view(-1, 1)
199
+ for _1, _2 in zip(rPPG, label):
200
+ hrs.append((self.get_hr(_1.cpu().detach().numpy()), self.get_hr(_2.cpu().detach().numpy())))
201
+ RMSE = np.mean([(i-j)**2 for i, j in hrs])**0.5
202
+ return RMSE
203
+
204
+ def test(self, data_loader):
205
+ """ Runs the model on test sets."""
206
+ if data_loader["test"] is None:
207
+ raise ValueError("No data for test")
208
+
209
+ print('')
210
+ print("===Testing===")
211
+
212
+ # Change chunk length to be test chunk length
213
+ self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
214
+
215
+ predictions = dict()
216
+ labels = dict()
217
+
218
+ if self.config.TOOLBOX_MODE == "only_test":
219
+ if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
220
+ raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
221
+ self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH))
222
+ print("Testing uses pretrained model!")
223
+ print(self.config.INFERENCE.MODEL_PATH)
224
+ else:
225
+ if self.config.TEST.USE_LAST_EPOCH:
226
+ last_epoch_model_path = os.path.join(
227
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
228
+ print("Testing uses last epoch as non-pretrained model!")
229
+ print(last_epoch_model_path)
230
+ self.model.load_state_dict(torch.load(last_epoch_model_path))
231
+ else:
232
+ best_model_path = os.path.join(
233
+ self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
234
+ print("Testing uses best epoch selected using model selection as non-pretrained model!")
235
+ print(best_model_path)
236
+ self.model.load_state_dict(torch.load(best_model_path))
237
+
238
+ self.model = self.model.to(self.config.DEVICE)
239
+ self.model.eval()
240
+ print("Running model evaluation on the testing dataset!")
241
+ with torch.no_grad():
242
+ for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
243
+ batch_size = test_batch[0].shape[0]
244
+ data, label = test_batch[0].to(
245
+ self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
246
+ gra_sharp = 2.0
247
+ pred_ppg_test, _, _, _ = self.model(data, gra_sharp)
248
+ for idx in range(batch_size):
249
+ subj_index = test_batch[2][idx]
250
+ sort_index = int(test_batch[3][idx])
251
+ if subj_index not in predictions.keys():
252
+ predictions[subj_index] = dict()
253
+ labels[subj_index] = dict()
254
+ predictions[subj_index][sort_index] = pred_ppg_test[idx]
255
+ labels[subj_index][sort_index] = label[idx]
256
+
257
+ print('')
258
+ calculate_metrics(predictions, labels, self.config)
259
+ if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
260
+ self.save_test_outputs(predictions, labels, self.config)
261
+
262
+ def save_model(self, index):
263
+ if not os.path.exists(self.model_dir):
264
+ os.makedirs(self.model_dir)
265
+ model_path = os.path.join(
266
+ self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
267
+ torch.save(self.model.state_dict(), model_path)
268
+ print('Saved Model Path: ', model_path)
269
+
270
+ # HR calculation based on ground truth label
271
+ def get_hr(self, y, sr=30, min=30, max=180):
272
+ p, q = welch(y, sr, nfft=1e5/sr, nperseg=np.min((len(y)-1, 256)))
273
+ return p[(p>min/60)&(p<max/60)][np.argmax(q[(p>min/60)&(p<max/60)])]*60