Spaces:
Sleeping
Sleeping
swetchareddytukkani
commited on
Commit
·
1c6711c
1
Parent(s):
47fcc3f
Initial commit with PhysMamba rPPG application
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +16 -0
- app.py +2559 -0
- final_model_release/PURE_PhysMamba_DiffNormalized.pth +3 -0
- final_model_release/UBFC-rPPG_PhysMamba_DiffNormalized.pth +3 -0
- mamba_ssm/__init__.py +20 -0
- mamba_ssm/models/__init__.py +0 -0
- mamba_ssm/models/mixer_seq_simple.py +233 -0
- mamba_ssm/modules/__init__.py +0 -0
- mamba_ssm/modules/mamba_simple.py +418 -0
- mamba_ssm/ops/__init__.py +7 -0
- mamba_ssm/ops/selective_scan_interface.py +17 -0
- mamba_ssm/ops/triton/__init__.py +0 -0
- mamba_ssm/ops/triton/layernorm.py +636 -0
- mamba_ssm/ops/triton/selective_state_update.py +192 -0
- mamba_ssm/utils/__init__.py +0 -0
- mamba_ssm/utils/generation.py +377 -0
- mamba_ssm/utils/hf.py +23 -0
- neural_methods/__init__.py +0 -0
- neural_methods/loss/NegPearsonLoss.py +23 -0
- neural_methods/loss/PhysFormerLossComputer.py +120 -0
- neural_methods/loss/PhysNetNegPearsonLoss.py +43 -0
- neural_methods/loss/RythmFormerLossComputer.py +167 -0
- neural_methods/loss/__init__.py +0 -0
- neural_methods/model/BigSmall.py +177 -0
- neural_methods/model/DeepPhys.py +125 -0
- neural_methods/model/EfficientPhys.py +128 -0
- neural_methods/model/FactorizePhys/FSAM.py +530 -0
- neural_methods/model/FactorizePhys/FactorizePhys.py +251 -0
- neural_methods/model/FactorizePhys/FactorizePhysBig.py +251 -0
- neural_methods/model/FactorizePhys/__init__.py +0 -0
- neural_methods/model/FactorizePhys/test_FactorizePhys.py +286 -0
- neural_methods/model/FactorizePhys/test_FactorizePhysBig.py +292 -0
- neural_methods/model/PhysFormer.py +313 -0
- neural_methods/model/PhysMamba.py +246 -0
- neural_methods/model/PhysNet.py +124 -0
- neural_methods/model/RhythmFormer.py +418 -0
- neural_methods/model/TS_CAN.py +269 -0
- neural_methods/model/__init__.py +0 -0
- neural_methods/model/iBVPNet.py +194 -0
- neural_methods/trainer/BaseTrainer.py +108 -0
- neural_methods/trainer/BigSmallTrainer.py +484 -0
- neural_methods/trainer/BigSmallTrainer.py.backup +484 -0
- neural_methods/trainer/DeepPhysTrainer.py +209 -0
- neural_methods/trainer/DeepPhysTrainer.py.backup +209 -0
- neural_methods/trainer/EfficientPhysTrainer.py +228 -0
- neural_methods/trainer/EfficientPhysTrainer.py.backup +228 -0
- neural_methods/trainer/FactorizePhysTrainer.py +312 -0
- neural_methods/trainer/FactorizePhysTrainer.py.backup +312 -0
- neural_methods/trainer/PhysFormerTrainer.py +273 -0
- neural_methods/trainer/PhysFormerTrainer.py.backup +273 -0
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
*.so
|
| 7 |
+
*.egg
|
| 8 |
+
*.egg-info/
|
| 9 |
+
dist/
|
| 10 |
+
build/
|
| 11 |
+
.pyenv/
|
| 12 |
+
.venv/
|
| 13 |
+
logs/
|
| 14 |
+
analysis/
|
| 15 |
+
*.log
|
| 16 |
+
.DS_Store
|
app.py
ADDED
|
@@ -0,0 +1,2559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ.setdefault("OPENCV_AVFOUNDATION_SKIP_AUTH", "1")
|
| 3 |
+
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings("ignore")
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import io
|
| 10 |
+
import json
|
| 11 |
+
import tempfile
|
| 12 |
+
import time
|
| 13 |
+
import threading
|
| 14 |
+
import uuid
|
| 15 |
+
import importlib.util
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from collections import deque
|
| 18 |
+
from typing import Optional, Tuple, Dict, List
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import pandas as pd
|
| 22 |
+
import cv2
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
from scipy import signal
|
| 29 |
+
from scipy.signal import find_peaks, welch, get_window
|
| 30 |
+
|
| 31 |
+
# .mat support (SCAMPS)
|
| 32 |
+
try:
|
| 33 |
+
from scipy.io import loadmat
|
| 34 |
+
MAT_SUPPORT = True
|
| 35 |
+
except Exception:
|
| 36 |
+
MAT_SUPPORT = False
|
| 37 |
+
|
| 38 |
+
# Matplotlib headless backend
|
| 39 |
+
import matplotlib
|
| 40 |
+
matplotlib.use("Agg")
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
from matplotlib.gridspec import GridSpec
|
| 43 |
+
|
| 44 |
+
import gradio as gr
|
| 45 |
+
|
| 46 |
+
class PhysMambaattention_viz:
|
| 47 |
+
"""Simplified Grad-CAM for PhysMamba."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, model: nn.Module, device: torch.device):
|
| 50 |
+
self.model = model
|
| 51 |
+
self.device = device
|
| 52 |
+
self.activations = None
|
| 53 |
+
self.gradients = None
|
| 54 |
+
self.hook_handles = []
|
| 55 |
+
|
| 56 |
+
def _get_target_layer(self):
|
| 57 |
+
"""Auto-detect the best layer for visualization, excluding Mamba/SSM layers."""
|
| 58 |
+
|
| 59 |
+
# Strategy 1: Look for temporal difference convolution
|
| 60 |
+
for name, module in reversed(list(self.model.named_modules())):
|
| 61 |
+
if ('tdc' in name.lower() or 'temporal_diff' in name.lower()) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
|
| 62 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
|
| 63 |
+
print(f"[attention_viz] Selected TDC layer: {name} ({type(module).__name__})")
|
| 64 |
+
return name, module
|
| 65 |
+
|
| 66 |
+
# Strategy 2: Look for 3D convolutions (temporal-spatial)
|
| 67 |
+
for name, module in reversed(list(self.model.named_modules())):
|
| 68 |
+
if isinstance(module, nn.Conv3d) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
|
| 69 |
+
print(f"[attention_viz] Selected Conv3d layer: {name}")
|
| 70 |
+
return name, module
|
| 71 |
+
|
| 72 |
+
# Strategy 3: Look for 2D convolutions (spatial)
|
| 73 |
+
for name, module in reversed(list(self.model.named_modules())):
|
| 74 |
+
if isinstance(module, nn.Conv2d) and not ('mamba' in name.lower() or 'ssm' in name.lower()):
|
| 75 |
+
print(f"[attention_viz] Selected Conv2d layer: {name}")
|
| 76 |
+
return name, module
|
| 77 |
+
|
| 78 |
+
# Strategy 4: Look for any convolution (last resort)
|
| 79 |
+
for name, module in reversed(list(self.model.named_modules())):
|
| 80 |
+
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
| 81 |
+
print(f"[attention_viz] Selected Conv layer (fallback): {name} ({type(module).__name__})")
|
| 82 |
+
return name, module
|
| 83 |
+
|
| 84 |
+
# Strategy 5: Print all available layers and pick the last non-Mamba one
|
| 85 |
+
print("\n[attention_viz] Available layers:")
|
| 86 |
+
suitable_layers = []
|
| 87 |
+
for name, module in self.model.named_modules():
|
| 88 |
+
layer_type = type(module).__name__
|
| 89 |
+
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
|
| 90 |
+
is_mamba = 'mamba' in name.lower() or 'ssm' in name.lower()
|
| 91 |
+
print(f" - {name}: {layer_type} {'(MAMBA - SKIP)' if is_mamba else ''}")
|
| 92 |
+
if not is_mamba:
|
| 93 |
+
suitable_layers.append((name, module))
|
| 94 |
+
|
| 95 |
+
if suitable_layers:
|
| 96 |
+
name, module = suitable_layers[-1]
|
| 97 |
+
print(f"\n[attention_viz] Selected last suitable layer: {name} ({type(module).__name__})")
|
| 98 |
+
return name, module
|
| 99 |
+
|
| 100 |
+
raise ValueError("No suitable layer found for Grad-CAM (all layers are Mamba/SSM)")
|
| 101 |
+
|
| 102 |
+
def _forward_hook(self, module, input, output):
|
| 103 |
+
self.activations = output.detach()
|
| 104 |
+
|
| 105 |
+
def _backward_hook(self, module, grad_input, grad_output):
|
| 106 |
+
self.gradients = grad_output[0].detach()
|
| 107 |
+
|
| 108 |
+
def register_hooks(self, target_layer_name: Optional[str] = None):
|
| 109 |
+
"""Register hooks on target layer"""
|
| 110 |
+
self._remove_hooks()
|
| 111 |
+
|
| 112 |
+
if target_layer_name and target_layer_name.strip():
|
| 113 |
+
# Manual selection
|
| 114 |
+
try:
|
| 115 |
+
target_module = dict(self.model.named_modules())[target_layer_name.strip()]
|
| 116 |
+
print(f"Using manually specified layer: {target_layer_name}")
|
| 117 |
+
except KeyError:
|
| 118 |
+
print(f"⚠ Layer '{target_layer_name}' not found, falling back to auto-detection")
|
| 119 |
+
target_layer_name, target_module = self._get_target_layer()
|
| 120 |
+
else:
|
| 121 |
+
# Auto-detection
|
| 122 |
+
target_layer_name, target_module = self._get_target_layer()
|
| 123 |
+
|
| 124 |
+
fwd_handle = target_module.register_forward_hook(self._forward_hook)
|
| 125 |
+
bwd_handle = target_module.register_full_backward_hook(self._backward_hook)
|
| 126 |
+
|
| 127 |
+
self.hook_handles = [fwd_handle, bwd_handle]
|
| 128 |
+
|
| 129 |
+
def _remove_hooks(self):
|
| 130 |
+
for handle in self.hook_handles:
|
| 131 |
+
handle.remove()
|
| 132 |
+
self.hook_handles = []
|
| 133 |
+
self.activations = None
|
| 134 |
+
self.gradients = None
|
| 135 |
+
|
| 136 |
+
def generate(self, input_tensor: torch.Tensor) -> np.ndarray:
|
| 137 |
+
"""Generate Grad-CAM heatmap with improved gradient handling."""
|
| 138 |
+
# Set model to eval but keep gradient computation
|
| 139 |
+
self.model.eval()
|
| 140 |
+
|
| 141 |
+
# Ensure tensor requires grad
|
| 142 |
+
input_tensor = input_tensor.requires_grad_(True)
|
| 143 |
+
|
| 144 |
+
# Forward pass
|
| 145 |
+
output = self.model(input_tensor)
|
| 146 |
+
|
| 147 |
+
# Handle different output types
|
| 148 |
+
if isinstance(output, dict):
|
| 149 |
+
# Try common keys
|
| 150 |
+
for key in ('pred', 'output', 'bvp', 'logits', 'out'):
|
| 151 |
+
if key in output and isinstance(output[key], torch.Tensor):
|
| 152 |
+
output = output[key]
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
# Get scalar for backward
|
| 156 |
+
if output.numel() == 1:
|
| 157 |
+
target = output
|
| 158 |
+
else:
|
| 159 |
+
# Use mean of output
|
| 160 |
+
target = output.mean()
|
| 161 |
+
|
| 162 |
+
print(f"[attention_viz] Output shape: {output.shape}, Target: {target.item():.4f}")
|
| 163 |
+
|
| 164 |
+
# Backward pass
|
| 165 |
+
self.model.zero_grad()
|
| 166 |
+
target.backward(retain_graph=False)
|
| 167 |
+
|
| 168 |
+
# Check if we got gradients
|
| 169 |
+
if self.activations is None:
|
| 170 |
+
print("⚠ No activations captured!")
|
| 171 |
+
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
|
| 172 |
+
|
| 173 |
+
if self.gradients is None:
|
| 174 |
+
print("⚠ No gradients captured!")
|
| 175 |
+
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
|
| 176 |
+
|
| 177 |
+
print(f"[attention_viz] Activations shape: {self.activations.shape}")
|
| 178 |
+
print(f"[attention_viz] Gradients shape: {self.gradients.shape}")
|
| 179 |
+
|
| 180 |
+
activations = self.activations
|
| 181 |
+
gradients = self.gradients
|
| 182 |
+
|
| 183 |
+
# Compute weights (global average pooling of gradients)
|
| 184 |
+
if gradients.dim() == 5: # [B, C, T, H, W]
|
| 185 |
+
weights = gradients.mean(dim=[2, 3, 4], keepdim=True)
|
| 186 |
+
elif gradients.dim() == 4: # [B, C, H, W]
|
| 187 |
+
weights = gradients.mean(dim=[2, 3], keepdim=True)
|
| 188 |
+
elif gradients.dim() == 3: # [B, C, T]
|
| 189 |
+
weights = gradients.mean(dim=2, keepdim=True).unsqueeze(-1).unsqueeze(-1)
|
| 190 |
+
else:
|
| 191 |
+
print(f"⚠ Unexpected gradient dimensions: {gradients.dim()}")
|
| 192 |
+
return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1]))
|
| 193 |
+
|
| 194 |
+
# Weighted combination
|
| 195 |
+
cam = (weights * activations).sum(dim=1, keepdim=True)
|
| 196 |
+
|
| 197 |
+
# If 5D, average over time
|
| 198 |
+
if cam.dim() == 5:
|
| 199 |
+
cam = cam.mean(dim=2)
|
| 200 |
+
|
| 201 |
+
# ReLU
|
| 202 |
+
cam = F.relu(cam)
|
| 203 |
+
|
| 204 |
+
# Convert to numpy
|
| 205 |
+
cam = cam.squeeze().cpu().detach().numpy()
|
| 206 |
+
|
| 207 |
+
# Normalize
|
| 208 |
+
if cam.max() > cam.min():
|
| 209 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min())
|
| 210 |
+
|
| 211 |
+
print(f"[attention_viz] Final heatmap shape: {cam.shape}, range: [{cam.min():.3f}, {cam.max():.3f}]")
|
| 212 |
+
|
| 213 |
+
return cam
|
| 214 |
+
|
| 215 |
+
def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
|
| 216 |
+
"""Overlay heatmap on frame."""
|
| 217 |
+
h, w = frame.shape[:2]
|
| 218 |
+
heatmap_resized = cv2.resize(heatmap, (w, h))
|
| 219 |
+
heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
|
| 220 |
+
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 221 |
+
overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
|
| 222 |
+
return overlay
|
| 223 |
+
|
| 224 |
+
def cleanup(self):
|
| 225 |
+
self._remove_hooks()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray:
|
| 229 |
+
"""Apply DiffNormalized preprocessing from PhysMamba paper."""
|
| 230 |
+
if len(frames) < 2:
|
| 231 |
+
return np.zeros((len(frames), *frames[0].shape), dtype=np.float32)
|
| 232 |
+
|
| 233 |
+
diff_frames = []
|
| 234 |
+
|
| 235 |
+
for i in range(len(frames)):
|
| 236 |
+
if i == 0:
|
| 237 |
+
diff_frames.append(np.zeros_like(frames[0], dtype=np.float32))
|
| 238 |
+
else:
|
| 239 |
+
curr = frames[i].astype(np.float32)
|
| 240 |
+
prev = frames[i-1].astype(np.float32)
|
| 241 |
+
denominator = curr + prev + 1e-8
|
| 242 |
+
diff = (curr - prev) / denominator
|
| 243 |
+
diff_frames.append(diff)
|
| 244 |
+
|
| 245 |
+
diff_array = np.stack(diff_frames)
|
| 246 |
+
std = diff_array.std()
|
| 247 |
+
if std > 0:
|
| 248 |
+
diff_array = diff_array / std
|
| 249 |
+
|
| 250 |
+
return diff_array
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def preprocess_for_physmamba(frames: List[np.ndarray],
|
| 254 |
+
target_frames: int = 128,
|
| 255 |
+
target_size: int = 128) -> torch.Tensor:
|
| 256 |
+
"""Complete preprocessing pipeline for PhysMamba model."""
|
| 257 |
+
if len(frames) < target_frames:
|
| 258 |
+
frames = frames + [frames[-1]] * (target_frames - len(frames))
|
| 259 |
+
elif len(frames) > target_frames:
|
| 260 |
+
indices = np.linspace(0, len(frames)-1, target_frames).astype(int)
|
| 261 |
+
frames = [frames[i] for i in indices]
|
| 262 |
+
|
| 263 |
+
frames_rgb = [f[..., ::-1].copy() for f in frames]
|
| 264 |
+
frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb]
|
| 265 |
+
frames_diff = apply_diff_normalized(frames_resized)
|
| 266 |
+
frames_transposed = np.transpose(frames_diff, (3, 0, 1, 2))
|
| 267 |
+
frames_batched = np.expand_dims(frames_transposed, axis=0)
|
| 268 |
+
tensor = torch.from_numpy(frames_batched.astype(np.float32))
|
| 269 |
+
|
| 270 |
+
return tensor
|
| 271 |
+
|
| 272 |
+
HERE = Path(__file__).parent
|
| 273 |
+
MODEL_DIR = HERE / "final_model_release"
|
| 274 |
+
LOG_DIR = HERE / "logs"
|
| 275 |
+
ANALYSIS_DIR = HERE / "analysis"
|
| 276 |
+
for d in [MODEL_DIR, LOG_DIR, ANALYSIS_DIR]:
|
| 277 |
+
d.mkdir(exist_ok=True)
|
| 278 |
+
|
| 279 |
+
DEVICE = (
|
| 280 |
+
torch.device("cuda") if torch.cuda.is_available()
|
| 281 |
+
else torch.device("mps") if torch.backends.mps.is_available()
|
| 282 |
+
else torch.device("cpu")
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
FACE_CASCADE = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
|
| 286 |
+
|
| 287 |
+
DEFAULT_SIZE = 128 # input H=W to model
|
| 288 |
+
DEFAULT_T = 128 # clip length
|
| 289 |
+
DEFAULT_STRIDE = 8 # inference hop
|
| 290 |
+
DISPLAY_FPS = 10
|
| 291 |
+
|
| 292 |
+
VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v']
|
| 293 |
+
|
| 294 |
+
_HR_SMOOTH = None
|
| 295 |
+
REST_HR_TARGET = 72.0
|
| 296 |
+
REST_HR_RANGE = (55.0, 95.0)
|
| 297 |
+
MAX_JUMP_BPM = 8.0
|
| 298 |
+
|
| 299 |
+
# Recognized GT files for subject folders
|
| 300 |
+
GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"}
|
| 301 |
+
GT_EXTS = {".txt", ".csv", ".json"}
|
| 302 |
+
|
| 303 |
+
def _as_path(maybe) -> Optional[str]:
|
| 304 |
+
"""Return a filesystem path from Gradio values (str, dict, tempfile objects)."""
|
| 305 |
+
if maybe is None:
|
| 306 |
+
return None
|
| 307 |
+
if isinstance(maybe, str):
|
| 308 |
+
return maybe
|
| 309 |
+
if isinstance(maybe, dict):
|
| 310 |
+
return maybe.get("name") or maybe.get("path")
|
| 311 |
+
name = getattr(maybe, "name", None) # tempfile-like object
|
| 312 |
+
if isinstance(name, str):
|
| 313 |
+
return name
|
| 314 |
+
try:
|
| 315 |
+
return str(maybe)
|
| 316 |
+
except Exception:
|
| 317 |
+
return None
|
| 318 |
+
|
| 319 |
+
def _import_from_file(py_path: Path):
|
| 320 |
+
spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path))
|
| 321 |
+
if not spec or not spec.loader:
|
| 322 |
+
raise ImportError(f"Cannot import module from {py_path}")
|
| 323 |
+
mod = importlib.util.module_from_spec(spec)
|
| 324 |
+
spec.loader.exec_module(mod)
|
| 325 |
+
return mod
|
| 326 |
+
|
| 327 |
+
def _looks_like_video(p: Path) -> bool:
|
| 328 |
+
if p.suffix.lower() == ".mat":
|
| 329 |
+
return True
|
| 330 |
+
return p.suffix.lower() in VIDEO_EXTENSIONS
|
| 331 |
+
|
| 332 |
+
class SimpleActivationAttention:
|
| 333 |
+
"""Lightweight attention visualization without gradients."""
|
| 334 |
+
|
| 335 |
+
def __init__(self, model: nn.Module, device: torch.device):
|
| 336 |
+
self.model = model
|
| 337 |
+
self.device = device
|
| 338 |
+
self.activations = None
|
| 339 |
+
self.hook_handle = None
|
| 340 |
+
|
| 341 |
+
def _activation_hook(self, module, input, output):
|
| 342 |
+
"""Capture activations during forward pass."""
|
| 343 |
+
self.activations = output.detach()
|
| 344 |
+
|
| 345 |
+
def register_hook(self):
|
| 346 |
+
"""Register hook on a suitable layer."""
|
| 347 |
+
# Find the last convolutional layer before Mamba
|
| 348 |
+
target = None
|
| 349 |
+
target_name = None
|
| 350 |
+
|
| 351 |
+
for name, module in self.model.named_modules():
|
| 352 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower():
|
| 353 |
+
target = module
|
| 354 |
+
target_name = name
|
| 355 |
+
|
| 356 |
+
if target is None:
|
| 357 |
+
print("⚠ [attention_viz] No suitable conv layer found, attention disabled")
|
| 358 |
+
return
|
| 359 |
+
|
| 360 |
+
self.hook_handle = target.register_forward_hook(self._activation_hook)
|
| 361 |
+
print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})")
|
| 362 |
+
|
| 363 |
+
def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]:
|
| 364 |
+
"""Generate attention map from activations (call after forward pass)."""
|
| 365 |
+
try:
|
| 366 |
+
if self.activations is None:
|
| 367 |
+
return None
|
| 368 |
+
|
| 369 |
+
# Process activations to create spatial attention
|
| 370 |
+
act = self.activations
|
| 371 |
+
|
| 372 |
+
# Handle different tensor shapes
|
| 373 |
+
if act.dim() == 5: # [B, C, T, H, W]
|
| 374 |
+
# Average over time and channels
|
| 375 |
+
attention = act.mean(dim=[1, 2]) # -> [B, H, W]
|
| 376 |
+
elif act.dim() == 4: # [B, C, H, W]
|
| 377 |
+
attention = act.mean(dim=1) # -> [B, H, W]
|
| 378 |
+
else:
|
| 379 |
+
print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}")
|
| 380 |
+
return None
|
| 381 |
+
|
| 382 |
+
# Convert to numpy
|
| 383 |
+
attention = attention.squeeze().cpu().numpy()
|
| 384 |
+
|
| 385 |
+
# Normalize to [0, 1]
|
| 386 |
+
if attention.max() > attention.min():
|
| 387 |
+
attention = (attention - attention.min()) / (attention.max() - attention.min())
|
| 388 |
+
|
| 389 |
+
return attention
|
| 390 |
+
|
| 391 |
+
except Exception as e:
|
| 392 |
+
print(f"⚠ [attention_viz] Generation failed: {e}")
|
| 393 |
+
return None
|
| 394 |
+
|
| 395 |
+
def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray:
|
| 396 |
+
"""Overlay heatmap on frame."""
|
| 397 |
+
h, w = frame.shape[:2]
|
| 398 |
+
heatmap_resized = cv2.resize(heatmap, (w, h))
|
| 399 |
+
heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8)
|
| 400 |
+
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 401 |
+
overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0)
|
| 402 |
+
return overlay
|
| 403 |
+
|
| 404 |
+
def cleanup(self):
|
| 405 |
+
if self.hook_handle is not None:
|
| 406 |
+
self.hook_handle.remove()
|
| 407 |
+
|
| 408 |
+
class VideoReader:
|
| 409 |
+
"""
|
| 410 |
+
Unified frame reader:
|
| 411 |
+
• Regular videos via cv2.VideoCapture
|
| 412 |
+
• .mat 'videos' (e.g., SCAMPS): expects array like (T,H,W,3) or (H,W,T,3) or (T,H,W)
|
| 413 |
+
Returns frames as BGR uint8.
|
| 414 |
+
"""
|
| 415 |
+
def __init__(self, path: str):
|
| 416 |
+
self.path = str(path)
|
| 417 |
+
self._cap = None
|
| 418 |
+
self._mat = None
|
| 419 |
+
self._idx = 0
|
| 420 |
+
self._len = 0
|
| 421 |
+
self._shape = None
|
| 422 |
+
|
| 423 |
+
if self.path.lower().endswith(".mat") and MAT_SUPPORT:
|
| 424 |
+
self._open_mat(self.path)
|
| 425 |
+
else:
|
| 426 |
+
self._open_cv(self.path)
|
| 427 |
+
|
| 428 |
+
def _open_cv(self, path: str):
|
| 429 |
+
cap = cv2.VideoCapture(path)
|
| 430 |
+
if not cap.isOpened():
|
| 431 |
+
raise RuntimeError("Cannot open video")
|
| 432 |
+
self._cap = cap
|
| 433 |
+
self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
| 434 |
+
|
| 435 |
+
def _open_mat(self, path: str):
|
| 436 |
+
try:
|
| 437 |
+
md = loadmat(path)
|
| 438 |
+
# Common keys in SCAMPS-like dumps
|
| 439 |
+
for key in ("video", "frames", "vid", "data"):
|
| 440 |
+
if key in md and isinstance(md[key], np.ndarray):
|
| 441 |
+
arr = md[key]
|
| 442 |
+
break
|
| 443 |
+
else:
|
| 444 |
+
arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None)
|
| 445 |
+
if arr is None:
|
| 446 |
+
raise RuntimeError("No ndarray found in .mat")
|
| 447 |
+
|
| 448 |
+
a = np.asarray(arr)
|
| 449 |
+
# Normalize to (T,H,W,3)
|
| 450 |
+
if a.ndim == 4:
|
| 451 |
+
if a.shape[-1] == 3:
|
| 452 |
+
if a.shape[0] < a.shape[2]: # (T,H,W,3) heuristic
|
| 453 |
+
v = a
|
| 454 |
+
else: # (H,W,T,3) -> (T,H,W,3)
|
| 455 |
+
v = np.transpose(a, (2, 0, 1, 3))
|
| 456 |
+
else:
|
| 457 |
+
v = a[..., :1] # take first channel
|
| 458 |
+
elif a.ndim == 3:
|
| 459 |
+
if a.shape[0] < a.shape[2]: # (T,H,W)
|
| 460 |
+
v = a
|
| 461 |
+
else: # (H,W,T) -> (T,H,W)
|
| 462 |
+
v = np.transpose(a, (2, 0, 1))
|
| 463 |
+
v = v[..., None]
|
| 464 |
+
else:
|
| 465 |
+
raise RuntimeError(f"Unsupported .mat video shape: {a.shape}")
|
| 466 |
+
|
| 467 |
+
v = np.ascontiguousarray(v)
|
| 468 |
+
if v.shape[-1] == 1:
|
| 469 |
+
v = np.repeat(v, 3, axis=-1)
|
| 470 |
+
v = v.astype(np.uint8)
|
| 471 |
+
self._mat = v
|
| 472 |
+
self._len = v.shape[0]
|
| 473 |
+
self._shape = v.shape[1:3]
|
| 474 |
+
except Exception as e:
|
| 475 |
+
raise RuntimeError(f"Failed to open .mat video: {e}")
|
| 476 |
+
|
| 477 |
+
def read(self):
|
| 478 |
+
"""Return (ret, frame[BGR]) like cv2.VideoCapture.read()."""
|
| 479 |
+
if self._mat is not None:
|
| 480 |
+
if self._idx >= self._len:
|
| 481 |
+
return False, None
|
| 482 |
+
frame = self._mat[self._idx]
|
| 483 |
+
self._idx += 1
|
| 484 |
+
return True, frame
|
| 485 |
+
else:
|
| 486 |
+
return self._cap.read()
|
| 487 |
+
|
| 488 |
+
def fps(self, fallback: int = 30) -> int:
|
| 489 |
+
if self._mat is not None:
|
| 490 |
+
return fallback # .mat typically lacks FPS; caller can override
|
| 491 |
+
f = self._cap.get(cv2.CAP_PROP_FPS)
|
| 492 |
+
return int(f) if f and f > 0 else fallback
|
| 493 |
+
|
| 494 |
+
def length(self) -> int:
|
| 495 |
+
return self._len
|
| 496 |
+
|
| 497 |
+
def release(self):
|
| 498 |
+
if self._cap is not None:
|
| 499 |
+
self._cap.release()
|
| 500 |
+
|
| 501 |
+
def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]:
|
| 502 |
+
x, y, w, h = face
|
| 503 |
+
# forehead
|
| 504 |
+
fh = frame[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)]
|
| 505 |
+
# cheeks
|
| 506 |
+
ck = frame[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)]
|
| 507 |
+
# full face
|
| 508 |
+
ff = frame[y:y + h, x:x + w]
|
| 509 |
+
return {"forehead": fh, "cheeks": ck, "face": ff}
|
| 510 |
+
|
| 511 |
+
def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float:
|
| 512 |
+
if patch is None or patch.size == 0:
|
| 513 |
+
return -1e9
|
| 514 |
+
g = patch[..., 1].astype(np.float32) / 255.0 # green channel
|
| 515 |
+
g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling
|
| 516 |
+
g = g - g.mean()
|
| 517 |
+
b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band")
|
| 518 |
+
try:
|
| 519 |
+
y = signal.filtfilt(b, a, g, method="gust")
|
| 520 |
+
except Exception:
|
| 521 |
+
y = g
|
| 522 |
+
return float((y ** 2).mean())
|
| 523 |
+
|
| 524 |
+
def pick_auto_roi(face: Tuple[int, int, int, int],
|
| 525 |
+
frame: np.ndarray,
|
| 526 |
+
attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]:
|
| 527 |
+
"""Simple ROI selection."""
|
| 528 |
+
cands = roi_candidates(face, frame)
|
| 529 |
+
scores = {k: roi_quality_score(v) for k, v in cands.items()}
|
| 530 |
+
|
| 531 |
+
if attn is not None and attn.size:
|
| 532 |
+
H, W = frame.shape[:2]
|
| 533 |
+
try:
|
| 534 |
+
attn_resized = cv2.resize(attn, (W, H))
|
| 535 |
+
x, y, w, h = face
|
| 536 |
+
fh_attn = attn_resized[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)].mean() if attn_resized.size > 0 else 0.0
|
| 537 |
+
ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0
|
| 538 |
+
ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0
|
| 539 |
+
scores['forehead'] += fh_attn * 0.2
|
| 540 |
+
scores['cheeks'] += ck_attn * 0.2
|
| 541 |
+
scores['face'] += ff_attn * 0.2
|
| 542 |
+
except Exception:
|
| 543 |
+
pass
|
| 544 |
+
|
| 545 |
+
best = max(scores, key=scores.get)
|
| 546 |
+
return cands[best], best
|
| 547 |
+
|
| 548 |
+
def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]:
|
| 549 |
+
"""
|
| 550 |
+
Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None).
|
| 551 |
+
Heuristics:
|
| 552 |
+
- Subject folder: any directory containing at least one video-like file (.mat or common video).
|
| 553 |
+
- If multiple videos, pick the largest by size.
|
| 554 |
+
- GT file: prefer known names in GT_FILENAMES, else any .txt/.csv/.json not named readme.
|
| 555 |
+
"""
|
| 556 |
+
pairs: List[Tuple[str, Optional[str]]] = []
|
| 557 |
+
if not root_dir.exists():
|
| 558 |
+
return pairs
|
| 559 |
+
|
| 560 |
+
def pick_pair(folder: Path) -> Optional[Tuple[str, Optional[str]]]:
|
| 561 |
+
vids = [p for p in folder.rglob("*") if p.is_file() and _looks_like_video(p)]
|
| 562 |
+
if not vids:
|
| 563 |
+
return None
|
| 564 |
+
vids.sort(key=lambda p: p.stat().st_size if p.exists() else 0, reverse=True)
|
| 565 |
+
video = vids[0]
|
| 566 |
+
|
| 567 |
+
gt: Optional[Path] = None
|
| 568 |
+
for p in folder.rglob("*"):
|
| 569 |
+
if p.is_file() and p.name.lower() in GT_FILENAMES:
|
| 570 |
+
gt = p
|
| 571 |
+
break
|
| 572 |
+
if gt is None:
|
| 573 |
+
cands = [
|
| 574 |
+
p for p in folder.rglob("*")
|
| 575 |
+
if p.is_file() and p.suffix.lower() in GT_EXTS and "readme" not in p.name.lower()
|
| 576 |
+
]
|
| 577 |
+
if cands:
|
| 578 |
+
gt = cands[0]
|
| 579 |
+
return str(video), (str(gt) if gt else None)
|
| 580 |
+
|
| 581 |
+
subs = [d for d in root_dir.iterdir() if d.is_dir()]
|
| 582 |
+
if subs:
|
| 583 |
+
for sub in subs:
|
| 584 |
+
pair = pick_pair(sub)
|
| 585 |
+
if pair:
|
| 586 |
+
pairs.append(pair)
|
| 587 |
+
else:
|
| 588 |
+
pair = pick_pair(root_dir) # the root itself might be a single-subject folder
|
| 589 |
+
if pair:
|
| 590 |
+
pairs.append(pair)
|
| 591 |
+
|
| 592 |
+
# Deduplicate
|
| 593 |
+
seen = set()
|
| 594 |
+
uniq: List[Tuple[str, Optional[str]]] = []
|
| 595 |
+
for v, g in pairs:
|
| 596 |
+
key = (v, g or "")
|
| 597 |
+
if key not in seen:
|
| 598 |
+
seen.add(key)
|
| 599 |
+
uniq.append((v, g))
|
| 600 |
+
return uniq
|
| 601 |
+
|
| 602 |
+
def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"):
|
| 603 |
+
import inspect
|
| 604 |
+
|
| 605 |
+
if model_file:
|
| 606 |
+
model_path = (repo_root / model_file).resolve()
|
| 607 |
+
if model_path.exists():
|
| 608 |
+
try:
|
| 609 |
+
mod = _import_from_file(model_path)
|
| 610 |
+
if hasattr(mod, model_class):
|
| 611 |
+
return getattr(mod, model_class)
|
| 612 |
+
except Exception:
|
| 613 |
+
pass
|
| 614 |
+
|
| 615 |
+
search_dirs = [
|
| 616 |
+
repo_root / "neural_methods" / "model",
|
| 617 |
+
repo_root / "neural_methods",
|
| 618 |
+
repo_root
|
| 619 |
+
]
|
| 620 |
+
|
| 621 |
+
name_pattern = re.compile(r"mamba", re.IGNORECASE)
|
| 622 |
+
|
| 623 |
+
for base_dir in search_dirs:
|
| 624 |
+
if not base_dir.exists():
|
| 625 |
+
continue
|
| 626 |
+
|
| 627 |
+
for py_file in base_dir.rglob("*.py"):
|
| 628 |
+
if "__pycache__" in str(py_file) or "mamba_ssm" in str(py_file):
|
| 629 |
+
continue
|
| 630 |
+
|
| 631 |
+
try:
|
| 632 |
+
mod = _import_from_file(py_file)
|
| 633 |
+
for name, obj in inspect.getmembers(mod):
|
| 634 |
+
if callable(obj) and name_pattern.search(name) and "ssm" not in name.lower():
|
| 635 |
+
if inspect.isclass(obj) and issubclass(obj, nn.Module):
|
| 636 |
+
return obj
|
| 637 |
+
except Exception:
|
| 638 |
+
continue
|
| 639 |
+
|
| 640 |
+
raise ImportError(f"Could not find PhysMamba model class")
|
| 641 |
+
|
| 642 |
+
def load_physmamba_model(ckpt_path: Path, device: torch.device,
|
| 643 |
+
model_file: str = "", model_class: str = "PhysMamba"):
|
| 644 |
+
|
| 645 |
+
repo_root = Path(".").resolve()
|
| 646 |
+
Builder = find_physmamba_builder(repo_root, model_file, model_class)
|
| 647 |
+
|
| 648 |
+
import inspect
|
| 649 |
+
ctor_trials = [
|
| 650 |
+
{},
|
| 651 |
+
{"d_model": 96},
|
| 652 |
+
{"dim": 96},
|
| 653 |
+
{"d_model": 96, "frames": 128, "img_size": 128, "in_chans": 3},
|
| 654 |
+
{"frames": 128, "img_size": 128, "in_chans": 3},
|
| 655 |
+
{"frame_depth": 3},
|
| 656 |
+
]
|
| 657 |
+
|
| 658 |
+
model = None
|
| 659 |
+
for kwargs in ctor_trials:
|
| 660 |
+
try:
|
| 661 |
+
candidate = Builder(**kwargs) if inspect.isclass(Builder) else Builder(**kwargs)
|
| 662 |
+
if isinstance(candidate, nn.Module):
|
| 663 |
+
model = candidate
|
| 664 |
+
break
|
| 665 |
+
except Exception:
|
| 666 |
+
continue
|
| 667 |
+
|
| 668 |
+
if model is None:
|
| 669 |
+
raise RuntimeError("Could not construct PhysMamba model")
|
| 670 |
+
|
| 671 |
+
try:
|
| 672 |
+
checkpoint = torch.load(str(ckpt_path), map_location="cpu")
|
| 673 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 674 |
+
|
| 675 |
+
try:
|
| 676 |
+
model.load_state_dict(state_dict, strict=False)
|
| 677 |
+
except Exception:
|
| 678 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 679 |
+
model.load_state_dict(state_dict, strict=False)
|
| 680 |
+
except Exception:
|
| 681 |
+
pass
|
| 682 |
+
|
| 683 |
+
model.to(device).eval()
|
| 684 |
+
|
| 685 |
+
try:
|
| 686 |
+
with torch.no_grad():
|
| 687 |
+
_ = model(torch.zeros(1, 3, 8, 128, 128, device=device))
|
| 688 |
+
except Exception:
|
| 689 |
+
pass
|
| 690 |
+
|
| 691 |
+
# Disable attention visualization since model forward pass is incompatible
|
| 692 |
+
attention_viz = None
|
| 693 |
+
|
| 694 |
+
return model, attention_viz
|
| 695 |
+
|
| 696 |
+
def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray:
|
| 697 |
+
"""
|
| 698 |
+
Stable band-pass with edge-safety and parameter clipping.
|
| 699 |
+
"""
|
| 700 |
+
x = np.asarray(x, dtype=float)
|
| 701 |
+
n = int(fs * 2)
|
| 702 |
+
if x.size < max(n, 8): # need at least ~2s
|
| 703 |
+
return x
|
| 704 |
+
|
| 705 |
+
nyq = 0.5 * fs
|
| 706 |
+
lo = max(low / nyq, 1e-6)
|
| 707 |
+
hi = min(high / nyq, 0.999_999)
|
| 708 |
+
if not (0.0 < lo < hi < 1.0):
|
| 709 |
+
return x
|
| 710 |
+
|
| 711 |
+
try:
|
| 712 |
+
b, a = signal.butter(order, [lo, hi], btype="band")
|
| 713 |
+
# padlen must be < len(x); reduce when short
|
| 714 |
+
padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1))
|
| 715 |
+
return signal.filtfilt(b, a, x, padlen=padlen)
|
| 716 |
+
except Exception:
|
| 717 |
+
return x
|
| 718 |
+
|
| 719 |
+
def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float:
|
| 720 |
+
"""
|
| 721 |
+
HR (BPM) via Welch PSD peak in [lo, hi] Hz.
|
| 722 |
+
"""
|
| 723 |
+
x = np.asarray(x, dtype=float)
|
| 724 |
+
if x.size < int(fs * 4.0): # need ~4s for a usable PSD
|
| 725 |
+
return 0.0
|
| 726 |
+
try:
|
| 727 |
+
# nperseg tuned for short windows while avoiding tiny segments
|
| 728 |
+
nper = int(min(max(64, fs * 2), min(512, x.size)))
|
| 729 |
+
f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant")
|
| 730 |
+
|
| 731 |
+
mask = (f >= lo) & (f <= hi)
|
| 732 |
+
if not np.any(mask):
|
| 733 |
+
return 0.0
|
| 734 |
+
f_band = f[mask]
|
| 735 |
+
p_band = pxx[mask]
|
| 736 |
+
|
| 737 |
+
if p_band.size == 0 or np.all(~np.isfinite(p_band)):
|
| 738 |
+
return 0.0
|
| 739 |
+
|
| 740 |
+
fpk = float(f_band[np.argmax(p_band)])
|
| 741 |
+
bpm = fpk * 60.0
|
| 742 |
+
# clip to plausible range
|
| 743 |
+
return float(np.clip(bpm, 30.0, 220.0))
|
| 744 |
+
except Exception:
|
| 745 |
+
return 0.0
|
| 746 |
+
|
| 747 |
+
def compute_rmssd(x: np.ndarray, fs: int = 30) -> float:
|
| 748 |
+
"""
|
| 749 |
+
HRV RMSSD from peaks; robust to short/flat segments.
|
| 750 |
+
"""
|
| 751 |
+
x = np.asarray(x, dtype=float)
|
| 752 |
+
if x.size < int(fs * 5.0):
|
| 753 |
+
return 0.0
|
| 754 |
+
try:
|
| 755 |
+
# peak distance ~ 0.5s minimum (avoid double counting)
|
| 756 |
+
peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs)))
|
| 757 |
+
if len(peaks) < 3:
|
| 758 |
+
return 0.0
|
| 759 |
+
rr = np.diff(peaks) / fs * 1000.0 # ms
|
| 760 |
+
if rr.size < 2:
|
| 761 |
+
return 0.0
|
| 762 |
+
return float(np.sqrt(np.mean(np.diff(rr) ** 2)))
|
| 763 |
+
except Exception:
|
| 764 |
+
return 0.0
|
| 765 |
+
|
| 766 |
+
def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]:
|
| 767 |
+
"""
|
| 768 |
+
Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band.
|
| 769 |
+
Signature unchanged to avoid breaking callers.
|
| 770 |
+
"""
|
| 771 |
+
global _HR_SMOOTH
|
| 772 |
+
|
| 773 |
+
y = np.asarray(pred, dtype=float)
|
| 774 |
+
if y.size == 0:
|
| 775 |
+
return y, 0.0
|
| 776 |
+
|
| 777 |
+
# 1) band-limit
|
| 778 |
+
y_filt = bandpass_filter(y, fs=fs, low=0.7, high=3.5, order=4)
|
| 779 |
+
|
| 780 |
+
# 2) HR estimate
|
| 781 |
+
hr = hr_from_welch(y_filt, fs=fs, lo=0.7, hi=3.5)
|
| 782 |
+
|
| 783 |
+
# 3) gentle attraction to resting band (if way off)
|
| 784 |
+
if hr > 0:
|
| 785 |
+
lo, hi = REST_HR_RANGE
|
| 786 |
+
if hr < lo or hr > hi:
|
| 787 |
+
dist = abs(hr - REST_HR_TARGET)
|
| 788 |
+
# farther away -> stronger pull
|
| 789 |
+
alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65))
|
| 790 |
+
hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET
|
| 791 |
+
|
| 792 |
+
# 4) temporal smoothing to limit frame-to-frame jumps
|
| 793 |
+
if hr > 0:
|
| 794 |
+
if _HR_SMOOTH is None:
|
| 795 |
+
_HR_SMOOTH = hr
|
| 796 |
+
else:
|
| 797 |
+
step = float(np.clip(hr - _HR_SMOOTH, -MAX_JUMP_BPM, MAX_JUMP_BPM))
|
| 798 |
+
_HR_SMOOTH = _HR_SMOOTH + 0.6 * step
|
| 799 |
+
hr = float(_HR_SMOOTH)
|
| 800 |
+
|
| 801 |
+
return y_filt, float(hr)
|
| 802 |
+
|
| 803 |
+
def draw_face_and_roi(frame_bgr: np.ndarray,
|
| 804 |
+
face_bbox: Optional[Tuple[int, int, int, int]],
|
| 805 |
+
roi_bbox: Optional[Tuple[int, int, int, int]],
|
| 806 |
+
label: str = "ROI") -> np.ndarray:
|
| 807 |
+
"""
|
| 808 |
+
Draw face (green) and ROI (cyan) rectangles on a copy of the frame.
|
| 809 |
+
"""
|
| 810 |
+
vis = frame_bgr.copy()
|
| 811 |
+
if face_bbox is not None:
|
| 812 |
+
x, y, w, h = face_bbox
|
| 813 |
+
cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 230, 0), 2)
|
| 814 |
+
cv2.putText(vis, "FACE", (x, max(20, y - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 230, 0), 2)
|
| 815 |
+
if roi_bbox is not None:
|
| 816 |
+
rx, ry, rw, rh = roi_bbox
|
| 817 |
+
cv2.rectangle(vis, (rx, ry), (rx + rw, ry + rh), (255, 220, 0), 2)
|
| 818 |
+
cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2)
|
| 819 |
+
return vis
|
| 820 |
+
|
| 821 |
+
def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int],
|
| 822 |
+
roi_type: str,
|
| 823 |
+
frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]:
|
| 824 |
+
"""
|
| 825 |
+
Compute the ROI rectangle (x,y,w,h) from a face bbox and your ROI rule.
|
| 826 |
+
Matches your crop_roi geometry.
|
| 827 |
+
"""
|
| 828 |
+
x, y, w, h = face_bbox
|
| 829 |
+
H, W = frame_shape[:2]
|
| 830 |
+
if roi_type == "forehead":
|
| 831 |
+
rx = int(x + 0.25 * w); rw = int(0.5 * w)
|
| 832 |
+
ry = int(y + 0.10 * h); rh = int(0.20 * h)
|
| 833 |
+
elif roi_type == "cheeks":
|
| 834 |
+
rx = int(x + 0.15 * w); rw = int(0.70 * w)
|
| 835 |
+
ry = int(y + 0.55 * h); rh = int(0.30 * h)
|
| 836 |
+
else:
|
| 837 |
+
rx, ry, rw, rh = x, y, w, h
|
| 838 |
+
|
| 839 |
+
rx2 = min(W, rx + rw)
|
| 840 |
+
ry2 = min(H, ry + rh)
|
| 841 |
+
rx = max(0, rx); ry = max(0, ry)
|
| 842 |
+
if rx2 <= rx or ry2 <= ry:
|
| 843 |
+
return (0, 0, 0, 0)
|
| 844 |
+
return (rx, ry, rx2 - rx, ry2 - ry)
|
| 845 |
+
|
| 846 |
+
def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray:
|
| 847 |
+
"""
|
| 848 |
+
Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR.
|
| 849 |
+
Assumes chw = (3, H, W) with zero-mean, unit-var normalization per-frame.
|
| 850 |
+
"""
|
| 851 |
+
if chw is None or chw.ndim != 3 or chw.shape[0] != 3:
|
| 852 |
+
return np.zeros((128, 128, 3), dtype=np.uint8)
|
| 853 |
+
|
| 854 |
+
# Undo channel-first & normalization to a viewable image
|
| 855 |
+
img = chw.copy()
|
| 856 |
+
# Re-normalize to 0..1 by min-max of the tensor to "show" contrast
|
| 857 |
+
vmin, vmax = float(img.min()), float(img.max())
|
| 858 |
+
if vmax <= vmin + 1e-6:
|
| 859 |
+
img = np.zeros_like(img)
|
| 860 |
+
else:
|
| 861 |
+
img = (img - vmin) / (vmax - vmin)
|
| 862 |
+
|
| 863 |
+
img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR
|
| 864 |
+
return img
|
| 865 |
+
|
| 866 |
+
def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
|
| 867 |
+
if gt_len <= 1:
|
| 868 |
+
return None
|
| 869 |
+
if gt_fs and gt_fs > 0:
|
| 870 |
+
return np.arange(gt_len, dtype=float) / float(gt_fs)
|
| 871 |
+
return None # will fall back to length-matching overlay
|
| 872 |
+
|
| 873 |
+
def plot_signals_with_gt(time_axis: np.ndarray,
|
| 874 |
+
raw_signal: np.ndarray,
|
| 875 |
+
post_signal: np.ndarray,
|
| 876 |
+
fs: int,
|
| 877 |
+
out_path: str,
|
| 878 |
+
gt_time: Optional[np.ndarray] = None,
|
| 879 |
+
gt_bvp: Optional[np.ndarray] = None,
|
| 880 |
+
title: str = "rPPG Signals (Pred vs GT)") -> str:
|
| 881 |
+
"""
|
| 882 |
+
Save a 3-pane plot: (1) predicted raw, (2) predicted post, (3) overlay Pred vs GT (normalized).
|
| 883 |
+
If GT is provided, it is resampled to the prediction time grid and lag-aligned
|
| 884 |
+
(±5 s) using cross-correlation. The overlay includes Pearson r, lag, and HR stats.
|
| 885 |
+
"""
|
| 886 |
+
import numpy as _np
|
| 887 |
+
import matplotlib.pyplot as _plt
|
| 888 |
+
from matplotlib.gridspec import GridSpec as _GridSpec
|
| 889 |
+
|
| 890 |
+
def z(x):
|
| 891 |
+
x = _np.asarray(x, dtype=float)
|
| 892 |
+
if x.size == 0:
|
| 893 |
+
return x
|
| 894 |
+
m = float(_np.nanmean(x))
|
| 895 |
+
s = float(_np.nanstd(x)) + 1e-8
|
| 896 |
+
return (x - m) / s
|
| 897 |
+
|
| 898 |
+
def _bandpass(x, fs_local, lo=0.7, hi=3.5, order=4):
|
| 899 |
+
try:
|
| 900 |
+
return bandpass_filter(_np.asarray(x, float), fs=fs_local, low=lo, high=hi, order=order)
|
| 901 |
+
except Exception:
|
| 902 |
+
return _np.asarray(x, float)
|
| 903 |
+
|
| 904 |
+
def _welch_hr(x, fs_local):
|
| 905 |
+
try:
|
| 906 |
+
return float(hr_from_welch(_np.asarray(x, float), fs=fs_local, lo=0.7, hi=3.5))
|
| 907 |
+
except Exception:
|
| 908 |
+
return 0.0
|
| 909 |
+
|
| 910 |
+
def _safe_interp(x_t, y, t_new):
|
| 911 |
+
"""Monotonic-safe 1D interpolation with clipping to valid domain."""
|
| 912 |
+
x_t = _np.asarray(x_t, dtype=float).ravel()
|
| 913 |
+
y = _np.asarray(y, dtype=float).ravel()
|
| 914 |
+
t_new = _np.asarray(t_new, dtype=float).ravel()
|
| 915 |
+
|
| 916 |
+
if x_t.size < 2 or y.size != x_t.size:
|
| 917 |
+
# Fallback: length-based resize to t_new length
|
| 918 |
+
if y.size == 0 or t_new.size == 0:
|
| 919 |
+
return _np.zeros_like(t_new)
|
| 920 |
+
idx = _np.linspace(0, y.size - 1, num=t_new.size)
|
| 921 |
+
return _np.interp(_np.arange(t_new.size), idx, y)
|
| 922 |
+
|
| 923 |
+
# Enforce strictly increasing time (dedup if needed)
|
| 924 |
+
order = _np.argsort(x_t)
|
| 925 |
+
x_t = x_t[order]
|
| 926 |
+
y = y[order]
|
| 927 |
+
mask = _np.concatenate(([True], _np.diff(x_t) > 0))
|
| 928 |
+
x_t = x_t[mask]
|
| 929 |
+
y = y[mask]
|
| 930 |
+
# Clip t_new to the valid domain to avoid edge extrapolation artifacts
|
| 931 |
+
t_clip = _np.clip(t_new, x_t[0], x_t[-1])
|
| 932 |
+
return _np.interp(t_clip, x_t, y)
|
| 933 |
+
|
| 934 |
+
def _best_lag(pred, gt, fs_local, max_lag_s=5.0):
|
| 935 |
+
"""Return lag (sec) that best aligns GT to Pred using cross-correlation on z-scored signals."""
|
| 936 |
+
x = z(pred); y = z(gt)
|
| 937 |
+
if x.size < 8 or y.size < 8:
|
| 938 |
+
return 0.0
|
| 939 |
+
n = int(min(len(x), len(y)))
|
| 940 |
+
x = x[:n]; y = y[:n]
|
| 941 |
+
max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local))))
|
| 942 |
+
# valid lags: negative means GT should be shifted left (advance) relative to Pred
|
| 943 |
+
lags = _np.arange(-max_lag, max_lag + 1)
|
| 944 |
+
# compute correlation for each lag
|
| 945 |
+
best_corr = -_np.inf
|
| 946 |
+
best_lag = 0
|
| 947 |
+
for L in lags:
|
| 948 |
+
if L < 0:
|
| 949 |
+
xx = x[-L:n]
|
| 950 |
+
yy = y[0:n+L]
|
| 951 |
+
elif L > 0:
|
| 952 |
+
xx = x[0:n-L]
|
| 953 |
+
yy = y[L:n]
|
| 954 |
+
else:
|
| 955 |
+
xx = x
|
| 956 |
+
yy = y
|
| 957 |
+
if xx.size < 8 or yy.size < 8:
|
| 958 |
+
continue
|
| 959 |
+
c = _np.corrcoef(xx, yy)[0, 1]
|
| 960 |
+
if _np.isfinite(c) and c > best_corr:
|
| 961 |
+
best_corr = c
|
| 962 |
+
best_lag = L
|
| 963 |
+
return float(best_lag / float(fs_local))
|
| 964 |
+
|
| 965 |
+
def _apply_lag(y, lag_sec, fs_local):
|
| 966 |
+
"""Shift y by lag_sec (positive => delay GT) using sample roll; edges set to NaN."""
|
| 967 |
+
y = _np.asarray(y, float)
|
| 968 |
+
if y.size == 0 or fs_local <= 0:
|
| 969 |
+
return y
|
| 970 |
+
shift = int(round(lag_sec * fs_local))
|
| 971 |
+
if shift == 0:
|
| 972 |
+
return y
|
| 973 |
+
out = _np.empty_like(y)
|
| 974 |
+
out[:] = _np.nan
|
| 975 |
+
if shift > 0:
|
| 976 |
+
# delay: move content right
|
| 977 |
+
out[shift:] = y[:-shift]
|
| 978 |
+
else:
|
| 979 |
+
# advance: move content left
|
| 980 |
+
out[:shift] = y[-shift:]
|
| 981 |
+
return out
|
| 982 |
+
|
| 983 |
+
t = _np.asarray(time_axis, dtype=float)
|
| 984 |
+
raw = _np.asarray(raw_signal, dtype=float)
|
| 985 |
+
post = _np.asarray(post_signal, dtype=float)
|
| 986 |
+
|
| 987 |
+
# guard
|
| 988 |
+
if t.size == 0:
|
| 989 |
+
t = _np.arange(post.size, dtype=float) / max(fs, 1)
|
| 990 |
+
|
| 991 |
+
have_gt = gt_bvp is not None and _np.asarray(gt_bvp).size > 0
|
| 992 |
+
gt_on_pred = None
|
| 993 |
+
lag_sec = 0.0
|
| 994 |
+
pearson_r = _np.nan
|
| 995 |
+
hr_pred = _welch_hr(_bandpass(post, fs), fs)
|
| 996 |
+
hr_gt = 0.0
|
| 997 |
+
|
| 998 |
+
if have_gt:
|
| 999 |
+
gt = _np.asarray(gt_bvp, dtype=float).ravel()
|
| 1000 |
+
if gt_time is not None and _np.asarray(gt_time).size == gt.size:
|
| 1001 |
+
gt_t = _np.asarray(gt_time, dtype=float).ravel()
|
| 1002 |
+
gt_on_pred = _safe_interp(gt_t, gt, t)
|
| 1003 |
+
else:
|
| 1004 |
+
# No time vector: try length-based mapping to pred time grid
|
| 1005 |
+
gt_on_pred = _safe_interp(_np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size),
|
| 1006 |
+
gt, t)
|
| 1007 |
+
|
| 1008 |
+
# Band-limit both before correlation/HR
|
| 1009 |
+
pred_bp = _bandpass(post, fs)
|
| 1010 |
+
gt_bp = _bandpass(gt_on_pred, fs)
|
| 1011 |
+
|
| 1012 |
+
# Estimate best lag (sec) of GT relative to Pred
|
| 1013 |
+
lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0)
|
| 1014 |
+
|
| 1015 |
+
# Apply lag to GT for visualization and correlation
|
| 1016 |
+
gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs)
|
| 1017 |
+
|
| 1018 |
+
# Compute Pearson r on overlapping valid samples
|
| 1019 |
+
valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp)
|
| 1020 |
+
if valid.sum() >= 16:
|
| 1021 |
+
pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1])
|
| 1022 |
+
else:
|
| 1023 |
+
pearson_r = _np.nan
|
| 1024 |
+
|
| 1025 |
+
hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs)
|
| 1026 |
+
|
| 1027 |
+
|
| 1028 |
+
_plt.figure(figsize=(13, 6), dpi=110)
|
| 1029 |
+
gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35)
|
| 1030 |
+
|
| 1031 |
+
# (1) Raw Pred
|
| 1032 |
+
ax1 = _plt.subplot(gs[0, 0])
|
| 1033 |
+
ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5)
|
| 1034 |
+
ax1.set_title(f"Predicted (Raw) — fs={fs} Hz")
|
| 1035 |
+
ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude")
|
| 1036 |
+
ax1.grid(True, alpha=0.3)
|
| 1037 |
+
|
| 1038 |
+
# (2) Post Pred
|
| 1039 |
+
ax2 = _plt.subplot(gs[0, 1])
|
| 1040 |
+
ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5)
|
| 1041 |
+
ax2.set_title("Predicted (Post-processed)")
|
| 1042 |
+
ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude")
|
| 1043 |
+
ax2.grid(True, alpha=0.3)
|
| 1044 |
+
|
| 1045 |
+
# (3) Overlay Pred vs GT (z-scored) OR just post
|
| 1046 |
+
ax3 = _plt.subplot(gs[1, :])
|
| 1047 |
+
ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6)
|
| 1048 |
+
|
| 1049 |
+
if have_gt and gt_on_pred is not None:
|
| 1050 |
+
gt_bp = _bandpass(gt_on_pred, fs)
|
| 1051 |
+
gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs)
|
| 1052 |
+
ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9)
|
| 1053 |
+
|
| 1054 |
+
# metrics box
|
| 1055 |
+
txt = [
|
| 1056 |
+
f"HR_pred: {hr_pred:.1f} BPM",
|
| 1057 |
+
f"HR_gt: {hr_gt:.1f} BPM",
|
| 1058 |
+
f"Pearson r: {pearson_r:.3f}" if _np.isfinite(pearson_r) else "Pearson r: --",
|
| 1059 |
+
f"Lag: {lag_sec:+.2f} s"
|
| 1060 |
+
]
|
| 1061 |
+
ax3.text(0.01, 0.98, "\n".join(txt), transform=ax3.transAxes,
|
| 1062 |
+
va="top", ha="left", fontsize=9,
|
| 1063 |
+
bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="0.8", alpha=0.9))
|
| 1064 |
+
ax3.set_title("Pred vs GT (z-score overlay, lag-aligned)")
|
| 1065 |
+
else:
|
| 1066 |
+
ax3.set_title("Pred vs GT (no GT provided)")
|
| 1067 |
+
|
| 1068 |
+
ax3.set_xlabel("Time (s)"); ax3.set_ylabel("z")
|
| 1069 |
+
ax3.grid(True, alpha=0.3)
|
| 1070 |
+
ax3.legend(loc="upper right")
|
| 1071 |
+
|
| 1072 |
+
_plt.suptitle(title, fontweight="bold")
|
| 1073 |
+
_plt.tight_layout(rect=[0, 0.02, 1, 0.97])
|
| 1074 |
+
_plt.savefig(out_path, bbox_inches="tight")
|
| 1075 |
+
_plt.close()
|
| 1076 |
+
return out_path
|
| 1077 |
+
|
| 1078 |
+
def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
|
| 1079 |
+
"""
|
| 1080 |
+
Robust single-face detector with a few practical guards:
|
| 1081 |
+
- converts to gray safely
|
| 1082 |
+
- equalizes histogram (helps underexposure)
|
| 1083 |
+
- tries multiple scales / minNeighbors
|
| 1084 |
+
- returns the largest face bbox (x,y,w,h) or None
|
| 1085 |
+
"""
|
| 1086 |
+
if frame is None or frame.size == 0:
|
| 1087 |
+
return None
|
| 1088 |
+
|
| 1089 |
+
try:
|
| 1090 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 1091 |
+
except Exception:
|
| 1092 |
+
# If color conversion fails, assume already gray
|
| 1093 |
+
gray = frame.copy() if frame.ndim == 2 else cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY)
|
| 1094 |
+
|
| 1095 |
+
# Light preproc to improve Haar performance
|
| 1096 |
+
gray = cv2.equalizeHist(gray)
|
| 1097 |
+
|
| 1098 |
+
faces_all = []
|
| 1099 |
+
# Try a couple of parameter combos to be more forgiving
|
| 1100 |
+
params = [
|
| 1101 |
+
dict(scaleFactor=1.05, minNeighbors=3),
|
| 1102 |
+
dict(scaleFactor=1.10, minNeighbors=4),
|
| 1103 |
+
dict(scaleFactor=1.20, minNeighbors=5),
|
| 1104 |
+
]
|
| 1105 |
+
for p in params:
|
| 1106 |
+
try:
|
| 1107 |
+
faces = FACE_CASCADE.detectMultiScale(gray, **p)
|
| 1108 |
+
if faces is not None and len(faces) > 0:
|
| 1109 |
+
faces_all.extend([tuple(map(int, f)) for f in faces])
|
| 1110 |
+
except Exception:
|
| 1111 |
+
continue
|
| 1112 |
+
|
| 1113 |
+
if not faces_all:
|
| 1114 |
+
return None
|
| 1115 |
+
|
| 1116 |
+
# Return the largest (by area)
|
| 1117 |
+
return max(faces_all, key=lambda f: f[2] * f[3])
|
| 1118 |
+
|
| 1119 |
+
def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]:
|
| 1120 |
+
"""
|
| 1121 |
+
Crop ROI from the frame based on a face bbox and the selected roi_type.
|
| 1122 |
+
Returns the cropped BGR ROI or None if invalid.
|
| 1123 |
+
"""
|
| 1124 |
+
if face_bbox is None or frame is None or frame.size == 0:
|
| 1125 |
+
return None
|
| 1126 |
+
|
| 1127 |
+
x, y, w, h = map(int, face_bbox)
|
| 1128 |
+
H, W = frame.shape[:2]
|
| 1129 |
+
|
| 1130 |
+
if roi_type == "forehead":
|
| 1131 |
+
rx = int(x + 0.25 * w); rw = int(0.50 * w)
|
| 1132 |
+
ry = int(y + 0.10 * h); rh = int(0.20 * h)
|
| 1133 |
+
elif roi_type == "cheeks":
|
| 1134 |
+
rx = int(x + 0.15 * w); rw = int(0.70 * w)
|
| 1135 |
+
ry = int(y + 0.55 * h); rh = int(0.30 * h)
|
| 1136 |
+
else:
|
| 1137 |
+
rx, ry, rw, rh = x, y, w, h
|
| 1138 |
+
|
| 1139 |
+
# clamp in-bounds
|
| 1140 |
+
rx = max(0, rx); ry = max(0, ry)
|
| 1141 |
+
rx2 = min(W, rx + rw); ry2 = min(H, ry + rh)
|
| 1142 |
+
|
| 1143 |
+
if rx2 <= rx or ry2 <= ry:
|
| 1144 |
+
return None
|
| 1145 |
+
|
| 1146 |
+
roi = frame[ry:ry2, rx:rx2]
|
| 1147 |
+
# Avoid empty or 1-pixel slivers
|
| 1148 |
+
if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4:
|
| 1149 |
+
return None
|
| 1150 |
+
return roi
|
| 1151 |
+
|
| 1152 |
+
def crop_roi_with_bbox(face_bbox: Tuple[int, int, int, int],
|
| 1153 |
+
roi_type: str,
|
| 1154 |
+
frame: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]:
|
| 1155 |
+
if face_bbox is None or frame is None or frame.size == 0:
|
| 1156 |
+
return None, None
|
| 1157 |
+
|
| 1158 |
+
x, y, w, h = map(int, face_bbox)
|
| 1159 |
+
H, W = frame.shape[:2]
|
| 1160 |
+
|
| 1161 |
+
if roi_type == "forehead":
|
| 1162 |
+
rx = int(x + 0.25 * w); rw = int(0.50 * w)
|
| 1163 |
+
ry = int(y + 0.10 * h); rh = int(0.20 * h)
|
| 1164 |
+
elif roi_type == "cheeks":
|
| 1165 |
+
rx = int(x + 0.15 * w); rw = int(0.70 * w)
|
| 1166 |
+
ry = int(y + 0.55 * h); rh = int(0.30 * h)
|
| 1167 |
+
else:
|
| 1168 |
+
rx, ry, rw, rh = x, y, w, h
|
| 1169 |
+
|
| 1170 |
+
rx = max(0, rx); ry = max(0, ry)
|
| 1171 |
+
rx2 = min(W, rx + rw); ry2 = min(H, ry + rh)
|
| 1172 |
+
if rx2 <= rx or ry2 <= ry:
|
| 1173 |
+
return None, None
|
| 1174 |
+
|
| 1175 |
+
roi = frame[ry:ry2, rx:rx2]
|
| 1176 |
+
if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4:
|
| 1177 |
+
return None, None
|
| 1178 |
+
|
| 1179 |
+
return roi, (rx, ry, rx2 - rx, ry2 - ry)
|
| 1180 |
+
|
| 1181 |
+
def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray:
|
| 1182 |
+
"""
|
| 1183 |
+
PhysMamba-compatible normalization with DiffNormalized support.
|
| 1184 |
+
Returns (3, size, size).
|
| 1185 |
+
"""
|
| 1186 |
+
if face_bgr is None or face_bgr.size == 0:
|
| 1187 |
+
return np.zeros((3, size, size), dtype=np.float32)
|
| 1188 |
+
|
| 1189 |
+
try:
|
| 1190 |
+
face = cv2.resize(face_bgr, (size, size), interpolation=cv2.INTER_AREA)
|
| 1191 |
+
except Exception:
|
| 1192 |
+
face = cv2.resize(face_bgr, (size, size))
|
| 1193 |
+
|
| 1194 |
+
face = face.astype(np.float32) / 255.0
|
| 1195 |
+
|
| 1196 |
+
# Per-frame standardization
|
| 1197 |
+
mean = face.mean(axis=(0, 1), keepdims=True)
|
| 1198 |
+
std = face.std(axis=(0, 1), keepdims=True) + 1e-6
|
| 1199 |
+
face = (face - mean) / std
|
| 1200 |
+
|
| 1201 |
+
# BGR -> RGB and HWC -> CHW
|
| 1202 |
+
chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False)
|
| 1203 |
+
return chw
|
| 1204 |
+
|
| 1205 |
+
def extract_attention_map(model, clip_tensor: torch.Tensor,
|
| 1206 |
+
attention_viz) -> Optional[np.ndarray]:
|
| 1207 |
+
"""Attention visualization disabled - model architecture incompatible."""
|
| 1208 |
+
return None
|
| 1209 |
+
|
| 1210 |
+
def create_attention_overlay(frame: np.ndarray, attention: Optional[np.ndarray],
|
| 1211 |
+
attention_viz: Optional[SimpleActivationAttention] = None) -> np.ndarray:
|
| 1212 |
+
"""Create attention heatmap overlay."""
|
| 1213 |
+
# Attention disabled - return original frame
|
| 1214 |
+
return frame
|
| 1215 |
+
|
| 1216 |
+
def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12):
|
| 1217 |
+
H, W = roi_bgr.shape[:2]
|
| 1218 |
+
base_bvp = forward_bvp(model, torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE))
|
| 1219 |
+
.unsqueeze(0).unsqueeze(2).to(DEVICE)) # fake T=1 path if needed
|
| 1220 |
+
base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs)
|
| 1221 |
+
|
| 1222 |
+
heat = np.zeros((H, W), np.float32)
|
| 1223 |
+
for y in range(0, H - patch + 1, stride):
|
| 1224 |
+
for x in range(0, W - patch + 1, stride):
|
| 1225 |
+
tmp = roi_bgr.copy()
|
| 1226 |
+
tmp[y:y+patch, x:x+patch] = 127 # occlude
|
| 1227 |
+
bvp = forward_bvp(model, torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE))
|
| 1228 |
+
.unsqueeze(0).unsqueeze(2).to(DEVICE))
|
| 1229 |
+
power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs)
|
| 1230 |
+
drop = max(0.0, base_power - power)
|
| 1231 |
+
heat[y:y+patch, x:x+patch] += drop
|
| 1232 |
+
heat -= heat.min()
|
| 1233 |
+
if heat.max() > 1e-8: heat /= heat.max()
|
| 1234 |
+
return heat
|
| 1235 |
+
|
| 1236 |
+
def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor):
|
| 1237 |
+
"""
|
| 1238 |
+
Try common 5D layouts:
|
| 1239 |
+
[B, C, T, H, W] then [B, T, C, H, W].
|
| 1240 |
+
"""
|
| 1241 |
+
last_err = None
|
| 1242 |
+
try:
|
| 1243 |
+
return model(clip_tensor)
|
| 1244 |
+
except Exception as e:
|
| 1245 |
+
last_err = e
|
| 1246 |
+
try:
|
| 1247 |
+
return model(clip_tensor.permute(0, 2, 1, 3, 4).contiguous())
|
| 1248 |
+
except Exception as e:
|
| 1249 |
+
last_err = e
|
| 1250 |
+
raise last_err
|
| 1251 |
+
|
| 1252 |
+
def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray:
|
| 1253 |
+
"""
|
| 1254 |
+
Forward and extract a 1D time-like BVP vector with length T_clip.
|
| 1255 |
+
Tolerant to dict/list/tuple heads and odd shapes.
|
| 1256 |
+
"""
|
| 1257 |
+
T_clip = int(clip_tensor.shape[2]) # intended time length for [B,C,T,H,W]
|
| 1258 |
+
with torch.no_grad():
|
| 1259 |
+
out = _call_model_try_orders(model, clip_tensor)
|
| 1260 |
+
|
| 1261 |
+
# unwrap common containers
|
| 1262 |
+
if isinstance(out, dict):
|
| 1263 |
+
for key in ("bvp", "ppg", "signal", "pred", "y", "out", "logits"):
|
| 1264 |
+
if key in out and isinstance(out[key], torch.Tensor):
|
| 1265 |
+
out = out[key]
|
| 1266 |
+
break
|
| 1267 |
+
|
| 1268 |
+
if isinstance(out, (list, tuple)):
|
| 1269 |
+
tensors = [t for t in out if isinstance(t, torch.Tensor)]
|
| 1270 |
+
if not tensors:
|
| 1271 |
+
return np.zeros(T_clip, dtype=np.float32)
|
| 1272 |
+
|
| 1273 |
+
def score(t: torch.Tensor):
|
| 1274 |
+
has_T = 1 if T_clip in t.shape else 0
|
| 1275 |
+
return (has_T, t.numel())
|
| 1276 |
+
|
| 1277 |
+
out = max(tensors, key=score)
|
| 1278 |
+
|
| 1279 |
+
if not isinstance(out, torch.Tensor):
|
| 1280 |
+
return np.zeros(T_clip, dtype=np.float32)
|
| 1281 |
+
|
| 1282 |
+
out = out.detach().cpu().float()
|
| 1283 |
+
|
| 1284 |
+
# ---- 1D
|
| 1285 |
+
if out.ndim == 1:
|
| 1286 |
+
v = out
|
| 1287 |
+
if v.shape[0] == T_clip:
|
| 1288 |
+
return v.numpy()
|
| 1289 |
+
if v.numel() == 1:
|
| 1290 |
+
return np.full(T_clip, float(v.item()), dtype=np.float32)
|
| 1291 |
+
return np.resize(v.numpy(), T_clip).astype(np.float32)
|
| 1292 |
+
|
| 1293 |
+
# ---- 2D
|
| 1294 |
+
if out.ndim == 2:
|
| 1295 |
+
B, K = out.shape
|
| 1296 |
+
if B == 1:
|
| 1297 |
+
v = out[0]
|
| 1298 |
+
return (v.numpy() if v.shape[0] == T_clip else np.resize(v.numpy(), T_clip).astype(np.float32))
|
| 1299 |
+
if B == T_clip:
|
| 1300 |
+
return out[:, 0].numpy()
|
| 1301 |
+
if K == T_clip:
|
| 1302 |
+
return out[0, :].numpy()
|
| 1303 |
+
return np.resize(out.flatten().numpy(), T_clip).astype(np.float32)
|
| 1304 |
+
|
| 1305 |
+
# ---- 3D
|
| 1306 |
+
if out.ndim == 3:
|
| 1307 |
+
B, D1, D2 = out.shape[0], out.shape[1], out.shape[2]
|
| 1308 |
+
if D1 == T_clip: # [B, T, C]
|
| 1309 |
+
return out[0, :, 0].numpy()
|
| 1310 |
+
if D2 == T_clip: # [B, C, T]
|
| 1311 |
+
return out[0, 0, :].numpy()
|
| 1312 |
+
v = out.mean(dim=tuple(range(1, out.ndim))).squeeze(0)
|
| 1313 |
+
return np.resize(v.numpy(), T_clip).astype(np.float32)
|
| 1314 |
+
|
| 1315 |
+
# ---- 4D
|
| 1316 |
+
if out.ndim == 4:
|
| 1317 |
+
B, A, H, W = out.shape
|
| 1318 |
+
if A == T_clip: # [B, T, H, W]
|
| 1319 |
+
return out[0].mean(dim=(1, 2)).numpy()
|
| 1320 |
+
v = out[0].mean(dim=(1, 2))
|
| 1321 |
+
return np.resize(v.numpy(), T_clip).astype(np.float32)
|
| 1322 |
+
|
| 1323 |
+
# ---- 5D+
|
| 1324 |
+
if out.ndim >= 5:
|
| 1325 |
+
shape = list(out.shape)
|
| 1326 |
+
try:
|
| 1327 |
+
t_idx = next(i for i, s in enumerate(shape) if (i != 0 and s == T_clip))
|
| 1328 |
+
except StopIteration:
|
| 1329 |
+
pooled = out[0].mean(dim=tuple(i for i in range(1, out.ndim) if i not in (-1,)))
|
| 1330 |
+
v = pooled.flatten()
|
| 1331 |
+
return np.resize(v.numpy(), T_clip).astype(np.float32)
|
| 1332 |
+
|
| 1333 |
+
axes = list(range(out.ndim))
|
| 1334 |
+
perm = [0, t_idx] + [i for i in axes[1:] if i != t_idx]
|
| 1335 |
+
o2 = out.permute(*perm) # [B, T, ...]
|
| 1336 |
+
pooled = o2.mean(dim=tuple(range(2, o2.ndim))) # -> [B, T]
|
| 1337 |
+
return pooled[0].numpy()
|
| 1338 |
+
|
| 1339 |
+
# fallback: constant vector with mean value
|
| 1340 |
+
val = float(out.mean().item()) if out.numel() else 0.0
|
| 1341 |
+
return np.full(T_clip, val, dtype=np.float32)
|
| 1342 |
+
|
| 1343 |
+
def _fallback_bvp_from_means(means, fs: int) -> np.ndarray:
|
| 1344 |
+
"""
|
| 1345 |
+
Classical rPPG from green-channel means when the model yields nothing.
|
| 1346 |
+
Detrend -> bandpass -> z-normalize.
|
| 1347 |
+
"""
|
| 1348 |
+
if means is None:
|
| 1349 |
+
return np.array([], dtype=np.float32)
|
| 1350 |
+
|
| 1351 |
+
x = np.asarray(means, dtype=np.float32)
|
| 1352 |
+
if x.size == 0:
|
| 1353 |
+
return np.array([], dtype=np.float32)
|
| 1354 |
+
|
| 1355 |
+
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
| 1356 |
+
|
| 1357 |
+
try:
|
| 1358 |
+
x = signal.detrend(x, type="linear")
|
| 1359 |
+
except Exception:
|
| 1360 |
+
pass
|
| 1361 |
+
|
| 1362 |
+
y = bandpass_filter(x, fs=fs, low=0.7, high=3.5, order=4)
|
| 1363 |
+
|
| 1364 |
+
std = float(np.std(y)) + 1e-6
|
| 1365 |
+
return (y / std).astype(np.float32)
|
| 1366 |
+
|
| 1367 |
+
def _to_floats(s: str) -> List[float]:
|
| 1368 |
+
"""
|
| 1369 |
+
Extract all real numbers from free-form text, including scientific notation.
|
| 1370 |
+
Gracefully ignores 'nan', 'inf', units, and comments.
|
| 1371 |
+
"""
|
| 1372 |
+
if not isinstance(s, str) or not s:
|
| 1373 |
+
return []
|
| 1374 |
+
|
| 1375 |
+
s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE)
|
| 1376 |
+
|
| 1377 |
+
s = s.replace(",", " ").replace(";", " ")
|
| 1378 |
+
|
| 1379 |
+
toks = re.findall(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?", s)
|
| 1380 |
+
out: List[float] = []
|
| 1381 |
+
for t in toks:
|
| 1382 |
+
try:
|
| 1383 |
+
v = float(t)
|
| 1384 |
+
if np.isfinite(v):
|
| 1385 |
+
out.append(v)
|
| 1386 |
+
except Exception:
|
| 1387 |
+
continue
|
| 1388 |
+
return out
|
| 1389 |
+
|
| 1390 |
+
def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]:
|
| 1391 |
+
"""
|
| 1392 |
+
Parse ground-truth files from:
|
| 1393 |
+
• UBFC/UBFC-rPPG style TXT: 3 lines => PPG, HR, timestep(s)
|
| 1394 |
+
• Generic TXT: one column (or free-form) numeric sequence
|
| 1395 |
+
• JSON: keys like 'ppg' / 'bvp' (optionally nested), 'hr', 'fs'
|
| 1396 |
+
• CSV: columns named BVP/PPG/Signal + optional HR + optional Time
|
| 1397 |
+
• MAT: array under keys ['ppg','bvp','signal','wave']; optionally 'fs'/'hr'
|
| 1398 |
+
• NPY: np.load() 1-D array; optionally sidecar .json with fs/hr (same stem)
|
| 1399 |
+
|
| 1400 |
+
Returns:
|
| 1401 |
+
bvp : np.ndarray (float) — may be empty
|
| 1402 |
+
hr : float (mean HR if available or estimated)
|
| 1403 |
+
fs_hint : float — sampling rate if derivable (0.0 if unknown)
|
| 1404 |
+
"""
|
| 1405 |
+
if not gt_path or not os.path.exists(gt_path):
|
| 1406 |
+
return np.array([]), 0.0, 0.0
|
| 1407 |
+
|
| 1408 |
+
p = Path(gt_path)
|
| 1409 |
+
ext = p.suffix.lower()
|
| 1410 |
+
|
| 1411 |
+
def _fs_from_time_vector(tv: np.ndarray) -> float:
|
| 1412 |
+
tv = np.asarray(tv, dtype=float)
|
| 1413 |
+
if tv.ndim != 1 or tv.size < 2:
|
| 1414 |
+
return 0.0
|
| 1415 |
+
diffs = np.diff(tv)
|
| 1416 |
+
diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
|
| 1417 |
+
return (1.0 / float(np.median(diffs))) if diffs.size else 0.0
|
| 1418 |
+
|
| 1419 |
+
|
| 1420 |
+
def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float:
|
| 1421 |
+
if bvp is None or bvp.size == 0:
|
| 1422 |
+
return 0.0
|
| 1423 |
+
fs_use = fs_hint if (fs_hint and fs_hint > 0) else 30.0
|
| 1424 |
+
bp = bandpass_filter(bvp.astype(float), fs=fs_use)
|
| 1425 |
+
return hr_from_welch(bp, fs=fs_use)
|
| 1426 |
+
|
| 1427 |
+
if p.name.lower() == "ground_truth.txt" or (ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2):
|
| 1428 |
+
try:
|
| 1429 |
+
lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines() if ln.strip()]
|
| 1430 |
+
ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else []
|
| 1431 |
+
hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else []
|
| 1432 |
+
t_vals = _to_floats(lines[2]) if len(lines) >= 3 else []
|
| 1433 |
+
|
| 1434 |
+
bvp = np.asarray(ppg_vals, dtype=float) if ppg_vals else np.array([], dtype=float)
|
| 1435 |
+
|
| 1436 |
+
hr = float(np.nanmean(hr_vals)) if hr_vals else 0.0
|
| 1437 |
+
fs_hint = 0.0
|
| 1438 |
+
if t_vals:
|
| 1439 |
+
if len(t_vals) == 1:
|
| 1440 |
+
dt = float(t_vals[0])
|
| 1441 |
+
if dt > 0:
|
| 1442 |
+
fs_hint = 1.0 / dt
|
| 1443 |
+
else:
|
| 1444 |
+
fs_hint = _fs_from_time_vector(np.asarray(t_vals, dtype=float))
|
| 1445 |
+
|
| 1446 |
+
if hr == 0.0 and bvp.size:
|
| 1447 |
+
hr = _hr_from_bvp(bvp, fs_hint)
|
| 1448 |
+
return bvp, hr, fs_hint
|
| 1449 |
+
except Exception:
|
| 1450 |
+
# Fall through to generic handlers
|
| 1451 |
+
pass
|
| 1452 |
+
|
| 1453 |
+
|
| 1454 |
+
if ext == ".txt":
|
| 1455 |
+
try:
|
| 1456 |
+
nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore"))
|
| 1457 |
+
bvp = np.asarray(nums, dtype=float) if nums else np.array([], dtype=float)
|
| 1458 |
+
hr = _hr_from_bvp(bvp, fs_hint=0.0) if bvp.size else 0.0
|
| 1459 |
+
return bvp, hr, 0.0
|
| 1460 |
+
except Exception:
|
| 1461 |
+
return np.array([]), 0.0, 0.0
|
| 1462 |
+
|
| 1463 |
+
|
| 1464 |
+
if ext == ".json":
|
| 1465 |
+
try:
|
| 1466 |
+
data = json.loads(p.read_text(encoding="utf-8", errors="ignore"))
|
| 1467 |
+
# Try several paths for BVP array
|
| 1468 |
+
bvp = None
|
| 1469 |
+
|
| 1470 |
+
def _seek(obj, keys):
|
| 1471 |
+
for k in keys:
|
| 1472 |
+
if isinstance(obj, dict) and k in obj:
|
| 1473 |
+
return obj[k]
|
| 1474 |
+
return None
|
| 1475 |
+
|
| 1476 |
+
# Direct top-level
|
| 1477 |
+
bvp = _seek(data, ("ppg", "bvp", "signal", "wave"))
|
| 1478 |
+
# Common nested containers
|
| 1479 |
+
if bvp is None:
|
| 1480 |
+
for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"):
|
| 1481 |
+
if container_key in data:
|
| 1482 |
+
cand = _seek(data[container_key], ("ppg", "bvp", "signal", "wave"))
|
| 1483 |
+
if cand is not None:
|
| 1484 |
+
bvp = cand
|
| 1485 |
+
break
|
| 1486 |
+
|
| 1487 |
+
if bvp is not None:
|
| 1488 |
+
bvp = np.asarray(bvp, dtype=float).ravel()
|
| 1489 |
+
else:
|
| 1490 |
+
bvp = np.array([], dtype=float)
|
| 1491 |
+
|
| 1492 |
+
# fs / hr (accept scalar or array)
|
| 1493 |
+
fs_hint = 0.0
|
| 1494 |
+
if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0:
|
| 1495 |
+
fs_hint = float(data["fs"])
|
| 1496 |
+
|
| 1497 |
+
hr = 0.0
|
| 1498 |
+
if "hr" in data:
|
| 1499 |
+
v = data["hr"]
|
| 1500 |
+
hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
|
| 1501 |
+
|
| 1502 |
+
if hr == 0.0 and bvp.size:
|
| 1503 |
+
hr = _hr_from_bvp(bvp, fs_hint)
|
| 1504 |
+
return bvp, hr, fs_hint
|
| 1505 |
+
except Exception:
|
| 1506 |
+
return np.array([]), 0.0, 0.0
|
| 1507 |
+
|
| 1508 |
+
|
| 1509 |
+
if ext == ".csv":
|
| 1510 |
+
try:
|
| 1511 |
+
df = pd.read_csv(p)
|
| 1512 |
+
# Normalize column names
|
| 1513 |
+
cols = {str(c).strip().lower(): c for c in df.columns}
|
| 1514 |
+
|
| 1515 |
+
def _first_match(names):
|
| 1516 |
+
for nm in names:
|
| 1517 |
+
if nm in cols:
|
| 1518 |
+
return cols[nm]
|
| 1519 |
+
return None
|
| 1520 |
+
|
| 1521 |
+
bvp_col = _first_match(["bvp", "ppg", "wave", "signal", "bvp_signal", "ppg_signal"])
|
| 1522 |
+
hr_col = _first_match(["hr", "heart_rate", "hr_bpm", "bpm"])
|
| 1523 |
+
t_col = _first_match(["time", "t", "timestamp", "sec", "seconds", "time_s"])
|
| 1524 |
+
|
| 1525 |
+
bvp = np.asarray(df[bvp_col].values, dtype=float) if bvp_col else np.array([], dtype=float)
|
| 1526 |
+
|
| 1527 |
+
fs_hint = 0.0
|
| 1528 |
+
if t_col is not None and len(df[t_col].values) >= 2:
|
| 1529 |
+
fs_hint = _fs_from_time_vector(np.asarray(df[t_col].values, dtype=float))
|
| 1530 |
+
|
| 1531 |
+
hr = float(np.nanmean(df[hr_col].values)) if (hr_col and df[hr_col].notna().any()) else 0.0
|
| 1532 |
+
if hr == 0.0 and bvp.size:
|
| 1533 |
+
hr = _hr_from_bvp(bvp, fs_hint)
|
| 1534 |
+
return bvp, hr, fs_hint
|
| 1535 |
+
except Exception:
|
| 1536 |
+
return np.array([]), 0.0, 0.0
|
| 1537 |
+
|
| 1538 |
+
|
| 1539 |
+
if ext == ".mat":
|
| 1540 |
+
try:
|
| 1541 |
+
md = loadmat(str(p))
|
| 1542 |
+
# look for most likely array
|
| 1543 |
+
arr = None
|
| 1544 |
+
for key in ("ppg", "bvp", "signal", "wave"):
|
| 1545 |
+
if key in md and isinstance(md[key], np.ndarray):
|
| 1546 |
+
arr = md[key]
|
| 1547 |
+
break
|
| 1548 |
+
if arr is None:
|
| 1549 |
+
# fallback: first 1-D array
|
| 1550 |
+
for v in md.values():
|
| 1551 |
+
if isinstance(v, np.ndarray) and v.ndim == 1:
|
| 1552 |
+
arr = v
|
| 1553 |
+
break
|
| 1554 |
+
bvp = np.asarray(arr, dtype=float).ravel() if arr is not None else np.array([], dtype=float)
|
| 1555 |
+
|
| 1556 |
+
fs_hint = 0.0
|
| 1557 |
+
for k in ("fs", "Fs", "sampling_rate", "sr"):
|
| 1558 |
+
if k in md:
|
| 1559 |
+
try:
|
| 1560 |
+
fs_hint = float(np.ravel(md[k])[0])
|
| 1561 |
+
break
|
| 1562 |
+
except Exception:
|
| 1563 |
+
pass
|
| 1564 |
+
|
| 1565 |
+
hr = 0.0
|
| 1566 |
+
if "hr" in md:
|
| 1567 |
+
try:
|
| 1568 |
+
hr = float(np.nanmean(np.ravel(md["hr"])))
|
| 1569 |
+
except Exception:
|
| 1570 |
+
hr = 0.0
|
| 1571 |
+
|
| 1572 |
+
if hr == 0.0 and bvp.size:
|
| 1573 |
+
hr = _hr_from_bvp(bvp, fs_hint)
|
| 1574 |
+
return bvp, hr, fs_hint
|
| 1575 |
+
except Exception:
|
| 1576 |
+
return np.array([]), 0.0, 0.0
|
| 1577 |
+
|
| 1578 |
+
# ================= NPY =================
|
| 1579 |
+
if ext == ".npy":
|
| 1580 |
+
try:
|
| 1581 |
+
bvp = np.asarray(np.load(str(p)), dtype=float).ravel()
|
| 1582 |
+
fs_hint, hr = 0.0, 0.0
|
| 1583 |
+
# optional sidecar JSON (same stem) with fs/hr
|
| 1584 |
+
sidecar = p.with_suffix(".json")
|
| 1585 |
+
if sidecar.exists():
|
| 1586 |
+
try:
|
| 1587 |
+
meta = json.loads(sidecar.read_text(encoding="utf-8", errors="ignore"))
|
| 1588 |
+
if isinstance(meta.get("fs", None), (int, float)) and meta["fs"] > 0:
|
| 1589 |
+
fs_hint = float(meta["fs"])
|
| 1590 |
+
if "hr" in meta:
|
| 1591 |
+
v = meta["hr"]
|
| 1592 |
+
hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v)
|
| 1593 |
+
except Exception:
|
| 1594 |
+
pass
|
| 1595 |
+
if hr == 0.0 and bvp.size:
|
| 1596 |
+
hr = _hr_from_bvp(bvp, fs_hint)
|
| 1597 |
+
return bvp, hr, fs_hint
|
| 1598 |
+
except Exception:
|
| 1599 |
+
return np.array([]), 0.0, 0.0
|
| 1600 |
+
|
| 1601 |
+
# Fallback (unsupported extension)
|
| 1602 |
+
return np.array([]), 0.0, 0.0
|
| 1603 |
+
|
| 1604 |
+
def scan_models() -> List[str]:
|
| 1605 |
+
if not MODEL_DIR.exists():
|
| 1606 |
+
return []
|
| 1607 |
+
|
| 1608 |
+
models = []
|
| 1609 |
+
for f in sorted(MODEL_DIR.iterdir()):
|
| 1610 |
+
if f.suffix.lower() == '.pth':
|
| 1611 |
+
models.append(f.name)
|
| 1612 |
+
|
| 1613 |
+
return models
|
| 1614 |
+
|
| 1615 |
+
_GLOBAL_CONTROLS: Dict[str, Dict] = {}
|
| 1616 |
+
|
| 1617 |
+
def ensure_controls(control_id: str) -> Tuple[str, Dict]:
|
| 1618 |
+
# Use a stable default so Pause/Resume/Stop work for the current run
|
| 1619 |
+
if not control_id:
|
| 1620 |
+
control_id = "default-session"
|
| 1621 |
+
if control_id not in _GLOBAL_CONTROLS:
|
| 1622 |
+
_GLOBAL_CONTROLS[control_id] = {
|
| 1623 |
+
'pause': threading.Event(),
|
| 1624 |
+
'stop': threading.Event()
|
| 1625 |
+
}
|
| 1626 |
+
return control_id, _GLOBAL_CONTROLS[control_id]
|
| 1627 |
+
|
| 1628 |
+
def process_video_file(
|
| 1629 |
+
video_path: str,
|
| 1630 |
+
gt_file: Optional[str],
|
| 1631 |
+
model_name: str,
|
| 1632 |
+
fps_input: int,
|
| 1633 |
+
max_seconds: int,
|
| 1634 |
+
roi_type: str,
|
| 1635 |
+
control_id: str
|
| 1636 |
+
):
|
| 1637 |
+
"""
|
| 1638 |
+
Enhanced video processing with Grad-CAM attention visualization.
|
| 1639 |
+
"""
|
| 1640 |
+
global _HR_SMOOTH
|
| 1641 |
+
_HR_SMOOTH = None
|
| 1642 |
+
|
| 1643 |
+
def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]:
|
| 1644 |
+
if gt_len <= 1:
|
| 1645 |
+
return None
|
| 1646 |
+
if gt_fs and gt_fs > 0:
|
| 1647 |
+
return np.arange(gt_len, dtype=float) / float(gt_fs)
|
| 1648 |
+
return None
|
| 1649 |
+
|
| 1650 |
+
control_id, controls = ensure_controls(control_id)
|
| 1651 |
+
controls['stop'].clear()
|
| 1652 |
+
|
| 1653 |
+
if not model_name:
|
| 1654 |
+
yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
|
| 1655 |
+
return
|
| 1656 |
+
|
| 1657 |
+
if isinstance(model_name, int):
|
| 1658 |
+
model_name = str(model_name)
|
| 1659 |
+
|
| 1660 |
+
model_path = MODEL_DIR / model_name
|
| 1661 |
+
if not model_path.exists():
|
| 1662 |
+
yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None)
|
| 1663 |
+
return
|
| 1664 |
+
|
| 1665 |
+
try:
|
| 1666 |
+
model, attention_viz = load_physmamba_model(model_path, DEVICE)
|
| 1667 |
+
except Exception as e:
|
| 1668 |
+
yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None)
|
| 1669 |
+
return
|
| 1670 |
+
|
| 1671 |
+
|
| 1672 |
+
gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0)
|
| 1673 |
+
|
| 1674 |
+
if not video_path or not os.path.exists(video_path):
|
| 1675 |
+
yield ("ERROR: Video not found", None, None, None, None, None, None, None, None, None)
|
| 1676 |
+
return
|
| 1677 |
+
|
| 1678 |
+
cap = cv2.VideoCapture(video_path)
|
| 1679 |
+
if not cap.isOpened():
|
| 1680 |
+
yield ("ERROR: Cannot open video", None, None, None, None, None, None, None, None, None)
|
| 1681 |
+
return
|
| 1682 |
+
|
| 1683 |
+
fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30)
|
| 1684 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
| 1685 |
+
max_frames = int(max_seconds * fps) if max_seconds and max_seconds > 0 else total_frames
|
| 1686 |
+
MAX_SIGNAL_LENGTH = max_frames if max_frames > 0 else fps * 600
|
| 1687 |
+
|
| 1688 |
+
frames_chw = deque(maxlen=DEFAULT_T)
|
| 1689 |
+
raw_g_means = deque(maxlen=DEFAULT_T)
|
| 1690 |
+
bvp_stream: List[float] = []
|
| 1691 |
+
frame_idx = 0
|
| 1692 |
+
last_infer = -1
|
| 1693 |
+
last_bpm = 0.0
|
| 1694 |
+
last_rmssd = 0.0
|
| 1695 |
+
last_attention = None
|
| 1696 |
+
|
| 1697 |
+
start_time = time.time()
|
| 1698 |
+
next_display = time.time()
|
| 1699 |
+
|
| 1700 |
+
tmpdir = Path(tempfile.gettempdir())
|
| 1701 |
+
frame_path = tmpdir / "frame.jpg"
|
| 1702 |
+
attention_path = tmpdir / "attention.jpg"
|
| 1703 |
+
signal_path = tmpdir / "signal.png"
|
| 1704 |
+
raw_path = tmpdir / "raw_signal.png"
|
| 1705 |
+
post_path = tmpdir / "post_signal.png"
|
| 1706 |
+
|
| 1707 |
+
yield ("Starting… reading video frames", None, f"{gt_hr:.1f}" if gt_hr > 0 else "--",
|
| 1708 |
+
None, None, None, None, None, None, None)
|
| 1709 |
+
|
| 1710 |
+
while True:
|
| 1711 |
+
if controls['stop'].is_set():
|
| 1712 |
+
break
|
| 1713 |
+
|
| 1714 |
+
while controls['pause'].is_set():
|
| 1715 |
+
time.sleep(0.2)
|
| 1716 |
+
if controls['stop'].is_set():
|
| 1717 |
+
break
|
| 1718 |
+
|
| 1719 |
+
ret, frame = cap.read()
|
| 1720 |
+
if not ret or (max_frames > 0 and frame_idx >= max_frames):
|
| 1721 |
+
break
|
| 1722 |
+
|
| 1723 |
+
frame_idx += 1
|
| 1724 |
+
|
| 1725 |
+
face = detect_face(frame)
|
| 1726 |
+
vis_frame = frame.copy()
|
| 1727 |
+
|
| 1728 |
+
if face is not None:
|
| 1729 |
+
x, y, w, h = face
|
| 1730 |
+
cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
| 1731 |
+
|
| 1732 |
+
if roi_type == "auto":
|
| 1733 |
+
roi, _ = pick_auto_roi(face, frame, attn=last_attention)
|
| 1734 |
+
else:
|
| 1735 |
+
roi = crop_roi(face, roi_type, frame)
|
| 1736 |
+
|
| 1737 |
+
if roi is not None and roi.size > 0:
|
| 1738 |
+
try:
|
| 1739 |
+
g_mean = float(roi[..., 1].astype(np.float32).mean())
|
| 1740 |
+
raw_g_means.append(g_mean)
|
| 1741 |
+
except Exception:
|
| 1742 |
+
pass
|
| 1743 |
+
|
| 1744 |
+
face_norm = normalize_frame(roi, DEFAULT_SIZE)
|
| 1745 |
+
frames_chw.append(face_norm)
|
| 1746 |
+
|
| 1747 |
+
if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE:
|
| 1748 |
+
try:
|
| 1749 |
+
clip = np.stack(list(frames_chw), axis=1).astype(np.float32)
|
| 1750 |
+
except Exception as e:
|
| 1751 |
+
print(f"[infer] clip stack failed: {e}")
|
| 1752 |
+
clip = None
|
| 1753 |
+
|
| 1754 |
+
bvp_out = None
|
| 1755 |
+
if clip is not None:
|
| 1756 |
+
clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE)
|
| 1757 |
+
|
| 1758 |
+
try:
|
| 1759 |
+
raw = forward_bvp(model, clip_t)
|
| 1760 |
+
if isinstance(raw, np.ndarray):
|
| 1761 |
+
raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
|
| 1762 |
+
bvp_out = raw if raw.size > 0 else None
|
| 1763 |
+
else:
|
| 1764 |
+
bvp_out = None
|
| 1765 |
+
except Exception as e:
|
| 1766 |
+
print(f"[infer] forward_bvp error: {e}")
|
| 1767 |
+
bvp_out = None
|
| 1768 |
+
|
| 1769 |
+
# UPDATED: Generate attention with Grad-CAM
|
| 1770 |
+
try:
|
| 1771 |
+
last_attention = extract_attention_map(model, clip_t, attention_viz)
|
| 1772 |
+
except Exception as e:
|
| 1773 |
+
print(f"⚠ Attention generation error: {e}")
|
| 1774 |
+
last_attention = None
|
| 1775 |
+
|
| 1776 |
+
if bvp_out is None or bvp_out.size == 0:
|
| 1777 |
+
gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0)
|
| 1778 |
+
fb = _fallback_bvp_from_means(gbuf, fs=fps)
|
| 1779 |
+
if isinstance(fb, np.ndarray) and fb.size > 0:
|
| 1780 |
+
bvp_out = fb
|
| 1781 |
+
else:
|
| 1782 |
+
print("[infer] fallback produced empty output")
|
| 1783 |
+
|
| 1784 |
+
if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0:
|
| 1785 |
+
tail = min(DEFAULT_STRIDE, bvp_out.size)
|
| 1786 |
+
bvp_stream.extend(bvp_out[-tail:].tolist())
|
| 1787 |
+
if len(bvp_stream) > MAX_SIGNAL_LENGTH:
|
| 1788 |
+
bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
|
| 1789 |
+
|
| 1790 |
+
if len(bvp_stream) >= int(5 * fps):
|
| 1791 |
+
seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32)
|
| 1792 |
+
_, last_bpm = postprocess_bvp(seg, fs=fps)
|
| 1793 |
+
last_rmssd = compute_rmssd(seg, fs=fps)
|
| 1794 |
+
|
| 1795 |
+
if frame_idx % (DEFAULT_STRIDE * 2) == 0:
|
| 1796 |
+
print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}")
|
| 1797 |
+
else:
|
| 1798 |
+
print("[infer] no usable bvp_out after fallback")
|
| 1799 |
+
|
| 1800 |
+
last_infer = frame_idx
|
| 1801 |
+
|
| 1802 |
+
else:
|
| 1803 |
+
cv2.putText(vis_frame, "No face detected", (20, 40),
|
| 1804 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2)
|
| 1805 |
+
|
| 1806 |
+
if last_bpm > 0:
|
| 1807 |
+
color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255)
|
| 1808 |
+
cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1)
|
| 1809 |
+
cv2.putText(vis_frame, f"HR: {last_bpm:.1f} BPM", (20, 48),
|
| 1810 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
|
| 1811 |
+
|
| 1812 |
+
vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
|
| 1813 |
+
|
| 1814 |
+
cv2.imwrite(str(frame_path), vis_frame)
|
| 1815 |
+
cv2.imwrite(str(attention_path), vis_attention)
|
| 1816 |
+
|
| 1817 |
+
now = time.time()
|
| 1818 |
+
if now >= next_display:
|
| 1819 |
+
if len(bvp_stream) >= max(10, int(1 * fps)):
|
| 1820 |
+
try:
|
| 1821 |
+
signal_array = np.array(bvp_stream, dtype=float)
|
| 1822 |
+
time_axis = np.arange(len(signal_array)) / fps
|
| 1823 |
+
|
| 1824 |
+
raw_signal = bandpass_filter(signal_array, fs=fps)
|
| 1825 |
+
post_signal, _ = postprocess_bvp(signal_array, fs=fps)
|
| 1826 |
+
|
| 1827 |
+
raw_vis = raw_signal - np.mean(raw_signal)
|
| 1828 |
+
post_vis = post_signal - np.mean(post_signal)
|
| 1829 |
+
|
| 1830 |
+
plt.figure(figsize=(10, 4), dpi=100)
|
| 1831 |
+
plt.plot(time_axis, raw_vis, linewidth=1.6)
|
| 1832 |
+
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
|
| 1833 |
+
plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM')
|
| 1834 |
+
plt.grid(True, alpha=0.3)
|
| 1835 |
+
plt.tight_layout(); plt.savefig(str(raw_path), dpi=100, bbox_inches='tight'); plt.close()
|
| 1836 |
+
|
| 1837 |
+
plt.figure(figsize=(10, 4), dpi=100)
|
| 1838 |
+
plt.plot(time_axis, post_vis, linewidth=1.6)
|
| 1839 |
+
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
|
| 1840 |
+
plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM')
|
| 1841 |
+
plt.grid(True, alpha=0.3)
|
| 1842 |
+
plt.tight_layout(); plt.savefig(str(post_path), dpi=100, bbox_inches='tight'); plt.close()
|
| 1843 |
+
|
| 1844 |
+
if gt_bvp is not None and gt_bvp.size > 0:
|
| 1845 |
+
try:
|
| 1846 |
+
gt_time = _gt_time_axis(len(gt_bvp), gt_fs)
|
| 1847 |
+
plot_signals_with_gt(
|
| 1848 |
+
time_axis=time_axis,
|
| 1849 |
+
raw_signal=raw_signal,
|
| 1850 |
+
post_signal=post_signal,
|
| 1851 |
+
fs=fps,
|
| 1852 |
+
out_path=str(signal_path),
|
| 1853 |
+
gt_time=gt_time,
|
| 1854 |
+
gt_bvp=gt_bvp,
|
| 1855 |
+
title=f"Pred vs GT — HR: {last_bpm:.1f} BPM"
|
| 1856 |
+
)
|
| 1857 |
+
except Exception:
|
| 1858 |
+
fig = plt.figure(figsize=(12, 5), dpi=100)
|
| 1859 |
+
gs = GridSpec(1, 2, figure=fig, wspace=0.3)
|
| 1860 |
+
ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6)
|
| 1861 |
+
ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3)
|
| 1862 |
+
ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6)
|
| 1863 |
+
ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3)
|
| 1864 |
+
plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold')
|
| 1865 |
+
plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all')
|
| 1866 |
+
else:
|
| 1867 |
+
fig = plt.figure(figsize=(12, 5), dpi=100)
|
| 1868 |
+
gs = GridSpec(1, 2, figure=fig, wspace=0.3)
|
| 1869 |
+
ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6)
|
| 1870 |
+
ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3)
|
| 1871 |
+
ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6)
|
| 1872 |
+
ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3)
|
| 1873 |
+
plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold')
|
| 1874 |
+
plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all')
|
| 1875 |
+
except Exception:
|
| 1876 |
+
pass
|
| 1877 |
+
|
| 1878 |
+
elapsed = now - start_time
|
| 1879 |
+
status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
|
| 1880 |
+
|
| 1881 |
+
yield (
|
| 1882 |
+
status,
|
| 1883 |
+
f"{last_bpm:.1f}" if last_bpm > 0 else None,
|
| 1884 |
+
f"{gt_hr:.1f}" if gt_hr > 0 else "--",
|
| 1885 |
+
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
|
| 1886 |
+
str(frame_path),
|
| 1887 |
+
str(attention_path),
|
| 1888 |
+
str(signal_path) if signal_path.exists() else None,
|
| 1889 |
+
str(raw_path) if raw_path.exists() else None,
|
| 1890 |
+
str(post_path) if post_path.exists() else None,
|
| 1891 |
+
None
|
| 1892 |
+
)
|
| 1893 |
+
|
| 1894 |
+
next_display = now + (1.0 / DISPLAY_FPS)
|
| 1895 |
+
|
| 1896 |
+
cap.release()
|
| 1897 |
+
|
| 1898 |
+
# Cleanup Grad-CAM
|
| 1899 |
+
if attention_viz: attention_viz.cleanup()
|
| 1900 |
+
|
| 1901 |
+
csv_path = None
|
| 1902 |
+
if bvp_stream:
|
| 1903 |
+
csv_path = Path(tempfile.gettempdir()) / "bvp_output.csv"
|
| 1904 |
+
time_array = np.arange(len(bvp_stream)) / fps
|
| 1905 |
+
signal_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps)
|
| 1906 |
+
try:
|
| 1907 |
+
pd.DataFrame({'time_s': time_array, 'bvp': signal_final}).to_csv(csv_path, index=False)
|
| 1908 |
+
except Exception:
|
| 1909 |
+
csv_path = None
|
| 1910 |
+
|
| 1911 |
+
try:
|
| 1912 |
+
final_overlay = Path(tempfile.gettempdir()) / "signal_final_overlay.png"
|
| 1913 |
+
if gt_bvp is not None and gt_bvp.size > 0:
|
| 1914 |
+
gt_time = _gt_time_axis(len(gt_bvp), gt_fs)
|
| 1915 |
+
plot_signals_with_gt(
|
| 1916 |
+
time_axis=time_array,
|
| 1917 |
+
raw_signal=bandpass_filter(np.array(bvp_stream, dtype=float), fs=fps),
|
| 1918 |
+
post_signal=signal_final,
|
| 1919 |
+
fs=fps,
|
| 1920 |
+
out_path=str(final_overlay),
|
| 1921 |
+
gt_time=gt_time,
|
| 1922 |
+
gt_bvp=gt_bvp,
|
| 1923 |
+
title=f"Final Pred vs GT — HR: {last_bpm:.1f} BPM"
|
| 1924 |
+
)
|
| 1925 |
+
if final_overlay.exists():
|
| 1926 |
+
signal_path = final_overlay
|
| 1927 |
+
except Exception:
|
| 1928 |
+
pass
|
| 1929 |
+
|
| 1930 |
+
elapsed = time.time() - start_time
|
| 1931 |
+
final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
|
| 1932 |
+
|
| 1933 |
+
yield (
|
| 1934 |
+
final_status,
|
| 1935 |
+
f"{last_bpm:.1f}" if last_bpm > 0 else None,
|
| 1936 |
+
f"{gt_hr:.1f}" if gt_hr > 0 else "--",
|
| 1937 |
+
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
|
| 1938 |
+
str(frame_path),
|
| 1939 |
+
str(attention_path),
|
| 1940 |
+
str(signal_path) if signal_path.exists() else None,
|
| 1941 |
+
str(raw_path) if raw_path.exists() else None,
|
| 1942 |
+
str(post_path) if post_path.exists() else None,
|
| 1943 |
+
str(csv_path) if csv_path else None
|
| 1944 |
+
)
|
| 1945 |
+
|
| 1946 |
+
def process_live_webcam(
|
| 1947 |
+
model_name: str,
|
| 1948 |
+
fps_input: int,
|
| 1949 |
+
roi_type: str,
|
| 1950 |
+
control_id: str
|
| 1951 |
+
):
|
| 1952 |
+
"""Stream live webcam with Grad-CAM attention visualization."""
|
| 1953 |
+
global _HR_SMOOTH
|
| 1954 |
+
_HR_SMOOTH = None
|
| 1955 |
+
|
| 1956 |
+
def _perf_heartbeat(frame_idx, t0, bvp_len, frames_chw_len, fps):
|
| 1957 |
+
if frame_idx == 1:
|
| 1958 |
+
print(f"[run] device={DEVICE} target_fps={fps}")
|
| 1959 |
+
if frame_idx % 60 == 0:
|
| 1960 |
+
elapsed = time.time() - t0
|
| 1961 |
+
cur_fps = frame_idx / max(elapsed, 1e-6)
|
| 1962 |
+
print(f"[perf] frames={frame_idx} ({cur_fps:.1f} FPS) "
|
| 1963 |
+
f"clip={frames_chw_len}/{DEFAULT_T} bvp_len={bvp_len}")
|
| 1964 |
+
|
| 1965 |
+
control_id, controls = ensure_controls(control_id)
|
| 1966 |
+
controls['stop'].clear()
|
| 1967 |
+
|
| 1968 |
+
if not model_name:
|
| 1969 |
+
yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
|
| 1970 |
+
return
|
| 1971 |
+
|
| 1972 |
+
model_path = MODEL_DIR / model_name
|
| 1973 |
+
if not model_path.exists():
|
| 1974 |
+
yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None)
|
| 1975 |
+
return
|
| 1976 |
+
|
| 1977 |
+
try:
|
| 1978 |
+
model, attention_viz = load_physmamba_model(model_path, DEVICE)
|
| 1979 |
+
except Exception as e:
|
| 1980 |
+
yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None)
|
| 1981 |
+
return
|
| 1982 |
+
|
| 1983 |
+
|
| 1984 |
+
cap = None
|
| 1985 |
+
for camera_id in [0, 1]:
|
| 1986 |
+
for backend in [cv2.CAP_AVFOUNDATION, cv2.CAP_ANY]:
|
| 1987 |
+
try:
|
| 1988 |
+
test_cap = cv2.VideoCapture(camera_id, backend)
|
| 1989 |
+
if test_cap.isOpened():
|
| 1990 |
+
test_cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
| 1991 |
+
test_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
| 1992 |
+
ret, frame = test_cap.read()
|
| 1993 |
+
if ret and frame is not None:
|
| 1994 |
+
cap = test_cap
|
| 1995 |
+
break
|
| 1996 |
+
test_cap.release()
|
| 1997 |
+
except Exception:
|
| 1998 |
+
pass
|
| 1999 |
+
if cap is not None:
|
| 2000 |
+
break
|
| 2001 |
+
|
| 2002 |
+
if cap is None:
|
| 2003 |
+
yield ("ERROR: Cannot access webcam", None, None, None, None, None, None, None, None, None)
|
| 2004 |
+
return
|
| 2005 |
+
|
| 2006 |
+
fps = int(fps_input) if fps_input else 30
|
| 2007 |
+
|
| 2008 |
+
frames_chw = deque(maxlen=DEFAULT_T)
|
| 2009 |
+
raw_g_means = deque(maxlen=DEFAULT_T)
|
| 2010 |
+
bvp_stream: List[float] = []
|
| 2011 |
+
frame_idx = 0
|
| 2012 |
+
last_infer = -1
|
| 2013 |
+
last_bpm = 0.0
|
| 2014 |
+
last_rmssd = 0.0
|
| 2015 |
+
last_attention = None
|
| 2016 |
+
|
| 2017 |
+
t0 = time.time()
|
| 2018 |
+
next_display = time.time()
|
| 2019 |
+
DISPLAY_INTERVAL = 0.5
|
| 2020 |
+
|
| 2021 |
+
frame_path = Path(tempfile.gettempdir()) / "live_frame.jpg"
|
| 2022 |
+
attention_path = Path(tempfile.gettempdir()) / "live_attention.jpg"
|
| 2023 |
+
signal_path = Path(tempfile.gettempdir()) / "live_signal.png"
|
| 2024 |
+
raw_path = Path(tempfile.gettempdir()) / "live_raw.png"
|
| 2025 |
+
post_path = Path(tempfile.gettempdir()) / "live_post.png"
|
| 2026 |
+
|
| 2027 |
+
MAX_SIGNAL_LENGTH = fps * 60
|
| 2028 |
+
|
| 2029 |
+
yield ("Starting… waiting for frames", None, "--", None,
|
| 2030 |
+
None, None, None, None, None, None)
|
| 2031 |
+
|
| 2032 |
+
while True:
|
| 2033 |
+
if controls['stop'].is_set():
|
| 2034 |
+
break
|
| 2035 |
+
|
| 2036 |
+
while controls['pause'].is_set():
|
| 2037 |
+
time.sleep(0.2)
|
| 2038 |
+
if controls['stop'].is_set():
|
| 2039 |
+
break
|
| 2040 |
+
|
| 2041 |
+
ret, frame = cap.read()
|
| 2042 |
+
if not ret:
|
| 2043 |
+
time.sleep(0.05)
|
| 2044 |
+
continue
|
| 2045 |
+
|
| 2046 |
+
frame_idx += 1
|
| 2047 |
+
_perf_heartbeat(frame_idx, t0, len(bvp_stream), len(frames_chw), fps)
|
| 2048 |
+
|
| 2049 |
+
face = detect_face(frame)
|
| 2050 |
+
vis_frame = frame.copy()
|
| 2051 |
+
|
| 2052 |
+
if face is not None:
|
| 2053 |
+
x, y, w, h = face
|
| 2054 |
+
cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 3)
|
| 2055 |
+
cv2.putText(vis_frame, "FACE", (x, max(20, y - 10)),
|
| 2056 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
| 2057 |
+
|
| 2058 |
+
if roi_type == "auto":
|
| 2059 |
+
roi, _ = pick_auto_roi(face, frame, attn=last_attention)
|
| 2060 |
+
else:
|
| 2061 |
+
roi = crop_roi(face, roi_type, frame)
|
| 2062 |
+
|
| 2063 |
+
if roi is not None and roi.size > 0:
|
| 2064 |
+
try:
|
| 2065 |
+
g_mean = float(roi[..., 1].astype(np.float32).mean())
|
| 2066 |
+
raw_g_means.append(g_mean)
|
| 2067 |
+
except Exception:
|
| 2068 |
+
pass
|
| 2069 |
+
|
| 2070 |
+
chw = normalize_frame(roi, DEFAULT_SIZE)
|
| 2071 |
+
frames_chw.append(chw)
|
| 2072 |
+
|
| 2073 |
+
if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE:
|
| 2074 |
+
try:
|
| 2075 |
+
clip = np.stack(list(frames_chw), axis=1).astype(np.float32)
|
| 2076 |
+
except Exception as e:
|
| 2077 |
+
print(f"[infer] clip stack failed: {e}")
|
| 2078 |
+
clip = None
|
| 2079 |
+
|
| 2080 |
+
bvp_out = None
|
| 2081 |
+
if clip is not None:
|
| 2082 |
+
clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE)
|
| 2083 |
+
try:
|
| 2084 |
+
raw = forward_bvp(model, clip_t)
|
| 2085 |
+
if isinstance(raw, np.ndarray):
|
| 2086 |
+
raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
|
| 2087 |
+
bvp_out = raw if raw.size > 0 else None
|
| 2088 |
+
else:
|
| 2089 |
+
bvp_out = None
|
| 2090 |
+
except Exception as e:
|
| 2091 |
+
print(f"[infer] forward_bvp error: {e}")
|
| 2092 |
+
bvp_out = None
|
| 2093 |
+
|
| 2094 |
+
try:
|
| 2095 |
+
last_attention = extract_attention_map(model, clip_t, attention_viz)
|
| 2096 |
+
except Exception as e:
|
| 2097 |
+
print(f" Attention generation error: {e}")
|
| 2098 |
+
last_attention = None
|
| 2099 |
+
|
| 2100 |
+
if bvp_out is None or bvp_out.size == 0:
|
| 2101 |
+
gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0)
|
| 2102 |
+
fb = _fallback_bvp_from_means(gbuf, fs=fps)
|
| 2103 |
+
if isinstance(fb, np.ndarray) and fb.size > 0:
|
| 2104 |
+
bvp_out = fb
|
| 2105 |
+
else:
|
| 2106 |
+
print("[infer] fallback produced empty output")
|
| 2107 |
+
|
| 2108 |
+
if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0:
|
| 2109 |
+
tail = min(DEFAULT_STRIDE, bvp_out.size)
|
| 2110 |
+
bvp_stream.extend(bvp_out[-tail:].tolist())
|
| 2111 |
+
if len(bvp_stream) > MAX_SIGNAL_LENGTH:
|
| 2112 |
+
bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:]
|
| 2113 |
+
|
| 2114 |
+
if len(bvp_stream) >= int(5 * fps):
|
| 2115 |
+
seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32)
|
| 2116 |
+
_, last_bpm = postprocess_bvp(seg, fs=fps)
|
| 2117 |
+
last_rmssd = compute_rmssd(seg, fs=fps)
|
| 2118 |
+
|
| 2119 |
+
if frame_idx % (DEFAULT_STRIDE * 2) == 0:
|
| 2120 |
+
print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}")
|
| 2121 |
+
else:
|
| 2122 |
+
print("[infer] no usable bvp_out after fallback")
|
| 2123 |
+
|
| 2124 |
+
last_infer = frame_idx
|
| 2125 |
+
|
| 2126 |
+
else:
|
| 2127 |
+
cv2.putText(vis_frame, "No face detected", (20, 40),
|
| 2128 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2)
|
| 2129 |
+
|
| 2130 |
+
cv2.putText(vis_frame, f"Fill: {len(frames_chw)}/{DEFAULT_T}", (20, 25),
|
| 2131 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 2132 |
+
cv2.putText(vis_frame, f"BVP: {len(bvp_stream)}", (20, 45),
|
| 2133 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 2134 |
+
|
| 2135 |
+
if last_bpm > 0:
|
| 2136 |
+
color = (50, 255, 50) if 55 <= last_bpm <= 100 else (50, 50, 255) if last_bpm > 100 else (255, 200, 50)
|
| 2137 |
+
overlay = vis_frame.copy()
|
| 2138 |
+
cv2.rectangle(overlay, (10, 10), (450, 100), (0, 0, 0), -1)
|
| 2139 |
+
vis_frame = cv2.addWeighted(vis_frame, 0.6, overlay, 0.4, 0)
|
| 2140 |
+
cv2.putText(vis_frame, f"HR: {last_bpm:.0f} BPM", (20, 80),
|
| 2141 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 3)
|
| 2142 |
+
cv2.circle(vis_frame, (30, 30), 10, (0, 255, 0), -1)
|
| 2143 |
+
cv2.putText(vis_frame, "LIVE", (50, 38),
|
| 2144 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
| 2145 |
+
else:
|
| 2146 |
+
cv2.putText(vis_frame, "Collecting…", (20, 80),
|
| 2147 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
|
| 2148 |
+
|
| 2149 |
+
vis_attention = create_attention_overlay(frame, last_attention, attention_viz)
|
| 2150 |
+
|
| 2151 |
+
cv2.imwrite(str(frame_path), vis_frame)
|
| 2152 |
+
cv2.imwrite(str(attention_path), vis_attention)
|
| 2153 |
+
|
| 2154 |
+
now = time.time()
|
| 2155 |
+
if now >= next_display:
|
| 2156 |
+
if len(bvp_stream) >= max(10, int(1 * fps)):
|
| 2157 |
+
try:
|
| 2158 |
+
sig = np.array(bvp_stream, dtype=float)
|
| 2159 |
+
t_axis = np.arange(len(sig)) / fps
|
| 2160 |
+
raw_sig = bandpass_filter(sig, fs=fps)
|
| 2161 |
+
post_sig, _ = postprocess_bvp(sig, fs=fps)
|
| 2162 |
+
|
| 2163 |
+
rv = raw_sig - np.mean(raw_sig)
|
| 2164 |
+
pv = post_sig - np.mean(post_sig)
|
| 2165 |
+
|
| 2166 |
+
plt.figure(figsize=(10, 4), dpi=100)
|
| 2167 |
+
plt.plot(t_axis, rv, linewidth=2)
|
| 2168 |
+
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
|
| 2169 |
+
plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM')
|
| 2170 |
+
plt.grid(True, alpha=0.3)
|
| 2171 |
+
plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
|
| 2172 |
+
plt.tight_layout(); plt.savefig(str(raw_path)); plt.close()
|
| 2173 |
+
|
| 2174 |
+
plt.figure(figsize=(10, 4), dpi=100)
|
| 2175 |
+
plt.plot(t_axis, pv, linewidth=2)
|
| 2176 |
+
plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
|
| 2177 |
+
plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM')
|
| 2178 |
+
plt.grid(True, alpha=0.3)
|
| 2179 |
+
plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
|
| 2180 |
+
plt.tight_layout(); plt.savefig(str(post_path)); plt.close()
|
| 2181 |
+
|
| 2182 |
+
fig = plt.figure(figsize=(12, 5), dpi=100)
|
| 2183 |
+
from matplotlib.gridspec import GridSpec
|
| 2184 |
+
gs = GridSpec(1, 2, figure=fig, wspace=0.3)
|
| 2185 |
+
ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(t_axis, rv, linewidth=2)
|
| 2186 |
+
ax1.set_title('Raw Signal'); ax1.grid(True, alpha=0.3)
|
| 2187 |
+
ax1.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
|
| 2188 |
+
ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(t_axis, pv, linewidth=2)
|
| 2189 |
+
ax2.set_title('Post-processed'); ax2.grid(True, alpha=0.3)
|
| 2190 |
+
ax2.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]])
|
| 2191 |
+
plt.suptitle(f'LIVE rPPG - HR: {last_bpm:.1f} BPM')
|
| 2192 |
+
plt.savefig(str(signal_path)); plt.close('all')
|
| 2193 |
+
except Exception:
|
| 2194 |
+
pass
|
| 2195 |
+
|
| 2196 |
+
elapsed = now - t0
|
| 2197 |
+
status = (f"Frame {frame_idx} | Fill {len(frames_chw)}/{DEFAULT_T} | "
|
| 2198 |
+
f"BVP {len(bvp_stream)} | HR {last_bpm:.1f} BPM | "
|
| 2199 |
+
f"Time {int(elapsed)}s")
|
| 2200 |
+
|
| 2201 |
+
yield (
|
| 2202 |
+
status,
|
| 2203 |
+
f"{last_bpm:.1f}" if last_bpm > 0 else None,
|
| 2204 |
+
"--",
|
| 2205 |
+
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
|
| 2206 |
+
str(frame_path),
|
| 2207 |
+
str(attention_path),
|
| 2208 |
+
str(signal_path) if signal_path.exists() else None,
|
| 2209 |
+
str(raw_path) if raw_path.exists() else None,
|
| 2210 |
+
str(post_path) if post_path.exists() else None,
|
| 2211 |
+
None
|
| 2212 |
+
)
|
| 2213 |
+
next_display = now + DISPLAY_INTERVAL
|
| 2214 |
+
|
| 2215 |
+
time.sleep(0.01)
|
| 2216 |
+
|
| 2217 |
+
cap.release()
|
| 2218 |
+
|
| 2219 |
+
# Cleanup Grad-CAM
|
| 2220 |
+
if attention_viz: attention_viz.cleanup()
|
| 2221 |
+
|
| 2222 |
+
csv_path = None
|
| 2223 |
+
if bvp_stream:
|
| 2224 |
+
csv_path = Path(tempfile.gettempdir()) / "live_bvp.csv"
|
| 2225 |
+
t = np.arange(len(bvp_stream)) / fps
|
| 2226 |
+
sig_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps)
|
| 2227 |
+
pd.DataFrame({"time_s": t, "bvp": sig_final}).to_csv(csv_path, index=False)
|
| 2228 |
+
|
| 2229 |
+
elapsed = time.time() - t0
|
| 2230 |
+
final_status = f"Session ended | Frames {frame_idx} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM"
|
| 2231 |
+
yield (
|
| 2232 |
+
final_status,
|
| 2233 |
+
f"{last_bpm:.1f}" if last_bpm > 0 else None,
|
| 2234 |
+
"--",
|
| 2235 |
+
f"{last_rmssd:.1f}" if last_rmssd > 0 else None,
|
| 2236 |
+
str(frame_path),
|
| 2237 |
+
str(attention_path),
|
| 2238 |
+
str(signal_path) if signal_path.exists() else None,
|
| 2239 |
+
str(raw_path) if raw_path.exists() else None,
|
| 2240 |
+
str(post_path) if post_path.exists() else None,
|
| 2241 |
+
str(csv_path) if csv_path else None
|
| 2242 |
+
)
|
| 2243 |
+
|
| 2244 |
+
def process_stream(
|
| 2245 |
+
input_source: str,
|
| 2246 |
+
video_path: Optional[str],
|
| 2247 |
+
gt_file: Optional[str],
|
| 2248 |
+
model_name: str,
|
| 2249 |
+
fps_input: int,
|
| 2250 |
+
max_seconds: int,
|
| 2251 |
+
roi_type: str,
|
| 2252 |
+
control_id: str
|
| 2253 |
+
):
|
| 2254 |
+
if input_source == "Live Webcam":
|
| 2255 |
+
yield from process_live_webcam(model_name, fps_input, roi_type, control_id)
|
| 2256 |
+
else:
|
| 2257 |
+
yield from process_video_file(video_path, gt_file, model_name, fps_input,
|
| 2258 |
+
max_seconds, roi_type, control_id)
|
| 2259 |
+
|
| 2260 |
+
def pause_processing(control_id: str) -> str:
|
| 2261 |
+
_, controls = ensure_controls(control_id)
|
| 2262 |
+
controls['pause'].set()
|
| 2263 |
+
return "Paused"
|
| 2264 |
+
|
| 2265 |
+
def resume_processing(control_id: str) -> str:
|
| 2266 |
+
_, controls = ensure_controls(control_id)
|
| 2267 |
+
controls['pause'].clear()
|
| 2268 |
+
return "Resumed"
|
| 2269 |
+
|
| 2270 |
+
def stop_processing(control_id: str) -> str:
|
| 2271 |
+
_, controls = ensure_controls(control_id)
|
| 2272 |
+
controls['stop'].set()
|
| 2273 |
+
controls['pause'].clear()
|
| 2274 |
+
return "Stopped"
|
| 2275 |
+
|
| 2276 |
+
def reset_ui():
|
| 2277 |
+
return ("Ready", None, None, None, None, None, None, None, None, None)
|
| 2278 |
+
|
| 2279 |
+
def handle_folder_upload(files):
|
| 2280 |
+
if not files:
|
| 2281 |
+
return None, None, "No files uploaded"
|
| 2282 |
+
|
| 2283 |
+
if not isinstance(files, list):
|
| 2284 |
+
files = [files]
|
| 2285 |
+
|
| 2286 |
+
video_path = None
|
| 2287 |
+
for file_obj in files:
|
| 2288 |
+
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
|
| 2289 |
+
if file_path.suffix.lower() in VIDEO_EXTENSIONS:
|
| 2290 |
+
video_path = str(file_path)
|
| 2291 |
+
break
|
| 2292 |
+
|
| 2293 |
+
gt_path = None
|
| 2294 |
+
gt_patterns = ['gtdump.txt', 'ground_truth.txt', 'gt.txt']
|
| 2295 |
+
|
| 2296 |
+
for file_obj in files:
|
| 2297 |
+
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
|
| 2298 |
+
if file_path.name.lower() in [p.lower() for p in gt_patterns]:
|
| 2299 |
+
gt_path = str(file_path)
|
| 2300 |
+
break
|
| 2301 |
+
|
| 2302 |
+
if not gt_path:
|
| 2303 |
+
for file_obj in files:
|
| 2304 |
+
file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name)
|
| 2305 |
+
if file_path.suffix.lower() in ['.txt', '.json', '.csv']:
|
| 2306 |
+
if 'readme' not in file_path.name.lower():
|
| 2307 |
+
gt_path = str(file_path)
|
| 2308 |
+
break
|
| 2309 |
+
|
| 2310 |
+
status = []
|
| 2311 |
+
if video_path:
|
| 2312 |
+
status.append(f"Video: {Path(video_path).name}")
|
| 2313 |
+
else:
|
| 2314 |
+
status.append("No video found")
|
| 2315 |
+
|
| 2316 |
+
if gt_path:
|
| 2317 |
+
status.append(f"GT: {Path(gt_path).name}")
|
| 2318 |
+
else:
|
| 2319 |
+
status.append("No ground truth")
|
| 2320 |
+
|
| 2321 |
+
return video_path, gt_path, " | ".join(status)
|
| 2322 |
+
|
| 2323 |
+
with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as demo:
|
| 2324 |
+
gr.Markdown("# rPPG Analysis Tool with Attention Visualization")
|
| 2325 |
+
|
| 2326 |
+
with gr.Row():
|
| 2327 |
+
input_source = gr.Radio(
|
| 2328 |
+
choices=["Video File", "Subject Folder", "Live Webcam"],
|
| 2329 |
+
value="Live Webcam",
|
| 2330 |
+
label="Input Source"
|
| 2331 |
+
)
|
| 2332 |
+
|
| 2333 |
+
with gr.Row():
|
| 2334 |
+
target_layer_input = gr.Textbox(
|
| 2335 |
+
label="Grad-CAM Target Layer (optional - leave empty for auto)",
|
| 2336 |
+
placeholder="e.g., backbone.conv3, encoder.layer2",
|
| 2337 |
+
value=""
|
| 2338 |
+
)
|
| 2339 |
+
|
| 2340 |
+
with gr.Row(visible=False) as video_inputs:
|
| 2341 |
+
with gr.Column():
|
| 2342 |
+
video_upload = gr.Video(label="Upload Video", sources=["upload"])
|
| 2343 |
+
with gr.Column():
|
| 2344 |
+
gt_upload = gr.File(label="Ground Truth (optional)",
|
| 2345 |
+
file_types=[".txt", ".csv", ".json"])
|
| 2346 |
+
|
| 2347 |
+
with gr.Row(visible=False) as folder_inputs:
|
| 2348 |
+
with gr.Column():
|
| 2349 |
+
folder_upload = gr.File(
|
| 2350 |
+
label="Upload Subject Folder",
|
| 2351 |
+
file_count="directory",
|
| 2352 |
+
file_types=None,
|
| 2353 |
+
type="filepath"
|
| 2354 |
+
)
|
| 2355 |
+
folder_status = gr.Textbox(label="Folder Status", interactive=False)
|
| 2356 |
+
|
| 2357 |
+
with gr.Row(visible=True) as webcam_inputs:
|
| 2358 |
+
gr.Markdown("### Live Webcam - Click Run to start")
|
| 2359 |
+
|
| 2360 |
+
def toggle_input_source(source):
|
| 2361 |
+
return [
|
| 2362 |
+
gr.update(visible=(source == "Video File")),
|
| 2363 |
+
gr.update(visible=(source == "Subject Folder")),
|
| 2364 |
+
gr.update(visible=(source == "Live Webcam"))
|
| 2365 |
+
]
|
| 2366 |
+
|
| 2367 |
+
input_source.change(
|
| 2368 |
+
toggle_input_source,
|
| 2369 |
+
inputs=[input_source],
|
| 2370 |
+
outputs=[video_inputs, folder_inputs, webcam_inputs]
|
| 2371 |
+
)
|
| 2372 |
+
|
| 2373 |
+
folder_video = gr.State(None)
|
| 2374 |
+
folder_gt = gr.State(None)
|
| 2375 |
+
|
| 2376 |
+
folder_upload.upload(
|
| 2377 |
+
handle_folder_upload,
|
| 2378 |
+
inputs=[folder_upload],
|
| 2379 |
+
outputs=[folder_video, folder_gt, folder_status]
|
| 2380 |
+
)
|
| 2381 |
+
|
| 2382 |
+
with gr.Row():
|
| 2383 |
+
with gr.Column(scale=3):
|
| 2384 |
+
model_dropdown = gr.Dropdown(
|
| 2385 |
+
choices=scan_models(),
|
| 2386 |
+
value=scan_models()[0] if scan_models() else None,
|
| 2387 |
+
label="PhysMamba Model",
|
| 2388 |
+
interactive=True
|
| 2389 |
+
)
|
| 2390 |
+
with gr.Column(scale=1):
|
| 2391 |
+
refresh_models_btn = gr.Button("Refresh", variant="secondary")
|
| 2392 |
+
|
| 2393 |
+
refresh_models_btn.click(
|
| 2394 |
+
lambda: gr.update(choices=scan_models()),
|
| 2395 |
+
inputs=None,
|
| 2396 |
+
outputs=[model_dropdown]
|
| 2397 |
+
)
|
| 2398 |
+
|
| 2399 |
+
with gr.Row():
|
| 2400 |
+
fps_slider = gr.Slider(
|
| 2401 |
+
minimum=10, maximum=120, value=30, step=5,
|
| 2402 |
+
label="FPS"
|
| 2403 |
+
)
|
| 2404 |
+
max_seconds_slider = gr.Slider(
|
| 2405 |
+
minimum=10, maximum=600, value=180, step=10,
|
| 2406 |
+
label="Max Duration (s)"
|
| 2407 |
+
)
|
| 2408 |
+
|
| 2409 |
+
with gr.Row():
|
| 2410 |
+
roi_dropdown = gr.Dropdown(choices=["auto","forehead","cheeks","face"], value="auto", label="ROI")
|
| 2411 |
+
|
| 2412 |
+
control_state = gr.State(value="")
|
| 2413 |
+
placeholder_state = gr.State(value=None)
|
| 2414 |
+
|
| 2415 |
+
with gr.Row():
|
| 2416 |
+
run_btn = gr.Button("Run", variant="primary")
|
| 2417 |
+
pause_btn = gr.Button("Pause", variant="secondary")
|
| 2418 |
+
resume_btn = gr.Button("Resume", variant="secondary")
|
| 2419 |
+
stop_btn = gr.Button("Stop", variant="stop")
|
| 2420 |
+
|
| 2421 |
+
status_text = gr.Textbox(label="Status", lines=2, value="Ready")
|
| 2422 |
+
|
| 2423 |
+
with gr.Row():
|
| 2424 |
+
hr_output = gr.Textbox(label="HR (BPM)", interactive=False)
|
| 2425 |
+
gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False)
|
| 2426 |
+
rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False)
|
| 2427 |
+
|
| 2428 |
+
with gr.Row():
|
| 2429 |
+
with gr.Column():
|
| 2430 |
+
frame_output = gr.Image(label="Video Feed", type="filepath")
|
| 2431 |
+
with gr.Column():
|
| 2432 |
+
attention_output = gr.Image(label="Attention Map", type="filepath")
|
| 2433 |
+
|
| 2434 |
+
with gr.Row():
|
| 2435 |
+
signal_output = gr.Image(label="Signal Comparison", type="filepath")
|
| 2436 |
+
|
| 2437 |
+
with gr.Row():
|
| 2438 |
+
with gr.Column():
|
| 2439 |
+
raw_signal_output = gr.Image(label="Raw Signal", type="filepath")
|
| 2440 |
+
with gr.Column():
|
| 2441 |
+
post_signal_output = gr.Image(label="Post-processed Signal", type="filepath")
|
| 2442 |
+
|
| 2443 |
+
with gr.Row():
|
| 2444 |
+
csv_output = gr.File(label="Download CSV")
|
| 2445 |
+
|
| 2446 |
+
pause_btn.click(
|
| 2447 |
+
pause_processing,
|
| 2448 |
+
inputs=[control_state],
|
| 2449 |
+
outputs=[status_text]
|
| 2450 |
+
)
|
| 2451 |
+
|
| 2452 |
+
resume_btn.click(
|
| 2453 |
+
resume_processing,
|
| 2454 |
+
inputs=[control_state],
|
| 2455 |
+
outputs=[status_text]
|
| 2456 |
+
)
|
| 2457 |
+
|
| 2458 |
+
stop_btn.click(
|
| 2459 |
+
stop_processing,
|
| 2460 |
+
inputs=[control_state],
|
| 2461 |
+
outputs=[status_text]
|
| 2462 |
+
).then(
|
| 2463 |
+
reset_ui,
|
| 2464 |
+
inputs=None,
|
| 2465 |
+
outputs=[status_text, hr_output, gt_hr_output, rmssd_output,
|
| 2466 |
+
frame_output, attention_output, signal_output,
|
| 2467 |
+
raw_signal_output, post_signal_output, csv_output]
|
| 2468 |
+
)
|
| 2469 |
+
|
| 2470 |
+
def run_processing(input_source, video_upload, gt_upload, folder_video, folder_gt,
|
| 2471 |
+
model_name, fps, max_sec, roi, ctrl_id):
|
| 2472 |
+
"""Fixed version that handles model_name type conversion."""
|
| 2473 |
+
|
| 2474 |
+
if isinstance(model_name, int):
|
| 2475 |
+
model_name = str(model_name)
|
| 2476 |
+
|
| 2477 |
+
if not model_name:
|
| 2478 |
+
yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None)
|
| 2479 |
+
return
|
| 2480 |
+
|
| 2481 |
+
if input_source == "Video File":
|
| 2482 |
+
video_path = _as_path(video_upload)
|
| 2483 |
+
gt_file = _as_path(gt_upload)
|
| 2484 |
+
elif input_source == "Subject Folder":
|
| 2485 |
+
video_path = _as_path(folder_video)
|
| 2486 |
+
gt_file = _as_path(folder_gt)
|
| 2487 |
+
else: # Live Webcam
|
| 2488 |
+
video_path, gt_file = None, None
|
| 2489 |
+
|
| 2490 |
+
yield from process_stream(
|
| 2491 |
+
input_source, video_path, gt_file,
|
| 2492 |
+
model_name, fps, max_sec, roi, ctrl_id
|
| 2493 |
+
)
|
| 2494 |
+
|
| 2495 |
+
|
| 2496 |
+
run_btn.click(
|
| 2497 |
+
fn=run_processing,
|
| 2498 |
+
inputs=[
|
| 2499 |
+
input_source,
|
| 2500 |
+
video_upload,
|
| 2501 |
+
gt_upload,
|
| 2502 |
+
folder_video,
|
| 2503 |
+
folder_gt,
|
| 2504 |
+
model_dropdown,
|
| 2505 |
+
fps_slider,
|
| 2506 |
+
max_seconds_slider,
|
| 2507 |
+
roi_dropdown,
|
| 2508 |
+
control_state
|
| 2509 |
+
],
|
| 2510 |
+
outputs=[
|
| 2511 |
+
status_text,
|
| 2512 |
+
hr_output,
|
| 2513 |
+
gt_hr_output,
|
| 2514 |
+
rmssd_output,
|
| 2515 |
+
frame_output,
|
| 2516 |
+
attention_output,
|
| 2517 |
+
signal_output,
|
| 2518 |
+
raw_signal_output,
|
| 2519 |
+
post_signal_output,
|
| 2520 |
+
csv_output
|
| 2521 |
+
]
|
| 2522 |
+
)
|
| 2523 |
+
|
| 2524 |
+
|
| 2525 |
+
run_btn.click(
|
| 2526 |
+
fn=run_processing,
|
| 2527 |
+
inputs=[
|
| 2528 |
+
input_source,
|
| 2529 |
+
video_upload,
|
| 2530 |
+
folder_video,
|
| 2531 |
+
folder_gt,
|
| 2532 |
+
model_dropdown,
|
| 2533 |
+
fps_slider,
|
| 2534 |
+
max_seconds_slider,
|
| 2535 |
+
roi_dropdown,
|
| 2536 |
+
control_state
|
| 2537 |
+
],
|
| 2538 |
+
outputs=[
|
| 2539 |
+
status_text,
|
| 2540 |
+
hr_output,
|
| 2541 |
+
gt_hr_output,
|
| 2542 |
+
rmssd_output,
|
| 2543 |
+
frame_output,
|
| 2544 |
+
attention_output,
|
| 2545 |
+
signal_output,
|
| 2546 |
+
raw_signal_output,
|
| 2547 |
+
post_signal_output,
|
| 2548 |
+
csv_output
|
| 2549 |
+
]
|
| 2550 |
+
)
|
| 2551 |
+
|
| 2552 |
+
if __name__ == "__main__":
|
| 2553 |
+
demo.queue(max_size=10).launch(
|
| 2554 |
+
server_name="127.0.0.1",
|
| 2555 |
+
server_port=7861,
|
| 2556 |
+
share=False,
|
| 2557 |
+
show_error=True,
|
| 2558 |
+
inbrowser=True
|
| 2559 |
+
)
|
final_model_release/PURE_PhysMamba_DiffNormalized.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2da6b9bc7e8728c20743c8f359eb43eefd2a96665098a29302fed2b09bc6b7d7
|
| 3 |
+
size 3013798
|
final_model_release/UBFC-rPPG_PhysMamba_DiffNormalized.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd18f227ff18f5140f471dc78e61426d9344ecc43677b35ec16b3199f1dccd19
|
| 3 |
+
size 3013798
|
mamba_ssm/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
|
| 3 |
+
import torch, torch.nn as nn, torch.nn.functional as F
|
| 4 |
+
print('[mamba_ssm shim] Loaded CPU shim')
|
| 5 |
+
class Mamba(nn.Module):
|
| 6 |
+
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, conv_bias=True, **kwargs):
|
| 7 |
+
super().__init__()
|
| 8 |
+
h = d_model * expand
|
| 9 |
+
self.in_proj = nn.Linear(d_model, 2*h)
|
| 10 |
+
self.dw = nn.Conv1d(h, h, d_conv, padding=d_conv-1, groups=h, bias=conv_bias)
|
| 11 |
+
self.mix = nn.Conv1d(h, h, 1)
|
| 12 |
+
self.out = nn.Linear(h, d_model)
|
| 13 |
+
self.d_model = d_model
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
B,L,C = x.shape
|
| 16 |
+
u,v = self.in_proj(x).chunk(2, dim=-1)
|
| 17 |
+
y = F.silu(u) * torch.sigmoid(v)
|
| 18 |
+
y = self.dw(y.transpose(1,2))[...,:L]
|
| 19 |
+
y = F.silu(self.mix(y)).transpose(1,2)
|
| 20 |
+
return self.out(y)
|
mamba_ssm/models/__init__.py
ADDED
|
File without changes
|
mamba_ssm/models/mixer_seq_simple.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from mamba_ssm.modules.mamba_simple import Mamba, Block
|
| 12 |
+
from mamba_ssm.utils.generation import GenerationMixin
|
| 13 |
+
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 17 |
+
except ImportError:
|
| 18 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_block(
|
| 22 |
+
d_model,
|
| 23 |
+
ssm_cfg=None,
|
| 24 |
+
norm_epsilon=1e-5,
|
| 25 |
+
rms_norm=False,
|
| 26 |
+
residual_in_fp32=False,
|
| 27 |
+
fused_add_norm=False,
|
| 28 |
+
layer_idx=None,
|
| 29 |
+
device=None,
|
| 30 |
+
dtype=None,
|
| 31 |
+
):
|
| 32 |
+
if ssm_cfg is None:
|
| 33 |
+
ssm_cfg = {}
|
| 34 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 35 |
+
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
| 36 |
+
norm_cls = partial(
|
| 37 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 38 |
+
)
|
| 39 |
+
block = Block(
|
| 40 |
+
d_model,
|
| 41 |
+
mixer_cls,
|
| 42 |
+
norm_cls=norm_cls,
|
| 43 |
+
fused_add_norm=fused_add_norm,
|
| 44 |
+
residual_in_fp32=residual_in_fp32,
|
| 45 |
+
)
|
| 46 |
+
block.layer_idx = layer_idx
|
| 47 |
+
return block
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 51 |
+
def _init_weights(
|
| 52 |
+
module,
|
| 53 |
+
n_layer,
|
| 54 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
| 55 |
+
rescale_prenorm_residual=True,
|
| 56 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
| 57 |
+
):
|
| 58 |
+
if isinstance(module, nn.Linear):
|
| 59 |
+
if module.bias is not None:
|
| 60 |
+
if not getattr(module.bias, "_no_reinit", False):
|
| 61 |
+
nn.init.zeros_(module.bias)
|
| 62 |
+
elif isinstance(module, nn.Embedding):
|
| 63 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 64 |
+
|
| 65 |
+
if rescale_prenorm_residual:
|
| 66 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 67 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 68 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 69 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 70 |
+
#
|
| 71 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 72 |
+
for name, p in module.named_parameters():
|
| 73 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
| 74 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 75 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
| 76 |
+
# We need to reinit p since this code could be called multiple times
|
| 77 |
+
# Having just p *= scale would repeatedly scale it down
|
| 78 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class MixerModel(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
d_model: int,
|
| 87 |
+
n_layer: int,
|
| 88 |
+
vocab_size: int,
|
| 89 |
+
ssm_cfg=None,
|
| 90 |
+
norm_epsilon: float = 1e-5,
|
| 91 |
+
rms_norm: bool = False,
|
| 92 |
+
initializer_cfg=None,
|
| 93 |
+
fused_add_norm=False,
|
| 94 |
+
residual_in_fp32=False,
|
| 95 |
+
device=None,
|
| 96 |
+
dtype=None,
|
| 97 |
+
) -> None:
|
| 98 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 101 |
+
|
| 102 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
| 103 |
+
|
| 104 |
+
# We change the order of residual and layer norm:
|
| 105 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
| 106 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
| 107 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
| 108 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
| 109 |
+
self.fused_add_norm = fused_add_norm
|
| 110 |
+
if self.fused_add_norm:
|
| 111 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
| 112 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
| 113 |
+
|
| 114 |
+
self.layers = nn.ModuleList(
|
| 115 |
+
[
|
| 116 |
+
create_block(
|
| 117 |
+
d_model,
|
| 118 |
+
ssm_cfg=ssm_cfg,
|
| 119 |
+
norm_epsilon=norm_epsilon,
|
| 120 |
+
rms_norm=rms_norm,
|
| 121 |
+
residual_in_fp32=residual_in_fp32,
|
| 122 |
+
fused_add_norm=fused_add_norm,
|
| 123 |
+
layer_idx=i,
|
| 124 |
+
**factory_kwargs,
|
| 125 |
+
)
|
| 126 |
+
for i in range(n_layer)
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
| 131 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self.apply(
|
| 135 |
+
partial(
|
| 136 |
+
_init_weights,
|
| 137 |
+
n_layer=n_layer,
|
| 138 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 143 |
+
return {
|
| 144 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 145 |
+
for i, layer in enumerate(self.layers)
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
def forward(self, input_ids, inference_params=None):
|
| 149 |
+
hidden_states = self.embedding(input_ids)
|
| 150 |
+
residual = None
|
| 151 |
+
for layer in self.layers:
|
| 152 |
+
hidden_states, residual = layer(
|
| 153 |
+
hidden_states, residual, inference_params=inference_params
|
| 154 |
+
)
|
| 155 |
+
if not self.fused_add_norm:
|
| 156 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 157 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 158 |
+
else:
|
| 159 |
+
# Set prenorm=False here since we don't need the residual
|
| 160 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
| 161 |
+
hidden_states = fused_add_norm_fn(
|
| 162 |
+
hidden_states,
|
| 163 |
+
self.norm_f.weight,
|
| 164 |
+
self.norm_f.bias,
|
| 165 |
+
eps=self.norm_f.eps,
|
| 166 |
+
residual=residual,
|
| 167 |
+
prenorm=False,
|
| 168 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 169 |
+
)
|
| 170 |
+
return hidden_states
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
| 174 |
+
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
d_model: int,
|
| 178 |
+
n_layer: int,
|
| 179 |
+
vocab_size: int,
|
| 180 |
+
initializer_cfg=None,
|
| 181 |
+
pad_vocab_size_multiple: int = 1,
|
| 182 |
+
device=None,
|
| 183 |
+
dtype=None,
|
| 184 |
+
**backbone_kwargs,
|
| 185 |
+
) -> None:
|
| 186 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 187 |
+
super().__init__()
|
| 188 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
| 189 |
+
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
| 190 |
+
self.backbone = MixerModel(
|
| 191 |
+
d_model=d_model,
|
| 192 |
+
n_layer=n_layer,
|
| 193 |
+
vocab_size=vocab_size,
|
| 194 |
+
initializer_cfg=initializer_cfg,
|
| 195 |
+
**backbone_kwargs,
|
| 196 |
+
**factory_kwargs,
|
| 197 |
+
)
|
| 198 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
| 199 |
+
|
| 200 |
+
# Initialize weights and apply final processing
|
| 201 |
+
self.apply(
|
| 202 |
+
partial(
|
| 203 |
+
_init_weights,
|
| 204 |
+
n_layer=n_layer,
|
| 205 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
self.tie_weights()
|
| 209 |
+
|
| 210 |
+
def tie_weights(self):
|
| 211 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
| 212 |
+
|
| 213 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 214 |
+
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 215 |
+
|
| 216 |
+
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
| 217 |
+
"""
|
| 218 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 219 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 220 |
+
"""
|
| 221 |
+
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
| 222 |
+
if num_last_tokens > 0:
|
| 223 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 224 |
+
lm_logits = self.lm_head(hidden_states)
|
| 225 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 226 |
+
return CausalLMOutput(logits=lm_logits)
|
| 227 |
+
|
| 228 |
+
@classmethod
|
| 229 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 230 |
+
config = load_config_hf(pretrained_model_name)
|
| 231 |
+
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
| 232 |
+
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
| 233 |
+
return model
|
mamba_ssm/modules/__init__.py
ADDED
|
File without changes
|
mamba_ssm/modules/mamba_simple.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 15 |
+
except ImportError:
|
| 16 |
+
causal_conv1d_fn, causal_conv1d_update = None
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
|
| 20 |
+
except ImportError:
|
| 21 |
+
selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 25 |
+
except ImportError:
|
| 26 |
+
selective_state_update = None
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 30 |
+
except ImportError:
|
| 31 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Mamba(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
d_model,
|
| 38 |
+
d_state=16,
|
| 39 |
+
d_conv=4,
|
| 40 |
+
expand=2,
|
| 41 |
+
dt_rank="auto",
|
| 42 |
+
dt_min=0.001,
|
| 43 |
+
dt_max=0.1,
|
| 44 |
+
dt_init="random",
|
| 45 |
+
dt_scale=1.0,
|
| 46 |
+
dt_init_floor=1e-4,
|
| 47 |
+
conv_bias=True,
|
| 48 |
+
bias=False,
|
| 49 |
+
use_fast_path=True, # Fused kernel options
|
| 50 |
+
layer_idx=None,
|
| 51 |
+
device=None,
|
| 52 |
+
dtype=None,
|
| 53 |
+
bimamba=True,
|
| 54 |
+
):
|
| 55 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.d_model = d_model
|
| 58 |
+
self.d_state = d_state
|
| 59 |
+
self.d_conv = d_conv
|
| 60 |
+
self.expand = expand
|
| 61 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 62 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 63 |
+
self.use_fast_path = use_fast_path
|
| 64 |
+
self.layer_idx = layer_idx
|
| 65 |
+
self.bimamba = bimamba
|
| 66 |
+
|
| 67 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 68 |
+
|
| 69 |
+
self.conv1d = nn.Conv1d(
|
| 70 |
+
in_channels=self.d_inner,
|
| 71 |
+
out_channels=self.d_inner,
|
| 72 |
+
bias=conv_bias,
|
| 73 |
+
kernel_size=d_conv,
|
| 74 |
+
groups=self.d_inner,
|
| 75 |
+
padding=d_conv - 1,
|
| 76 |
+
**factory_kwargs,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.activation = "silu"
|
| 80 |
+
self.act = nn.SiLU()
|
| 81 |
+
|
| 82 |
+
self.x_proj = nn.Linear(
|
| 83 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 84 |
+
)
|
| 85 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 86 |
+
|
| 87 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 88 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
| 89 |
+
if dt_init == "constant":
|
| 90 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 91 |
+
elif dt_init == "random":
|
| 92 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 93 |
+
else:
|
| 94 |
+
raise NotImplementedError
|
| 95 |
+
|
| 96 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 97 |
+
dt = torch.exp(
|
| 98 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 99 |
+
+ math.log(dt_min)
|
| 100 |
+
).clamp(min=dt_init_floor)
|
| 101 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 102 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 105 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 106 |
+
self.dt_proj.bias._no_reinit = True
|
| 107 |
+
|
| 108 |
+
# S4D real initialization
|
| 109 |
+
# NOTE: why plus 1?
|
| 110 |
+
A = repeat(
|
| 111 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 112 |
+
"n -> d n",
|
| 113 |
+
d=self.d_inner,
|
| 114 |
+
).contiguous()
|
| 115 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 116 |
+
self.A_log = nn.Parameter(A_log)
|
| 117 |
+
self.A_log._no_weight_decay = True
|
| 118 |
+
|
| 119 |
+
# D "skip" parameter
|
| 120 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 121 |
+
self.D._no_weight_decay = True
|
| 122 |
+
|
| 123 |
+
# bidirectional
|
| 124 |
+
# forked from https://github.com/hustvl/Vim
|
| 125 |
+
if self.bimamba:
|
| 126 |
+
A_b = repeat(
|
| 127 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 128 |
+
"n -> d n",
|
| 129 |
+
d=self.d_inner,
|
| 130 |
+
).contiguous()
|
| 131 |
+
A_b_log = torch.log(A_b) # Keep A_b_log in fp32
|
| 132 |
+
self.A_b_log = nn.Parameter(A_b_log)
|
| 133 |
+
self.A_b_log._no_weight_decay = True
|
| 134 |
+
|
| 135 |
+
self.conv1d_b = nn.Conv1d(
|
| 136 |
+
in_channels=self.d_inner,
|
| 137 |
+
out_channels=self.d_inner,
|
| 138 |
+
bias=conv_bias,
|
| 139 |
+
kernel_size=d_conv,
|
| 140 |
+
groups=self.d_inner,
|
| 141 |
+
padding=d_conv - 1,
|
| 142 |
+
**factory_kwargs,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.x_proj_b = nn.Linear(
|
| 146 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 147 |
+
)
|
| 148 |
+
self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 149 |
+
|
| 150 |
+
self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 151 |
+
self.D_b._no_weight_decay = True
|
| 152 |
+
|
| 153 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 154 |
+
|
| 155 |
+
def forward(self, hidden_states, inference_params=None, T=1):
|
| 156 |
+
"""
|
| 157 |
+
hidden_states: (B, L, D)
|
| 158 |
+
Returns: same shape as hidden_states
|
| 159 |
+
"""
|
| 160 |
+
batch, seqlen, dim = hidden_states.shape
|
| 161 |
+
|
| 162 |
+
conv_state, ssm_state = None, None
|
| 163 |
+
if inference_params is not None:
|
| 164 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 165 |
+
if inference_params.seqlen_offset > 0:
|
| 166 |
+
# The states are updated inplace
|
| 167 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
| 168 |
+
return out
|
| 169 |
+
|
| 170 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
| 171 |
+
# NOTE: same as in_proj(hidden_states) but memory-efficient with the following operations
|
| 172 |
+
xz = rearrange(
|
| 173 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
| 174 |
+
"d (b l) -> b d l",
|
| 175 |
+
l=seqlen,
|
| 176 |
+
)
|
| 177 |
+
if self.in_proj.bias is not None:
|
| 178 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
| 179 |
+
|
| 180 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 181 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
| 182 |
+
if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
|
| 183 |
+
if self.bimamba:
|
| 184 |
+
A_b = -torch.exp(self.A_b_log.float())
|
| 185 |
+
out = mamba_inner_fn_no_out_proj(
|
| 186 |
+
xz,
|
| 187 |
+
self.conv1d.weight,
|
| 188 |
+
self.conv1d.bias,
|
| 189 |
+
self.x_proj.weight,
|
| 190 |
+
self.dt_proj.weight,
|
| 191 |
+
A,
|
| 192 |
+
None, # input-dependent B
|
| 193 |
+
None, # input-dependent C
|
| 194 |
+
self.D.float(),
|
| 195 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 196 |
+
delta_softplus=True,
|
| 197 |
+
)
|
| 198 |
+
out_b = mamba_inner_fn_no_out_proj(
|
| 199 |
+
xz.flip([-1]),
|
| 200 |
+
self.conv1d_b.weight,
|
| 201 |
+
self.conv1d_b.bias,
|
| 202 |
+
self.x_proj_b.weight,
|
| 203 |
+
self.dt_proj_b.weight,
|
| 204 |
+
A_b,
|
| 205 |
+
None,
|
| 206 |
+
None,
|
| 207 |
+
self.D_b.float(),
|
| 208 |
+
delta_bias=self.dt_proj_b.bias.float(),
|
| 209 |
+
delta_softplus=True,
|
| 210 |
+
)
|
| 211 |
+
out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
|
| 212 |
+
else:
|
| 213 |
+
out = mamba_inner_fn(
|
| 214 |
+
xz,
|
| 215 |
+
self.conv1d.weight,
|
| 216 |
+
self.conv1d.bias,
|
| 217 |
+
self.x_proj.weight,
|
| 218 |
+
self.dt_proj.weight,
|
| 219 |
+
self.out_proj.weight,
|
| 220 |
+
self.out_proj.bias,
|
| 221 |
+
A,
|
| 222 |
+
None, # input-dependent B
|
| 223 |
+
None, # input-dependent C
|
| 224 |
+
self.D.float(),
|
| 225 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 226 |
+
delta_softplus=True,
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
x, z = xz.chunk(2, dim=1)
|
| 230 |
+
# Compute short convolution
|
| 231 |
+
if conv_state is not None:
|
| 232 |
+
conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W)
|
| 233 |
+
if causal_conv1d_fn is None:
|
| 234 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
| 235 |
+
else:
|
| 236 |
+
assert self.activation in ["silu", "swish"]
|
| 237 |
+
x = causal_conv1d_fn(
|
| 238 |
+
x,
|
| 239 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 240 |
+
self.conv1d.bias,
|
| 241 |
+
self.activation,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# We're careful here about the layout, to avoid extra transposes.
|
| 245 |
+
# We want dt to have d as the slowest moving dimension
|
| 246 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 247 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
| 248 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 249 |
+
dt = self.dt_proj.weight @ dt.t()
|
| 250 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
| 251 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 252 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 253 |
+
assert self.activation in ["silu", "swish"]
|
| 254 |
+
y = selective_scan_fn(
|
| 255 |
+
x,
|
| 256 |
+
dt,
|
| 257 |
+
A,
|
| 258 |
+
B,
|
| 259 |
+
C,
|
| 260 |
+
self.D.float(),
|
| 261 |
+
z=z,
|
| 262 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 263 |
+
delta_softplus=True,
|
| 264 |
+
return_last_state=ssm_state is not None,
|
| 265 |
+
)
|
| 266 |
+
if ssm_state is not None:
|
| 267 |
+
y, last_state = y
|
| 268 |
+
ssm_state.copy_(last_state)
|
| 269 |
+
y = rearrange(y, "b d l -> b l d")
|
| 270 |
+
out = self.out_proj(y)
|
| 271 |
+
return out
|
| 272 |
+
|
| 273 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
| 274 |
+
dtype = hidden_states.dtype
|
| 275 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
| 276 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
| 277 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
| 278 |
+
|
| 279 |
+
# Conv step
|
| 280 |
+
if causal_conv1d_update is None:
|
| 281 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
| 282 |
+
conv_state[:, :, -1] = x
|
| 283 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
| 284 |
+
if self.conv1d.bias is not None:
|
| 285 |
+
x = x + self.conv1d.bias
|
| 286 |
+
x = self.act(x).to(dtype=dtype)
|
| 287 |
+
else:
|
| 288 |
+
x = causal_conv1d_update(
|
| 289 |
+
x,
|
| 290 |
+
conv_state,
|
| 291 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 292 |
+
self.conv1d.bias,
|
| 293 |
+
self.activation,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
| 297 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 298 |
+
# Don't add dt_bias here
|
| 299 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
| 300 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 301 |
+
|
| 302 |
+
# SSM step
|
| 303 |
+
if selective_state_update is None:
|
| 304 |
+
# Discretize A and B
|
| 305 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
| 306 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
| 307 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
| 308 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
| 309 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
| 310 |
+
y = y + self.D.to(dtype) * x
|
| 311 |
+
y = y * self.act(z) # (B D)
|
| 312 |
+
else:
|
| 313 |
+
y = selective_state_update(
|
| 314 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
out = self.out_proj(y)
|
| 318 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
| 319 |
+
|
| 320 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 321 |
+
device = self.out_proj.weight.device
|
| 322 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
| 323 |
+
conv_state = torch.zeros(
|
| 324 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
| 325 |
+
)
|
| 326 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
| 327 |
+
# ssm_dtype = torch.float32
|
| 328 |
+
ssm_state = torch.zeros(
|
| 329 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
| 330 |
+
)
|
| 331 |
+
return conv_state, ssm_state
|
| 332 |
+
|
| 333 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
| 334 |
+
assert self.layer_idx is not None
|
| 335 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
| 336 |
+
batch_shape = (batch_size,)
|
| 337 |
+
conv_state = torch.zeros(
|
| 338 |
+
batch_size,
|
| 339 |
+
self.d_model * self.expand,
|
| 340 |
+
self.d_conv,
|
| 341 |
+
device=self.conv1d.weight.device,
|
| 342 |
+
dtype=self.conv1d.weight.dtype,
|
| 343 |
+
)
|
| 344 |
+
ssm_state = torch.zeros(
|
| 345 |
+
batch_size,
|
| 346 |
+
self.d_model * self.expand,
|
| 347 |
+
self.d_state,
|
| 348 |
+
device=self.dt_proj.weight.device,
|
| 349 |
+
dtype=self.dt_proj.weight.dtype,
|
| 350 |
+
# dtype=torch.float32,
|
| 351 |
+
)
|
| 352 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
| 353 |
+
else:
|
| 354 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
| 355 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
| 356 |
+
if initialize_states:
|
| 357 |
+
conv_state.zero_()
|
| 358 |
+
ssm_state.zero_()
|
| 359 |
+
return conv_state, ssm_state
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class Block(nn.Module):
|
| 363 |
+
def __init__(
|
| 364 |
+
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
| 365 |
+
):
|
| 366 |
+
"""
|
| 367 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
| 368 |
+
|
| 369 |
+
This Block has a slightly different structure compared to a regular
|
| 370 |
+
prenorm Transformer block.
|
| 371 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
| 372 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 373 |
+
Here we have: Add -> LN -> Mixer, returning both
|
| 374 |
+
the hidden_states (output of the mixer) and the residual.
|
| 375 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
| 376 |
+
The residual needs to be provided (except for the very first block).
|
| 377 |
+
"""
|
| 378 |
+
super().__init__()
|
| 379 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 380 |
+
self.fused_add_norm = fused_add_norm
|
| 381 |
+
self.mixer = mixer_cls(dim)
|
| 382 |
+
self.norm = norm_cls(dim)
|
| 383 |
+
if self.fused_add_norm:
|
| 384 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
| 385 |
+
assert isinstance(
|
| 386 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
| 387 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
| 388 |
+
|
| 389 |
+
def forward(
|
| 390 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
| 391 |
+
):
|
| 392 |
+
r"""Pass the input through the encoder layer.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 396 |
+
residual: hidden_states = Mixer(LN(residual))
|
| 397 |
+
"""
|
| 398 |
+
if not self.fused_add_norm:
|
| 399 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 400 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
| 401 |
+
if self.residual_in_fp32:
|
| 402 |
+
residual = residual.to(torch.float32)
|
| 403 |
+
else:
|
| 404 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
| 405 |
+
hidden_states, residual = fused_add_norm_fn(
|
| 406 |
+
hidden_states,
|
| 407 |
+
self.norm.weight,
|
| 408 |
+
self.norm.bias,
|
| 409 |
+
residual=residual,
|
| 410 |
+
prenorm=True,
|
| 411 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 412 |
+
eps=self.norm.eps,
|
| 413 |
+
)
|
| 414 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
| 415 |
+
return hidden_states, residual
|
| 416 |
+
|
| 417 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 418 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
mamba_ssm/ops/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
def selective_scan(*args, **kwargs):
|
| 4 |
+
for a in args:
|
| 5 |
+
if isinstance(a, torch.Tensor):
|
| 6 |
+
return a
|
| 7 |
+
return torch.tensor(0.0)
|
mamba_ssm/ops/selective_scan_interface.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
def selective_scan_fn(*args, **kwargs):
|
| 4 |
+
for a in args:
|
| 5 |
+
if isinstance(a, torch.Tensor):
|
| 6 |
+
return a
|
| 7 |
+
return torch.tensor(0.0)
|
| 8 |
+
def mamba_inner_fn(*args, **kwargs):
|
| 9 |
+
for a in args:
|
| 10 |
+
if isinstance(a, torch.Tensor):
|
| 11 |
+
return a
|
| 12 |
+
return torch.tensor(0.0)
|
| 13 |
+
def bimamba_inner_fn(*args, **kwargs):
|
| 14 |
+
for a in args:
|
| 15 |
+
if isinstance(a, torch.Tensor):
|
| 16 |
+
return a
|
| 17 |
+
return torch.tensor(0.0)
|
mamba_ssm/ops/triton/__init__.py
ADDED
|
File without changes
|
mamba_ssm/ops/triton/layernorm.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
# Implement residual + layer_norm / rms_norm.
|
| 3 |
+
|
| 4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
| 6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
| 7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
| 14 |
+
|
| 15 |
+
import triton
|
| 16 |
+
import triton.language as tl
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
| 20 |
+
dtype = x.dtype
|
| 21 |
+
if upcast:
|
| 22 |
+
weight = weight.float()
|
| 23 |
+
bias = bias.float() if bias is not None else None
|
| 24 |
+
if upcast:
|
| 25 |
+
x = x.float()
|
| 26 |
+
residual = residual.float() if residual is not None else residual
|
| 27 |
+
if residual is not None:
|
| 28 |
+
x = (x + residual).to(x.dtype)
|
| 29 |
+
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
| 30 |
+
dtype
|
| 31 |
+
)
|
| 32 |
+
return out if not prenorm else (out, x)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
| 36 |
+
dtype = x.dtype
|
| 37 |
+
if upcast:
|
| 38 |
+
weight = weight.float()
|
| 39 |
+
bias = bias.float() if bias is not None else None
|
| 40 |
+
if upcast:
|
| 41 |
+
x = x.float()
|
| 42 |
+
residual = residual.float() if residual is not None else residual
|
| 43 |
+
if residual is not None:
|
| 44 |
+
x = (x + residual).to(x.dtype)
|
| 45 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
| 46 |
+
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
| 47 |
+
out = out.to(dtype)
|
| 48 |
+
return out if not prenorm else (out, x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@triton.autotune(
|
| 52 |
+
configs=[
|
| 53 |
+
triton.Config({}, num_warps=1),
|
| 54 |
+
triton.Config({}, num_warps=2),
|
| 55 |
+
triton.Config({}, num_warps=4),
|
| 56 |
+
triton.Config({}, num_warps=8),
|
| 57 |
+
triton.Config({}, num_warps=16),
|
| 58 |
+
triton.Config({}, num_warps=32),
|
| 59 |
+
],
|
| 60 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 61 |
+
)
|
| 62 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 63 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
| 64 |
+
@triton.jit
|
| 65 |
+
def _layer_norm_fwd_1pass_kernel(
|
| 66 |
+
X, # pointer to the input
|
| 67 |
+
Y, # pointer to the output
|
| 68 |
+
W, # pointer to the weights
|
| 69 |
+
B, # pointer to the biases
|
| 70 |
+
RESIDUAL, # pointer to the residual
|
| 71 |
+
RESIDUAL_OUT, # pointer to the residual
|
| 72 |
+
Mean, # pointer to the mean
|
| 73 |
+
Rstd, # pointer to the 1/std
|
| 74 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 75 |
+
stride_y_row,
|
| 76 |
+
stride_res_row,
|
| 77 |
+
stride_res_out_row,
|
| 78 |
+
N, # number of columns in X
|
| 79 |
+
eps, # epsilon to avoid division by zero
|
| 80 |
+
IS_RMS_NORM: tl.constexpr,
|
| 81 |
+
BLOCK_N: tl.constexpr,
|
| 82 |
+
HAS_RESIDUAL: tl.constexpr,
|
| 83 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
| 84 |
+
HAS_BIAS: tl.constexpr,
|
| 85 |
+
):
|
| 86 |
+
# Map the program id to the row of X and Y it should compute.
|
| 87 |
+
row = tl.program_id(0)
|
| 88 |
+
X += row * stride_x_row
|
| 89 |
+
Y += row * stride_y_row
|
| 90 |
+
if HAS_RESIDUAL:
|
| 91 |
+
RESIDUAL += row * stride_res_row
|
| 92 |
+
if STORE_RESIDUAL_OUT:
|
| 93 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
| 94 |
+
# Compute mean and variance
|
| 95 |
+
cols = tl.arange(0, BLOCK_N)
|
| 96 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 97 |
+
if HAS_RESIDUAL:
|
| 98 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
| 99 |
+
x += residual
|
| 100 |
+
if STORE_RESIDUAL_OUT:
|
| 101 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
| 102 |
+
if not IS_RMS_NORM:
|
| 103 |
+
mean = tl.sum(x, axis=0) / N
|
| 104 |
+
tl.store(Mean + row, mean)
|
| 105 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
| 106 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
| 107 |
+
else:
|
| 108 |
+
xbar = tl.where(cols < N, x, 0.0)
|
| 109 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
| 110 |
+
rstd = 1 / tl.sqrt(var + eps)
|
| 111 |
+
tl.store(Rstd + row, rstd)
|
| 112 |
+
# Normalize and apply linear transformation
|
| 113 |
+
mask = cols < N
|
| 114 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 115 |
+
if HAS_BIAS:
|
| 116 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
| 117 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 118 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
| 119 |
+
# Write output
|
| 120 |
+
tl.store(Y + cols, y, mask=mask)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _layer_norm_fwd(
|
| 124 |
+
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
| 125 |
+
):
|
| 126 |
+
if residual is not None:
|
| 127 |
+
residual_dtype = residual.dtype
|
| 128 |
+
M, N = x.shape
|
| 129 |
+
assert x.stride(-1) == 1
|
| 130 |
+
if residual is not None:
|
| 131 |
+
assert residual.stride(-1) == 1
|
| 132 |
+
assert residual.shape == (M, N)
|
| 133 |
+
assert weight.shape == (N,)
|
| 134 |
+
assert weight.stride(-1) == 1
|
| 135 |
+
if bias is not None:
|
| 136 |
+
assert bias.stride(-1) == 1
|
| 137 |
+
assert bias.shape == (N,)
|
| 138 |
+
# allocate output
|
| 139 |
+
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
| 140 |
+
assert y.stride(-1) == 1
|
| 141 |
+
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
| 142 |
+
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
| 143 |
+
assert residual_out.stride(-1) == 1
|
| 144 |
+
else:
|
| 145 |
+
residual_out = None
|
| 146 |
+
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
|
| 147 |
+
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
| 148 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 149 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 150 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 151 |
+
if N > BLOCK_N:
|
| 152 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 153 |
+
# heuristics for number of warps
|
| 154 |
+
with torch.cuda.device(x.device.index):
|
| 155 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
| 156 |
+
x,
|
| 157 |
+
y,
|
| 158 |
+
weight,
|
| 159 |
+
bias,
|
| 160 |
+
residual,
|
| 161 |
+
residual_out,
|
| 162 |
+
mean,
|
| 163 |
+
rstd,
|
| 164 |
+
x.stride(0),
|
| 165 |
+
y.stride(0),
|
| 166 |
+
residual.stride(0) if residual is not None else 0,
|
| 167 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
| 168 |
+
N,
|
| 169 |
+
eps,
|
| 170 |
+
is_rms_norm,
|
| 171 |
+
BLOCK_N,
|
| 172 |
+
residual is not None,
|
| 173 |
+
residual_out is not None,
|
| 174 |
+
bias is not None,
|
| 175 |
+
)
|
| 176 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype
|
| 177 |
+
return y, mean, rstd, residual_out if residual_out is not None else x
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@triton.autotune(
|
| 181 |
+
configs=[
|
| 182 |
+
triton.Config({}, num_warps=1),
|
| 183 |
+
triton.Config({}, num_warps=2),
|
| 184 |
+
triton.Config({}, num_warps=4),
|
| 185 |
+
triton.Config({}, num_warps=8),
|
| 186 |
+
triton.Config({}, num_warps=16),
|
| 187 |
+
triton.Config({}, num_warps=32),
|
| 188 |
+
],
|
| 189 |
+
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
| 190 |
+
)
|
| 191 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
| 192 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
| 193 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
| 194 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
| 195 |
+
@triton.jit
|
| 196 |
+
def _layer_norm_bwd_kernel(
|
| 197 |
+
X, # pointer to the input
|
| 198 |
+
W, # pointer to the weights
|
| 199 |
+
B, # pointer to the biases
|
| 200 |
+
Y, # pointer to the output to be recomputed
|
| 201 |
+
DY, # pointer to the output gradient
|
| 202 |
+
DX, # pointer to the input gradient
|
| 203 |
+
DW, # pointer to the partial sum of weights gradient
|
| 204 |
+
DB, # pointer to the partial sum of biases gradient
|
| 205 |
+
DRESIDUAL,
|
| 206 |
+
DRESIDUAL_IN,
|
| 207 |
+
Mean, # pointer to the mean
|
| 208 |
+
Rstd, # pointer to the 1/std
|
| 209 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
| 210 |
+
stride_y_row,
|
| 211 |
+
stride_dy_row,
|
| 212 |
+
stride_dx_row,
|
| 213 |
+
stride_dres_row,
|
| 214 |
+
stride_dres_in_row,
|
| 215 |
+
M, # number of rows in X
|
| 216 |
+
N, # number of columns in X
|
| 217 |
+
eps, # epsilon to avoid division by zero
|
| 218 |
+
rows_per_program,
|
| 219 |
+
IS_RMS_NORM: tl.constexpr,
|
| 220 |
+
BLOCK_N: tl.constexpr,
|
| 221 |
+
HAS_DRESIDUAL: tl.constexpr,
|
| 222 |
+
STORE_DRESIDUAL: tl.constexpr,
|
| 223 |
+
HAS_BIAS: tl.constexpr,
|
| 224 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
| 225 |
+
):
|
| 226 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
| 227 |
+
row_block_id = tl.program_id(0)
|
| 228 |
+
row_start = row_block_id * rows_per_program
|
| 229 |
+
cols = tl.arange(0, BLOCK_N)
|
| 230 |
+
mask = cols < N
|
| 231 |
+
X += row_start * stride_x_row
|
| 232 |
+
if HAS_DRESIDUAL:
|
| 233 |
+
DRESIDUAL += row_start * stride_dres_row
|
| 234 |
+
if STORE_DRESIDUAL:
|
| 235 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
| 236 |
+
DY += row_start * stride_dy_row
|
| 237 |
+
DX += row_start * stride_dx_row
|
| 238 |
+
if RECOMPUTE_OUTPUT:
|
| 239 |
+
Y += row_start * stride_y_row
|
| 240 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
| 241 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
| 242 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
| 243 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 244 |
+
if HAS_BIAS:
|
| 245 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 246 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
| 247 |
+
for row in range(row_start, row_end):
|
| 248 |
+
# Load data to SRAM
|
| 249 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
| 250 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
| 251 |
+
if not IS_RMS_NORM:
|
| 252 |
+
mean = tl.load(Mean + row)
|
| 253 |
+
rstd = tl.load(Rstd + row)
|
| 254 |
+
# Compute dx
|
| 255 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
| 256 |
+
xhat = tl.where(mask, xhat, 0.0)
|
| 257 |
+
if RECOMPUTE_OUTPUT:
|
| 258 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
| 259 |
+
tl.store(Y + cols, y, mask=mask)
|
| 260 |
+
wdy = w * dy
|
| 261 |
+
dw += dy * xhat
|
| 262 |
+
if HAS_BIAS:
|
| 263 |
+
db += dy
|
| 264 |
+
if not IS_RMS_NORM:
|
| 265 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 266 |
+
c2 = tl.sum(wdy, axis=0) / N
|
| 267 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
| 268 |
+
else:
|
| 269 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
| 270 |
+
dx = (wdy - xhat * c1) * rstd
|
| 271 |
+
if HAS_DRESIDUAL:
|
| 272 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
| 273 |
+
dx += dres
|
| 274 |
+
# Write dx
|
| 275 |
+
if STORE_DRESIDUAL:
|
| 276 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
| 277 |
+
tl.store(DX + cols, dx, mask=mask)
|
| 278 |
+
|
| 279 |
+
X += stride_x_row
|
| 280 |
+
if HAS_DRESIDUAL:
|
| 281 |
+
DRESIDUAL += stride_dres_row
|
| 282 |
+
if STORE_DRESIDUAL:
|
| 283 |
+
DRESIDUAL_IN += stride_dres_in_row
|
| 284 |
+
if RECOMPUTE_OUTPUT:
|
| 285 |
+
Y += stride_y_row
|
| 286 |
+
DY += stride_dy_row
|
| 287 |
+
DX += stride_dx_row
|
| 288 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
| 289 |
+
if HAS_BIAS:
|
| 290 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _layer_norm_bwd(
|
| 294 |
+
dy,
|
| 295 |
+
x,
|
| 296 |
+
weight,
|
| 297 |
+
bias,
|
| 298 |
+
eps,
|
| 299 |
+
mean,
|
| 300 |
+
rstd,
|
| 301 |
+
dresidual=None,
|
| 302 |
+
has_residual=False,
|
| 303 |
+
is_rms_norm=False,
|
| 304 |
+
x_dtype=None,
|
| 305 |
+
recompute_output=False,
|
| 306 |
+
):
|
| 307 |
+
M, N = x.shape
|
| 308 |
+
assert x.stride(-1) == 1
|
| 309 |
+
assert dy.stride(-1) == 1
|
| 310 |
+
assert dy.shape == (M, N)
|
| 311 |
+
if dresidual is not None:
|
| 312 |
+
assert dresidual.stride(-1) == 1
|
| 313 |
+
assert dresidual.shape == (M, N)
|
| 314 |
+
assert weight.shape == (N,)
|
| 315 |
+
assert weight.stride(-1) == 1
|
| 316 |
+
if bias is not None:
|
| 317 |
+
assert bias.stride(-1) == 1
|
| 318 |
+
assert bias.shape == (N,)
|
| 319 |
+
# allocate output
|
| 320 |
+
dx = (
|
| 321 |
+
torch.empty_like(x)
|
| 322 |
+
if x_dtype is None
|
| 323 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
| 324 |
+
)
|
| 325 |
+
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
|
| 326 |
+
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
| 327 |
+
|
| 328 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 329 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 330 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 331 |
+
if N > BLOCK_N:
|
| 332 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
| 333 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 334 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
| 335 |
+
_db = (
|
| 336 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
| 337 |
+
if bias is not None
|
| 338 |
+
else None
|
| 339 |
+
)
|
| 340 |
+
rows_per_program = math.ceil(M / sm_count)
|
| 341 |
+
grid = (sm_count,)
|
| 342 |
+
with torch.cuda.device(x.device.index):
|
| 343 |
+
_layer_norm_bwd_kernel[grid](
|
| 344 |
+
x,
|
| 345 |
+
weight,
|
| 346 |
+
bias,
|
| 347 |
+
y,
|
| 348 |
+
dy,
|
| 349 |
+
dx,
|
| 350 |
+
_dw,
|
| 351 |
+
_db,
|
| 352 |
+
dresidual,
|
| 353 |
+
dresidual_in,
|
| 354 |
+
mean,
|
| 355 |
+
rstd,
|
| 356 |
+
x.stride(0),
|
| 357 |
+
0 if not recompute_output else y.stride(0),
|
| 358 |
+
dy.stride(0),
|
| 359 |
+
dx.stride(0),
|
| 360 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
| 361 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
| 362 |
+
M,
|
| 363 |
+
N,
|
| 364 |
+
eps,
|
| 365 |
+
rows_per_program,
|
| 366 |
+
is_rms_norm,
|
| 367 |
+
BLOCK_N,
|
| 368 |
+
dresidual is not None,
|
| 369 |
+
dresidual_in is not None,
|
| 370 |
+
bias is not None,
|
| 371 |
+
)
|
| 372 |
+
dw = _dw.sum(0).to(weight.dtype)
|
| 373 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
| 374 |
+
# Don't need to compute dresidual_in separately in this case
|
| 375 |
+
if has_residual and dx.dtype == x.dtype:
|
| 376 |
+
dresidual_in = dx
|
| 377 |
+
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class LayerNormFn(torch.autograd.Function):
|
| 381 |
+
@staticmethod
|
| 382 |
+
def forward(
|
| 383 |
+
ctx,
|
| 384 |
+
x,
|
| 385 |
+
weight,
|
| 386 |
+
bias,
|
| 387 |
+
residual=None,
|
| 388 |
+
eps=1e-6,
|
| 389 |
+
prenorm=False,
|
| 390 |
+
residual_in_fp32=False,
|
| 391 |
+
is_rms_norm=False,
|
| 392 |
+
):
|
| 393 |
+
x_shape_og = x.shape
|
| 394 |
+
# reshape input data into 2D tensor
|
| 395 |
+
x = x.reshape(-1, x.shape[-1])
|
| 396 |
+
if x.stride(-1) != 1:
|
| 397 |
+
x = x.contiguous()
|
| 398 |
+
if residual is not None:
|
| 399 |
+
assert residual.shape == x_shape_og
|
| 400 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
| 401 |
+
if residual.stride(-1) != 1:
|
| 402 |
+
residual = residual.contiguous()
|
| 403 |
+
weight = weight.contiguous()
|
| 404 |
+
if bias is not None:
|
| 405 |
+
bias = bias.contiguous()
|
| 406 |
+
residual_dtype = (
|
| 407 |
+
residual.dtype
|
| 408 |
+
if residual is not None
|
| 409 |
+
else (torch.float32 if residual_in_fp32 else None)
|
| 410 |
+
)
|
| 411 |
+
y, mean, rstd, residual_out = _layer_norm_fwd(
|
| 412 |
+
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
|
| 413 |
+
)
|
| 414 |
+
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
|
| 415 |
+
ctx.x_shape_og = x_shape_og
|
| 416 |
+
ctx.eps = eps
|
| 417 |
+
ctx.is_rms_norm = is_rms_norm
|
| 418 |
+
ctx.has_residual = residual is not None
|
| 419 |
+
ctx.prenorm = prenorm
|
| 420 |
+
ctx.x_dtype = x.dtype
|
| 421 |
+
y = y.reshape(x_shape_og)
|
| 422 |
+
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
| 423 |
+
|
| 424 |
+
@staticmethod
|
| 425 |
+
def backward(ctx, dy, *args):
|
| 426 |
+
x, weight, bias, mean, rstd = ctx.saved_tensors
|
| 427 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
| 428 |
+
if dy.stride(-1) != 1:
|
| 429 |
+
dy = dy.contiguous()
|
| 430 |
+
assert dy.shape == x.shape
|
| 431 |
+
if ctx.prenorm:
|
| 432 |
+
dresidual = args[0]
|
| 433 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 434 |
+
if dresidual.stride(-1) != 1:
|
| 435 |
+
dresidual = dresidual.contiguous()
|
| 436 |
+
assert dresidual.shape == x.shape
|
| 437 |
+
else:
|
| 438 |
+
dresidual = None
|
| 439 |
+
dx, dw, db, dresidual_in = _layer_norm_bwd(
|
| 440 |
+
dy,
|
| 441 |
+
x,
|
| 442 |
+
weight,
|
| 443 |
+
bias,
|
| 444 |
+
ctx.eps,
|
| 445 |
+
mean,
|
| 446 |
+
rstd,
|
| 447 |
+
dresidual,
|
| 448 |
+
ctx.has_residual,
|
| 449 |
+
ctx.is_rms_norm,
|
| 450 |
+
x_dtype=ctx.x_dtype,
|
| 451 |
+
)
|
| 452 |
+
return (
|
| 453 |
+
dx.reshape(ctx.x_shape_og),
|
| 454 |
+
dw,
|
| 455 |
+
db,
|
| 456 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 457 |
+
None,
|
| 458 |
+
None,
|
| 459 |
+
None,
|
| 460 |
+
None,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def layer_norm_fn(
|
| 465 |
+
x,
|
| 466 |
+
weight,
|
| 467 |
+
bias,
|
| 468 |
+
residual=None,
|
| 469 |
+
eps=1e-6,
|
| 470 |
+
prenorm=False,
|
| 471 |
+
residual_in_fp32=False,
|
| 472 |
+
is_rms_norm=False,
|
| 473 |
+
):
|
| 474 |
+
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
|
| 478 |
+
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class RMSNorm(torch.nn.Module):
|
| 482 |
+
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
| 483 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 484 |
+
super().__init__()
|
| 485 |
+
self.eps = eps
|
| 486 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 487 |
+
self.register_parameter("bias", None)
|
| 488 |
+
self.reset_parameters()
|
| 489 |
+
|
| 490 |
+
def reset_parameters(self):
|
| 491 |
+
torch.nn.init.ones_(self.weight)
|
| 492 |
+
|
| 493 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
| 494 |
+
return rms_norm_fn(
|
| 495 |
+
x,
|
| 496 |
+
self.weight,
|
| 497 |
+
self.bias,
|
| 498 |
+
residual=residual,
|
| 499 |
+
eps=self.eps,
|
| 500 |
+
prenorm=prenorm,
|
| 501 |
+
residual_in_fp32=residual_in_fp32,
|
| 502 |
+
is_rms_norm=True,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
| 507 |
+
@staticmethod
|
| 508 |
+
@custom_fwd
|
| 509 |
+
def forward(
|
| 510 |
+
ctx,
|
| 511 |
+
x,
|
| 512 |
+
norm_weight,
|
| 513 |
+
norm_bias,
|
| 514 |
+
linear_weight,
|
| 515 |
+
linear_bias,
|
| 516 |
+
residual=None,
|
| 517 |
+
eps=1e-6,
|
| 518 |
+
prenorm=False,
|
| 519 |
+
residual_in_fp32=False,
|
| 520 |
+
is_rms_norm=False,
|
| 521 |
+
):
|
| 522 |
+
x_shape_og = x.shape
|
| 523 |
+
# reshape input data into 2D tensor
|
| 524 |
+
x = x.reshape(-1, x.shape[-1])
|
| 525 |
+
if x.stride(-1) != 1:
|
| 526 |
+
x = x.contiguous()
|
| 527 |
+
if residual is not None:
|
| 528 |
+
assert residual.shape == x_shape_og
|
| 529 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
| 530 |
+
if residual.stride(-1) != 1:
|
| 531 |
+
residual = residual.contiguous()
|
| 532 |
+
norm_weight = norm_weight.contiguous()
|
| 533 |
+
if norm_bias is not None:
|
| 534 |
+
norm_bias = norm_bias.contiguous()
|
| 535 |
+
residual_dtype = (
|
| 536 |
+
residual.dtype
|
| 537 |
+
if residual is not None
|
| 538 |
+
else (torch.float32 if residual_in_fp32 else None)
|
| 539 |
+
)
|
| 540 |
+
y, mean, rstd, residual_out = _layer_norm_fwd(
|
| 541 |
+
x,
|
| 542 |
+
norm_weight,
|
| 543 |
+
norm_bias,
|
| 544 |
+
eps,
|
| 545 |
+
residual,
|
| 546 |
+
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
|
| 547 |
+
residual_dtype=residual_dtype,
|
| 548 |
+
is_rms_norm=is_rms_norm,
|
| 549 |
+
)
|
| 550 |
+
y = y.reshape(x_shape_og)
|
| 551 |
+
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
| 552 |
+
linear_weight = linear_weight.to(dtype)
|
| 553 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
| 554 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
| 555 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
| 556 |
+
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
| 557 |
+
ctx.x_shape_og = x_shape_og
|
| 558 |
+
ctx.eps = eps
|
| 559 |
+
ctx.is_rms_norm = is_rms_norm
|
| 560 |
+
ctx.has_residual = residual is not None
|
| 561 |
+
ctx.prenorm = prenorm
|
| 562 |
+
ctx.x_dtype = x.dtype
|
| 563 |
+
ctx.linear_bias_is_none = linear_bias is None
|
| 564 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 565 |
+
|
| 566 |
+
@staticmethod
|
| 567 |
+
@custom_bwd
|
| 568 |
+
def backward(ctx, dout, *args):
|
| 569 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 570 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
| 571 |
+
dy = F.linear(dout, linear_weight.t())
|
| 572 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
| 573 |
+
if dy.stride(-1) != 1:
|
| 574 |
+
dy = dy.contiguous()
|
| 575 |
+
assert dy.shape == x.shape
|
| 576 |
+
if ctx.prenorm:
|
| 577 |
+
dresidual = args[0]
|
| 578 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
| 579 |
+
if dresidual.stride(-1) != 1:
|
| 580 |
+
dresidual = dresidual.contiguous()
|
| 581 |
+
assert dresidual.shape == x.shape
|
| 582 |
+
else:
|
| 583 |
+
dresidual = None
|
| 584 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
|
| 585 |
+
dy,
|
| 586 |
+
x,
|
| 587 |
+
norm_weight,
|
| 588 |
+
norm_bias,
|
| 589 |
+
ctx.eps,
|
| 590 |
+
mean,
|
| 591 |
+
rstd,
|
| 592 |
+
dresidual,
|
| 593 |
+
ctx.has_residual,
|
| 594 |
+
ctx.is_rms_norm,
|
| 595 |
+
x_dtype=ctx.x_dtype,
|
| 596 |
+
recompute_output=True,
|
| 597 |
+
)
|
| 598 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
| 599 |
+
return (
|
| 600 |
+
dx.reshape(ctx.x_shape_og),
|
| 601 |
+
dnorm_weight,
|
| 602 |
+
dnorm_bias,
|
| 603 |
+
dlinear_weight,
|
| 604 |
+
dlinear_bias,
|
| 605 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
| 606 |
+
None,
|
| 607 |
+
None,
|
| 608 |
+
None,
|
| 609 |
+
None,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def layer_norm_linear_fn(
|
| 614 |
+
x,
|
| 615 |
+
norm_weight,
|
| 616 |
+
norm_bias,
|
| 617 |
+
linear_weight,
|
| 618 |
+
linear_bias,
|
| 619 |
+
residual=None,
|
| 620 |
+
eps=1e-6,
|
| 621 |
+
prenorm=False,
|
| 622 |
+
residual_in_fp32=False,
|
| 623 |
+
is_rms_norm=False,
|
| 624 |
+
):
|
| 625 |
+
return LayerNormLinearFn.apply(
|
| 626 |
+
x,
|
| 627 |
+
norm_weight,
|
| 628 |
+
norm_bias,
|
| 629 |
+
linear_weight,
|
| 630 |
+
linear_bias,
|
| 631 |
+
residual,
|
| 632 |
+
eps,
|
| 633 |
+
prenorm,
|
| 634 |
+
residual_in_fp32,
|
| 635 |
+
is_rms_norm,
|
| 636 |
+
)
|
mamba_ssm/ops/triton/selective_state_update.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
|
| 3 |
+
"""We want triton==2.1.0 for this
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
| 17 |
+
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
| 18 |
+
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
| 19 |
+
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
| 20 |
+
@triton.jit
|
| 21 |
+
def _selective_scan_update_kernel(
|
| 22 |
+
# Pointers to matrices
|
| 23 |
+
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
|
| 24 |
+
# Matrix dimensions
|
| 25 |
+
batch, dim, dstate,
|
| 26 |
+
# Strides
|
| 27 |
+
stride_state_batch, stride_state_dim, stride_state_dstate,
|
| 28 |
+
stride_x_batch, stride_x_dim,
|
| 29 |
+
stride_dt_batch, stride_dt_dim,
|
| 30 |
+
stride_dt_bias_dim,
|
| 31 |
+
stride_A_dim, stride_A_dstate,
|
| 32 |
+
stride_B_batch, stride_B_dstate,
|
| 33 |
+
stride_C_batch, stride_C_dstate,
|
| 34 |
+
stride_D_dim,
|
| 35 |
+
stride_z_batch, stride_z_dim,
|
| 36 |
+
stride_out_batch, stride_out_dim,
|
| 37 |
+
# Meta-parameters
|
| 38 |
+
DT_SOFTPLUS: tl.constexpr,
|
| 39 |
+
BLOCK_SIZE_M: tl.constexpr,
|
| 40 |
+
HAS_DT_BIAS: tl.constexpr,
|
| 41 |
+
HAS_D: tl.constexpr,
|
| 42 |
+
HAS_Z: tl.constexpr,
|
| 43 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
| 44 |
+
):
|
| 45 |
+
pid_m = tl.program_id(axis=0)
|
| 46 |
+
pid_b = tl.program_id(axis=1)
|
| 47 |
+
state_ptr += pid_b * stride_state_batch
|
| 48 |
+
x_ptr += pid_b * stride_x_batch
|
| 49 |
+
dt_ptr += pid_b * stride_dt_batch
|
| 50 |
+
B_ptr += pid_b * stride_B_batch
|
| 51 |
+
C_ptr += pid_b * stride_C_batch
|
| 52 |
+
if HAS_Z:
|
| 53 |
+
z_ptr += pid_b * stride_z_batch
|
| 54 |
+
out_ptr += pid_b * stride_out_batch
|
| 55 |
+
|
| 56 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 57 |
+
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
| 58 |
+
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
|
| 59 |
+
x_ptrs = x_ptr + offs_m * stride_x_dim
|
| 60 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
| 61 |
+
if HAS_DT_BIAS:
|
| 62 |
+
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
| 63 |
+
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
|
| 64 |
+
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
| 65 |
+
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
| 66 |
+
if HAS_D:
|
| 67 |
+
D_ptrs = D_ptr + offs_m * stride_D_dim
|
| 68 |
+
if HAS_Z:
|
| 69 |
+
z_ptrs = z_ptr + offs_m * stride_z_dim
|
| 70 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
| 71 |
+
|
| 72 |
+
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
|
| 73 |
+
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 74 |
+
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 75 |
+
if HAS_DT_BIAS:
|
| 76 |
+
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 77 |
+
if DT_SOFTPLUS:
|
| 78 |
+
dt = tl.log(1.0 + tl.exp(dt))
|
| 79 |
+
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
| 80 |
+
dA = tl.exp(A * dt[:, None])
|
| 81 |
+
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 82 |
+
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
| 83 |
+
if HAS_D:
|
| 84 |
+
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 85 |
+
if HAS_Z:
|
| 86 |
+
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
| 87 |
+
|
| 88 |
+
dB = B[None, :] * dt[:, None]
|
| 89 |
+
state = state * dA + dB * x[:, None]
|
| 90 |
+
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
|
| 91 |
+
out = tl.sum(state * C[None, :], axis=1)
|
| 92 |
+
if HAS_D:
|
| 93 |
+
out += x * D
|
| 94 |
+
if HAS_Z:
|
| 95 |
+
out *= z * tl.sigmoid(z)
|
| 96 |
+
tl.store(out_ptrs, out, mask=offs_m < dim)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
| 100 |
+
"""
|
| 101 |
+
Argument:
|
| 102 |
+
state: (batch, dim, dstate)
|
| 103 |
+
x: (batch, dim)
|
| 104 |
+
dt: (batch, dim)
|
| 105 |
+
A: (dim, dstate)
|
| 106 |
+
B: (batch, dstate)
|
| 107 |
+
C: (batch, dstate)
|
| 108 |
+
D: (dim,)
|
| 109 |
+
z: (batch, dim)
|
| 110 |
+
dt_bias: (dim,)
|
| 111 |
+
Return:
|
| 112 |
+
out: (batch, dim)
|
| 113 |
+
"""
|
| 114 |
+
batch, dim, dstate = state.shape
|
| 115 |
+
assert x.shape == (batch, dim)
|
| 116 |
+
assert dt.shape == x.shape
|
| 117 |
+
assert A.shape == (dim, dstate)
|
| 118 |
+
assert B.shape == (batch, dstate)
|
| 119 |
+
assert C.shape == B.shape
|
| 120 |
+
if D is not None:
|
| 121 |
+
assert D.shape == (dim,)
|
| 122 |
+
if z is not None:
|
| 123 |
+
assert z.shape == x.shape
|
| 124 |
+
if dt_bias is not None:
|
| 125 |
+
assert dt_bias.shape == (dim,)
|
| 126 |
+
out = torch.empty_like(x)
|
| 127 |
+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
|
| 128 |
+
z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
|
| 129 |
+
# We don't want autotune since it will overwrite the state
|
| 130 |
+
# We instead tune by hand.
|
| 131 |
+
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
|
| 132 |
+
else ((16, 4) if dstate <= 32 else
|
| 133 |
+
((8, 4) if dstate <= 64 else
|
| 134 |
+
((4, 4) if dstate <= 128 else
|
| 135 |
+
((4, 8))))))
|
| 136 |
+
with torch.cuda.device(x.device.index):
|
| 137 |
+
_selective_scan_update_kernel[grid](
|
| 138 |
+
state, x, dt, dt_bias, A, B, C, D, z, out,
|
| 139 |
+
batch, dim, dstate,
|
| 140 |
+
state.stride(0), state.stride(1), state.stride(2),
|
| 141 |
+
x.stride(0), x.stride(1),
|
| 142 |
+
dt.stride(0), dt.stride(1),
|
| 143 |
+
dt_bias.stride(0) if dt_bias is not None else 0,
|
| 144 |
+
A.stride(0), A.stride(1),
|
| 145 |
+
B.stride(0), B.stride(1),
|
| 146 |
+
C.stride(0), C.stride(1),
|
| 147 |
+
D.stride(0) if D is not None else 0,
|
| 148 |
+
z_strides[0], z_strides[1],
|
| 149 |
+
out.stride(0), out.stride(1),
|
| 150 |
+
dt_softplus,
|
| 151 |
+
BLOCK_SIZE_M,
|
| 152 |
+
num_warps=num_warps,
|
| 153 |
+
)
|
| 154 |
+
return out
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
| 158 |
+
"""
|
| 159 |
+
Argument:
|
| 160 |
+
state: (batch, dim, dstate)
|
| 161 |
+
x: (batch, dim)
|
| 162 |
+
dt: (batch, dim)
|
| 163 |
+
A: (dim, dstate)
|
| 164 |
+
B: (batch, dstate)
|
| 165 |
+
C: (batch, dstate)
|
| 166 |
+
D: (dim,)
|
| 167 |
+
z: (batch, dim)
|
| 168 |
+
dt_bias: (dim,)
|
| 169 |
+
Return:
|
| 170 |
+
out: (batch, dim)
|
| 171 |
+
"""
|
| 172 |
+
batch, dim, dstate = state.shape
|
| 173 |
+
assert x.shape == (batch, dim)
|
| 174 |
+
assert dt.shape == x.shape
|
| 175 |
+
assert A.shape == (dim, dstate)
|
| 176 |
+
assert B.shape == (batch, dstate)
|
| 177 |
+
assert C.shape == B.shape
|
| 178 |
+
if D is not None:
|
| 179 |
+
assert D.shape == (dim,)
|
| 180 |
+
if z is not None:
|
| 181 |
+
assert z.shape == x.shape
|
| 182 |
+
if dt_bias is not None:
|
| 183 |
+
assert dt_bias.shape == (dim,)
|
| 184 |
+
dt = dt + dt_bias
|
| 185 |
+
dt = F.softplus(dt) if dt_softplus else dt
|
| 186 |
+
dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
|
| 187 |
+
dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
|
| 188 |
+
state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
|
| 189 |
+
out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
|
| 190 |
+
if D is not None:
|
| 191 |
+
out += (x * D).to(out.dtype)
|
| 192 |
+
return (out if z is None else out * F.silu(z)).to(x.dtype)
|
mamba_ssm/utils/__init__.py
ADDED
|
File without changes
|
mamba_ssm/utils/generation.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Callable, Optional, Sequence, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
| 14 |
+
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class InferenceParams:
|
| 19 |
+
"""Inference parameters that are passed to the main model in order
|
| 20 |
+
to efficienly calculate and store the context during inference."""
|
| 21 |
+
|
| 22 |
+
max_seqlen: int
|
| 23 |
+
max_batch_size: int
|
| 24 |
+
seqlen_offset: int = 0
|
| 25 |
+
batch_size_offset: int = 0
|
| 26 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
| 27 |
+
lengths_per_sample: Optional[Tensor] = None
|
| 28 |
+
|
| 29 |
+
def reset(self, max_seqlen, max_batch_size):
|
| 30 |
+
self.max_seqlen = max_seqlen
|
| 31 |
+
self.max_batch_size = max_batch_size
|
| 32 |
+
self.seqlen_offset = 0
|
| 33 |
+
if self.lengths_per_sample is not None:
|
| 34 |
+
self.lengths_per_sample.zero_()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
| 38 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
| 39 |
+
def modify_logits_for_top_k_filtering(logits, top_k):
|
| 40 |
+
"""Set the logits for none top-k values to -inf. Done in-place."""
|
| 41 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 42 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
| 46 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
| 47 |
+
def modify_logits_for_top_p_filtering(logits, top_p):
|
| 48 |
+
"""Set the logits for none top-p values to -inf. Done in-place."""
|
| 49 |
+
if top_p <= 0.0 or top_p >= 1.0:
|
| 50 |
+
return
|
| 51 |
+
# First sort and calculate cumulative sum of probabilities.
|
| 52 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
| 53 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 54 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
| 55 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
| 56 |
+
# scatter sorted tensors to original indexing
|
| 57 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 58 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 59 |
+
)
|
| 60 |
+
logits.masked_fill_(indices_to_remove, float("-inf"))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
| 64 |
+
"""Sample from top-k logits.
|
| 65 |
+
Arguments:
|
| 66 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
| 67 |
+
"""
|
| 68 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
| 69 |
+
return logits.argmax(dim=-1)
|
| 70 |
+
else:
|
| 71 |
+
if top_p > 0.0:
|
| 72 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
| 73 |
+
if top_k > 0:
|
| 74 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 75 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
| 76 |
+
if temperature != 1.0:
|
| 77 |
+
logits_top /= temperature
|
| 78 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
| 79 |
+
return indices[
|
| 80 |
+
torch.arange(indices.shape[0], device=indices.device),
|
| 81 |
+
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
| 82 |
+
]
|
| 83 |
+
else:
|
| 84 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
| 85 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
| 86 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
| 87 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
| 88 |
+
dim=-1
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@torch.inference_mode()
|
| 93 |
+
def decode(
|
| 94 |
+
input_ids,
|
| 95 |
+
model,
|
| 96 |
+
max_length,
|
| 97 |
+
top_k=1,
|
| 98 |
+
top_p=0.0,
|
| 99 |
+
temperature=1.0,
|
| 100 |
+
eos_token_id=None,
|
| 101 |
+
teacher_outputs=None,
|
| 102 |
+
vocab_size=None,
|
| 103 |
+
tensor_parallel=1,
|
| 104 |
+
cg=False,
|
| 105 |
+
enable_timing=False,
|
| 106 |
+
):
|
| 107 |
+
"""Decoding, either greedy or with top-k or top-p sampling.
|
| 108 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
| 109 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
| 110 |
+
then top-p.
|
| 111 |
+
We assume that all sequences in the same batch have the same length.
|
| 112 |
+
|
| 113 |
+
Arguments:
|
| 114 |
+
input_ids: (batch, seq_len)
|
| 115 |
+
max_length: int
|
| 116 |
+
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
| 117 |
+
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
| 118 |
+
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
| 119 |
+
sequences: (batch, max_length)
|
| 120 |
+
scores: tuples of (batch, vocab_size)
|
| 121 |
+
"""
|
| 122 |
+
batch_size, seqlen_og = input_ids.shape
|
| 123 |
+
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
| 124 |
+
if cg:
|
| 125 |
+
if not hasattr(model, "_decoding_cache"):
|
| 126 |
+
model._decoding_cache = None
|
| 127 |
+
model._decoding_cache = update_graph_cache(
|
| 128 |
+
model,
|
| 129 |
+
model._decoding_cache,
|
| 130 |
+
batch_size,
|
| 131 |
+
seqlen_og,
|
| 132 |
+
max_length,
|
| 133 |
+
tensor_parallel=tensor_parallel,
|
| 134 |
+
)
|
| 135 |
+
inference_params = model._decoding_cache.inference_params
|
| 136 |
+
inference_params.reset(max_length, batch_size)
|
| 137 |
+
else:
|
| 138 |
+
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
| 139 |
+
|
| 140 |
+
def get_logits(input_ids, inference_params):
|
| 141 |
+
decoding = inference_params.seqlen_offset > 0
|
| 142 |
+
if decoding:
|
| 143 |
+
position_ids = torch.full(
|
| 144 |
+
(batch_size, 1),
|
| 145 |
+
inference_params.seqlen_offset,
|
| 146 |
+
dtype=torch.long,
|
| 147 |
+
device=input_ids.device,
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
position_ids = None
|
| 151 |
+
if not cg or not decoding:
|
| 152 |
+
logits = model(
|
| 153 |
+
input_ids,
|
| 154 |
+
position_ids=position_ids,
|
| 155 |
+
inference_params=inference_params,
|
| 156 |
+
num_last_tokens=1,
|
| 157 |
+
).logits.squeeze(dim=1)
|
| 158 |
+
else:
|
| 159 |
+
logits = model._decoding_cache.run(
|
| 160 |
+
input_ids, position_ids, inference_params.seqlen_offset
|
| 161 |
+
).squeeze(dim=1)
|
| 162 |
+
return logits[..., :vocab_size] if vocab_size is not None else logits
|
| 163 |
+
|
| 164 |
+
def sample_tokens(logits, inference_params):
|
| 165 |
+
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
| 166 |
+
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
| 167 |
+
else:
|
| 168 |
+
token = teacher_outputs[:, inference_params.seqlen_offset]
|
| 169 |
+
# return rearrange(token, "b -> b 1")
|
| 170 |
+
return token.unsqueeze(1)
|
| 171 |
+
|
| 172 |
+
def should_stop(current_token, inference_params):
|
| 173 |
+
if inference_params.seqlen_offset == 0:
|
| 174 |
+
return False
|
| 175 |
+
if eos_token_id is not None and (current_token == eos_token_id).all():
|
| 176 |
+
return True
|
| 177 |
+
if inference_params.seqlen_offset >= max_length - 1:
|
| 178 |
+
return True
|
| 179 |
+
return False
|
| 180 |
+
|
| 181 |
+
start = torch.cuda.Event(enable_timing=enable_timing)
|
| 182 |
+
end = torch.cuda.Event(enable_timing=enable_timing)
|
| 183 |
+
|
| 184 |
+
if enable_timing:
|
| 185 |
+
if tensor_parallel > 1:
|
| 186 |
+
torch.distributed.barrier()
|
| 187 |
+
start.record()
|
| 188 |
+
scores, sequences = [], [input_ids]
|
| 189 |
+
while not should_stop(sequences[-1], inference_params):
|
| 190 |
+
scores.append(get_logits(sequences[-1], inference_params))
|
| 191 |
+
inference_params.seqlen_offset += sequences[-1].shape[1]
|
| 192 |
+
sequences.append(sample_tokens(scores[-1], inference_params))
|
| 193 |
+
if enable_timing:
|
| 194 |
+
end.record()
|
| 195 |
+
if tensor_parallel > 1:
|
| 196 |
+
torch.distributed.barrier()
|
| 197 |
+
torch.cuda.synchronize()
|
| 198 |
+
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
| 199 |
+
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
| 200 |
+
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class GenerationMixin:
|
| 204 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 205 |
+
raise NotImplementedError
|
| 206 |
+
|
| 207 |
+
def generate(
|
| 208 |
+
self,
|
| 209 |
+
input_ids,
|
| 210 |
+
max_length,
|
| 211 |
+
top_k=1,
|
| 212 |
+
top_p=0.0,
|
| 213 |
+
temperature=1.0,
|
| 214 |
+
return_dict_in_generate=False,
|
| 215 |
+
output_scores=False,
|
| 216 |
+
**kwargs,
|
| 217 |
+
):
|
| 218 |
+
output = decode(
|
| 219 |
+
input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
|
| 220 |
+
)
|
| 221 |
+
if not output_scores:
|
| 222 |
+
output.scores = None
|
| 223 |
+
return output if return_dict_in_generate else output.sequences
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def allocate_inference_cache(
|
| 227 |
+
max_batch_size,
|
| 228 |
+
max_seqlen,
|
| 229 |
+
nheads,
|
| 230 |
+
headdim,
|
| 231 |
+
layers: Union[int, Sequence],
|
| 232 |
+
device,
|
| 233 |
+
dtype=torch.float16,
|
| 234 |
+
):
|
| 235 |
+
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
| 236 |
+
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
|
| 237 |
+
if isinstance(layers, int):
|
| 238 |
+
layers = range(layers)
|
| 239 |
+
return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@dataclass
|
| 243 |
+
class DecodingCGCache:
|
| 244 |
+
max_batch_size: int = 0
|
| 245 |
+
max_seqlen: int = 0
|
| 246 |
+
device = None
|
| 247 |
+
dtype = None
|
| 248 |
+
callables: dict = field(default_factory=dict)
|
| 249 |
+
mempool = None
|
| 250 |
+
inference_params: Optional[InferenceParams] = None
|
| 251 |
+
run: Optional[Callable] = None
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@torch.inference_mode()
|
| 255 |
+
def update_graph_cache(
|
| 256 |
+
model,
|
| 257 |
+
cache,
|
| 258 |
+
batch_size,
|
| 259 |
+
seqlen_og,
|
| 260 |
+
max_seqlen,
|
| 261 |
+
decoding_seqlens=(1,),
|
| 262 |
+
tensor_parallel=1,
|
| 263 |
+
dtype=None,
|
| 264 |
+
n_warmups=2,
|
| 265 |
+
):
|
| 266 |
+
if cache is None:
|
| 267 |
+
cache = DecodingCGCache()
|
| 268 |
+
param_example = next(iter(model.parameters()))
|
| 269 |
+
device = param_example.device
|
| 270 |
+
if dtype is None:
|
| 271 |
+
dtype = param_example.dtype
|
| 272 |
+
if (
|
| 273 |
+
(device, dtype) != (cache.device, cache.dtype)
|
| 274 |
+
or batch_size > cache.max_batch_size
|
| 275 |
+
or max_seqlen > cache.max_seqlen
|
| 276 |
+
): # Invalidate the cache
|
| 277 |
+
cache.callables = {}
|
| 278 |
+
cache.mempool = None
|
| 279 |
+
cache.inference_params = None
|
| 280 |
+
gc.collect()
|
| 281 |
+
cache.device, cache.dtype = device, dtype
|
| 282 |
+
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
| 283 |
+
if hasattr(model, "allocate_inference_cache"):
|
| 284 |
+
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
| 285 |
+
else:
|
| 286 |
+
headdim = getattr(
|
| 287 |
+
model.config,
|
| 288 |
+
"head_dim",
|
| 289 |
+
model.config.hidden_size // model.config.num_attention_heads,
|
| 290 |
+
)
|
| 291 |
+
inf_cache = allocate_inference_cache(
|
| 292 |
+
batch_size,
|
| 293 |
+
max_seqlen,
|
| 294 |
+
model.config.num_attention_heads // tensor_parallel,
|
| 295 |
+
headdim,
|
| 296 |
+
model.config.num_hidden_layers,
|
| 297 |
+
device,
|
| 298 |
+
dtype,
|
| 299 |
+
)
|
| 300 |
+
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
| 301 |
+
cache.inference_params = InferenceParams(
|
| 302 |
+
max_seqlen=max_seqlen,
|
| 303 |
+
max_batch_size=batch_size,
|
| 304 |
+
seqlen_offset=seqlen_og,
|
| 305 |
+
key_value_memory_dict=inf_cache,
|
| 306 |
+
lengths_per_sample=lengths_per_sample,
|
| 307 |
+
)
|
| 308 |
+
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
| 309 |
+
for decoding_seqlen in decoding_seqlens:
|
| 310 |
+
if (batch_size, decoding_seqlen) not in cache.callables:
|
| 311 |
+
cache.callables[batch_size, decoding_seqlen] = capture_graph(
|
| 312 |
+
model,
|
| 313 |
+
cache.inference_params,
|
| 314 |
+
batch_size,
|
| 315 |
+
max_seqlen,
|
| 316 |
+
decoding_seqlen=decoding_seqlen,
|
| 317 |
+
mempool=cache.mempool,
|
| 318 |
+
n_warmups=n_warmups,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def dispatch(input_ids, position_ids, seqlen):
|
| 322 |
+
batch_size, decoding_seqlen = input_ids.shape[:2]
|
| 323 |
+
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
|
| 324 |
+
|
| 325 |
+
cache.run = dispatch
|
| 326 |
+
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
| 327 |
+
return cache
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def capture_graph(
|
| 331 |
+
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
|
| 332 |
+
):
|
| 333 |
+
device = next(iter(model.parameters())).device
|
| 334 |
+
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
| 335 |
+
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
| 336 |
+
seqlen_offset_og = inference_params.seqlen_offset
|
| 337 |
+
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
|
| 338 |
+
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
|
| 339 |
+
|
| 340 |
+
# Warmup before capture
|
| 341 |
+
s = torch.cuda.Stream()
|
| 342 |
+
s.wait_stream(torch.cuda.current_stream())
|
| 343 |
+
with torch.cuda.stream(s):
|
| 344 |
+
for _ in range(n_warmups):
|
| 345 |
+
logits = model(
|
| 346 |
+
input_ids,
|
| 347 |
+
position_ids=position_ids,
|
| 348 |
+
inference_params=inference_params,
|
| 349 |
+
num_last_tokens=decoding_seqlen,
|
| 350 |
+
).logits
|
| 351 |
+
s.synchronize()
|
| 352 |
+
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
| 353 |
+
# which requires that graph launch and non-captured launch to not overlap (I think,
|
| 354 |
+
# that's how I interpret the documentation). I'm not sure if this is required.
|
| 355 |
+
if torch.distributed.is_initialized():
|
| 356 |
+
torch.distributed.barrier()
|
| 357 |
+
torch.cuda.current_stream().wait_stream(s)
|
| 358 |
+
# Captures the graph
|
| 359 |
+
# To allow capture, automatically sets a side stream as the current stream in the context
|
| 360 |
+
graph = torch.cuda.CUDAGraph()
|
| 361 |
+
with torch.cuda.graph(graph, pool=mempool):
|
| 362 |
+
logits = model(
|
| 363 |
+
input_ids,
|
| 364 |
+
position_ids=position_ids,
|
| 365 |
+
inference_params=inference_params,
|
| 366 |
+
num_last_tokens=decoding_seqlen,
|
| 367 |
+
).logits
|
| 368 |
+
|
| 369 |
+
def run(new_input_ids, new_position_ids, seqlen):
|
| 370 |
+
inference_params.lengths_per_sample[:] = seqlen
|
| 371 |
+
input_ids.copy_(new_input_ids)
|
| 372 |
+
position_ids.copy_(new_position_ids)
|
| 373 |
+
graph.replay()
|
| 374 |
+
return logits.clone()
|
| 375 |
+
|
| 376 |
+
inference_params.seqlen_offset = seqlen_offset_og
|
| 377 |
+
return run
|
mamba_ssm/utils/hf.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
| 6 |
+
from transformers.utils.hub import cached_file
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_config_hf(model_name):
|
| 10 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
| 11 |
+
return json.load(open(resolved_archive_file))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_state_dict_hf(model_name, device=None, dtype=None):
|
| 15 |
+
# If not fp32, then we don't want to load directly to the GPU
|
| 16 |
+
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
| 17 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
| 18 |
+
return torch.load(resolved_archive_file, map_location=mapped_device)
|
| 19 |
+
# Convert dtype before moving to GPU to save memory
|
| 20 |
+
if dtype is not None:
|
| 21 |
+
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
| 22 |
+
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
| 23 |
+
return state_dict
|
neural_methods/__init__.py
ADDED
|
File without changes
|
neural_methods/loss/NegPearsonLoss.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function, division
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import argparse, os
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import math
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Neg_Pearson(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super(Neg_Pearson, self).__init__()
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
def forward(self, preds, labels):
|
| 19 |
+
cos = nn.CosineSimilarity(dim=0, eps=1e-6)
|
| 20 |
+
pearson = cos(preds - preds.mean(dim=0, keepdim=True), labels - labels.mean(dim=0, keepdim=True))
|
| 21 |
+
return torch.mean(1 - pearson)
|
| 22 |
+
|
| 23 |
+
|
neural_methods/loss/PhysFormerLossComputer.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Adapted from here: https://github.com/ZitongYu/PhysFormer/blob/main/TorchLossComputer.py
|
| 3 |
+
Modifed based on the HR-CNN here: https://github.com/radimspetlik/hr-cnn
|
| 4 |
+
'''
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import pdb
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
def normal_sampling(mean, label_k, std):
|
| 14 |
+
return math.exp(-(label_k-mean)**2/(2*std**2))/(math.sqrt(2*math.pi)*std)
|
| 15 |
+
|
| 16 |
+
def kl_loss(inputs, labels):
|
| 17 |
+
# Reshape the labels tensor to match the shape of inputs
|
| 18 |
+
labels = labels.view(1, -1)
|
| 19 |
+
|
| 20 |
+
# Compute the KL Div Loss
|
| 21 |
+
criterion = nn.KLDivLoss(reduction='sum')
|
| 22 |
+
loss = criterion(F.log_softmax(inputs, dim=-1), labels)
|
| 23 |
+
return loss
|
| 24 |
+
|
| 25 |
+
class TorchLossComputer(object):
|
| 26 |
+
@staticmethod
|
| 27 |
+
def compute_complex_absolute_given_k(output, k, N):
|
| 28 |
+
two_pi_n_over_N = torch.autograd.Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N
|
| 29 |
+
hanning = torch.autograd.Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1)
|
| 30 |
+
|
| 31 |
+
k = k.type(torch.FloatTensor).cuda()
|
| 32 |
+
two_pi_n_over_N = two_pi_n_over_N.cuda()
|
| 33 |
+
hanning = hanning.cuda()
|
| 34 |
+
|
| 35 |
+
output = output.view(1, -1) * hanning
|
| 36 |
+
output = output.view(1, 1, -1).type(torch.cuda.FloatTensor)
|
| 37 |
+
k = k.view(1, -1, 1)
|
| 38 |
+
two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1)
|
| 39 |
+
complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \
|
| 40 |
+
+ torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2
|
| 41 |
+
|
| 42 |
+
return complex_absolute
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def complex_absolute(output, Fs, bpm_range=None):
|
| 46 |
+
output = output.view(1, -1)
|
| 47 |
+
|
| 48 |
+
N = output.size()[1]
|
| 49 |
+
|
| 50 |
+
unit_per_hz = Fs / N
|
| 51 |
+
feasible_bpm = bpm_range / 60.0
|
| 52 |
+
k = feasible_bpm / unit_per_hz
|
| 53 |
+
|
| 54 |
+
# only calculate feasible PSD range [0.7,4] Hz
|
| 55 |
+
complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N)
|
| 56 |
+
|
| 57 |
+
return (1.0 / complex_absolute.sum()) * complex_absolute # Analogous Softmax operator
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def cross_entropy_power_spectrum_loss(inputs, target, Fs):
|
| 61 |
+
inputs = inputs.view(1, -1)
|
| 62 |
+
target = target.view(1, -1)
|
| 63 |
+
bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
|
| 64 |
+
|
| 65 |
+
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
|
| 66 |
+
|
| 67 |
+
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
|
| 68 |
+
whole_max_idx = whole_max_idx.type(torch.float)
|
| 69 |
+
|
| 70 |
+
return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma):
|
| 74 |
+
inputs = inputs.view(1, -1)
|
| 75 |
+
target = target.view(1, -1)
|
| 76 |
+
bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
|
| 77 |
+
|
| 78 |
+
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
|
| 79 |
+
|
| 80 |
+
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
|
| 81 |
+
whole_max_idx = whole_max_idx.type(torch.float)
|
| 82 |
+
|
| 83 |
+
#pdb.set_trace()
|
| 84 |
+
criterion = FocalLoss(gamma=gamma)
|
| 85 |
+
|
| 86 |
+
return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def cross_entropy_power_spectrum_forward_pred(inputs, Fs):
|
| 91 |
+
inputs = inputs.view(1, -1)
|
| 92 |
+
bpm_range = torch.arange(40, 190, dtype=torch.float).cuda()
|
| 93 |
+
|
| 94 |
+
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
|
| 95 |
+
|
| 96 |
+
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
|
| 97 |
+
whole_max_idx = whole_max_idx.type(torch.float)
|
| 98 |
+
|
| 99 |
+
return whole_max_idx
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def cross_entropy_power_spectrum_DLDL_softmax2(inputs, target, Fs, std):
|
| 103 |
+
target_distribution = [normal_sampling(int(target), i, std) for i in range(40, 180)]
|
| 104 |
+
target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
|
| 105 |
+
target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda'))
|
| 106 |
+
|
| 107 |
+
inputs = inputs.view(1, -1)
|
| 108 |
+
target = target.view(1, -1)
|
| 109 |
+
|
| 110 |
+
bpm_range = torch.arange(40, 180, dtype=torch.float).to(torch.device('cuda'))
|
| 111 |
+
|
| 112 |
+
ca = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
|
| 113 |
+
|
| 114 |
+
fre_distribution = ca/torch.sum(ca)
|
| 115 |
+
loss_distribution_kl = kl_loss(fre_distribution, target_distribution)
|
| 116 |
+
|
| 117 |
+
whole_max_val, whole_max_idx = ca.view(-1).max(0)
|
| 118 |
+
whole_max_idx = whole_max_idx.type(torch.float)
|
| 119 |
+
return loss_distribution_kl, F.cross_entropy(ca, (target-bpm_range[0]).view(1).type(torch.long)), torch.abs(target[0]-bpm_range[0]-whole_max_idx)
|
| 120 |
+
|
neural_methods/loss/PhysNetNegPearsonLoss.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function, division
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import argparse, os
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import math
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Neg_Pearson(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
The Neg_Pearson Module is from the orignal author of Physnet.
|
| 16 |
+
Code of 'Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks'
|
| 17 |
+
source: https://github.com/ZitongYu/PhysNet/blob/master/NegPearsonLoss.py
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super(Neg_Pearson, self).__init__()
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def forward(self, preds, labels):
|
| 26 |
+
loss = 0
|
| 27 |
+
for i in range(preds.shape[0]):
|
| 28 |
+
sum_x = torch.sum(preds[i])
|
| 29 |
+
sum_y = torch.sum(labels[i])
|
| 30 |
+
sum_xy = torch.sum(preds[i]*labels[i])
|
| 31 |
+
sum_x2 = torch.sum(torch.pow(preds[i],2))
|
| 32 |
+
sum_y2 = torch.sum(torch.pow(labels[i],2))
|
| 33 |
+
N = preds.shape[1]
|
| 34 |
+
pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2))))
|
| 35 |
+
loss += 1 - pearson
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
loss = loss/preds.shape[0]
|
| 39 |
+
return loss
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
neural_methods/loss/RythmFormerLossComputer.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Adapted from here: https://github.com/ZitongYu/PhysFormer/TorchLossComputer.py
|
| 3 |
+
Modifed based on the HR-CNN here: https://github.com/radimspetlik/hr-cnn
|
| 4 |
+
'''
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from evaluation.post_process import calculate_metric_per_video
|
| 12 |
+
|
| 13 |
+
def normal_sampling(mean, label_k, std):
|
| 14 |
+
return math.exp(-(label_k-mean)**2/(2*std**2))/(math.sqrt(2*math.pi)*std)
|
| 15 |
+
|
| 16 |
+
def kl_loss(inputs, labels):
|
| 17 |
+
criterion = nn.KLDivLoss(reduce=False)
|
| 18 |
+
outputs = torch.log(inputs)
|
| 19 |
+
loss = criterion(outputs, labels)
|
| 20 |
+
#loss = loss.sum()/loss.shape[0]
|
| 21 |
+
loss = loss.sum()
|
| 22 |
+
return loss
|
| 23 |
+
|
| 24 |
+
class Neg_Pearson(nn.Module): # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super(Neg_Pearson,self).__init__()
|
| 27 |
+
|
| 28 |
+
def forward(self, preds, labels): # all variable operation
|
| 29 |
+
loss = 0
|
| 30 |
+
for i in range(preds.shape[0]):
|
| 31 |
+
sum_x = torch.sum(preds[i]) # x
|
| 32 |
+
sum_y = torch.sum(labels[i]) # y
|
| 33 |
+
sum_xy = torch.sum(preds[i]*labels[i]) # xy
|
| 34 |
+
sum_x2 = torch.sum(torch.pow(preds[i],2)) # x^2
|
| 35 |
+
sum_y2 = torch.sum(torch.pow(labels[i],2)) # y^2
|
| 36 |
+
N = preds.shape[1]
|
| 37 |
+
pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2))))
|
| 38 |
+
loss += 1 - pearson
|
| 39 |
+
|
| 40 |
+
loss = loss/preds.shape[0]
|
| 41 |
+
return loss
|
| 42 |
+
|
| 43 |
+
class RhythmFormer_Loss(nn.Module):
|
| 44 |
+
def __init__(self):
|
| 45 |
+
super(RhythmFormer_Loss,self).__init__()
|
| 46 |
+
self.criterion_Pearson = Neg_Pearson()
|
| 47 |
+
def forward(self, pred_ppg, labels , epoch , FS , diff_flag):
|
| 48 |
+
loss_time = self.criterion_Pearson(pred_ppg.view(1,-1) , labels.view(1,-1))
|
| 49 |
+
loss_CE , loss_distribution_kl = TorchLossComputer.Frequency_loss(pred_ppg.squeeze(-1), labels.squeeze(-1), diff_flag=diff_flag, Fs=FS, std=3.0)
|
| 50 |
+
loss_hr = TorchLossComputer.HR_loss(pred_ppg.squeeze(-1), labels.squeeze(-1), diff_flag=diff_flag, Fs=FS, std=3.0)
|
| 51 |
+
if torch.isnan(loss_time) :
|
| 52 |
+
loss_time = 0
|
| 53 |
+
|
| 54 |
+
loss = 0.2 * loss_time + 1.0 * loss_CE + 1.0 * loss_hr
|
| 55 |
+
return loss
|
| 56 |
+
|
| 57 |
+
class TorchLossComputer(object):
|
| 58 |
+
@staticmethod
|
| 59 |
+
def compute_complex_absolute_given_k(output, k, N):
|
| 60 |
+
two_pi_n_over_N = Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N
|
| 61 |
+
hanning = Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1)
|
| 62 |
+
|
| 63 |
+
k = k.type(torch.FloatTensor).cuda()
|
| 64 |
+
two_pi_n_over_N = two_pi_n_over_N.cuda()
|
| 65 |
+
hanning = hanning.cuda()
|
| 66 |
+
|
| 67 |
+
output = output.view(1, -1) * hanning
|
| 68 |
+
output = output.view(1, 1, -1).type(torch.cuda.FloatTensor)
|
| 69 |
+
k = k.view(1, -1, 1)
|
| 70 |
+
two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1)
|
| 71 |
+
complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \
|
| 72 |
+
+ torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2
|
| 73 |
+
|
| 74 |
+
return complex_absolute
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def complex_absolute(output, Fs, bpm_range=None):
|
| 78 |
+
output = output.view(1, -1)
|
| 79 |
+
|
| 80 |
+
N = output.size()[1]
|
| 81 |
+
|
| 82 |
+
unit_per_hz = Fs / N
|
| 83 |
+
feasible_bpm = bpm_range / 60.0
|
| 84 |
+
k = feasible_bpm / unit_per_hz
|
| 85 |
+
|
| 86 |
+
# only calculate feasible PSD range [0.7,4]Hz
|
| 87 |
+
complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N)
|
| 88 |
+
|
| 89 |
+
return (1.0 / complex_absolute.sum()) * complex_absolute # Analogous Softmax operator
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def cross_entropy_power_spectrum_loss(inputs, target, Fs):
|
| 94 |
+
inputs = inputs.view(1, -1)
|
| 95 |
+
target = target.view(1, -1)
|
| 96 |
+
bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
|
| 97 |
+
#bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
|
| 98 |
+
|
| 99 |
+
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
|
| 100 |
+
|
| 101 |
+
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
|
| 102 |
+
whole_max_idx = whole_max_idx.type(torch.float)
|
| 103 |
+
|
| 104 |
+
#pdb.set_trace()
|
| 105 |
+
|
| 106 |
+
#return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2
|
| 107 |
+
return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma):
|
| 111 |
+
inputs = inputs.view(1, -1)
|
| 112 |
+
target = target.view(1, -1)
|
| 113 |
+
bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
|
| 114 |
+
#bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
|
| 115 |
+
|
| 116 |
+
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
|
| 117 |
+
|
| 118 |
+
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
|
| 119 |
+
whole_max_idx = whole_max_idx.type(torch.float)
|
| 120 |
+
|
| 121 |
+
#pdb.set_trace()
|
| 122 |
+
criterion = FocalLoss(gamma=gamma)
|
| 123 |
+
|
| 124 |
+
#return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2
|
| 125 |
+
return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def cross_entropy_power_spectrum_forward_pred(inputs, Fs):
|
| 130 |
+
inputs = inputs.view(1, -1)
|
| 131 |
+
bpm_range = torch.arange(40, 190, dtype=torch.float).cuda()
|
| 132 |
+
#bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
|
| 133 |
+
#bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
|
| 134 |
+
|
| 135 |
+
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
|
| 136 |
+
|
| 137 |
+
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
|
| 138 |
+
whole_max_idx = whole_max_idx.type(torch.float)
|
| 139 |
+
|
| 140 |
+
return whole_max_idx
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def Frequency_loss(inputs, target, diff_flag , Fs, std):
|
| 144 |
+
hr_gt, pred_hr_peak, SNR, macc = calculate_metric_per_video(inputs.detach().cpu(), target.detach().cpu(), diff_flag = diff_flag, fs=Fs, hr_method='FFT')
|
| 145 |
+
inputs = inputs.view(1, -1)
|
| 146 |
+
target = target.view(1, -1)
|
| 147 |
+
bpm_range = torch.arange(45, 150, dtype=torch.float).to(torch.device('cuda'))
|
| 148 |
+
ca = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
|
| 149 |
+
sa = ca/torch.sum(ca)
|
| 150 |
+
|
| 151 |
+
target_distribution = [normal_sampling(int(hr_gt), i, std) for i in range(45, 150)]
|
| 152 |
+
target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
|
| 153 |
+
target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda'))
|
| 154 |
+
|
| 155 |
+
hr_gt = torch.tensor(hr_gt-45).view(1).type(torch.long).to(torch.device('cuda'))
|
| 156 |
+
return F.cross_entropy(ca, hr_gt) , kl_loss(sa , target_distribution)
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def HR_loss(inputs, target, diff_flag , Fs, std):
|
| 160 |
+
psd_gt, psd_pred, SNR, macc = calculate_metric_per_video(inputs.detach().cpu(), target.detach().cpu(), diff_flag = diff_flag, fs=Fs, hr_method='Peak')
|
| 161 |
+
pred_distribution = [normal_sampling(np.argmax(psd_pred), i, std) for i in range(psd_pred.size)]
|
| 162 |
+
pred_distribution = [i if i > 1e-15 else 1e-15 for i in pred_distribution]
|
| 163 |
+
pred_distribution = torch.Tensor(pred_distribution).to(torch.device('cuda'))
|
| 164 |
+
target_distribution = [normal_sampling(np.argmax(psd_gt), i, std) for i in range(psd_gt.size)]
|
| 165 |
+
target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
|
| 166 |
+
target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda'))
|
| 167 |
+
return kl_loss(pred_distribution , target_distribution)
|
neural_methods/loss/__init__.py
ADDED
|
File without changes
|
neural_methods/model/BigSmall.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BigSmall: Multitask Network for AU / Respiration / PPG
|
| 2 |
+
|
| 3 |
+
BigSmall: Efficient Multi-Task Learning
|
| 4 |
+
For Physiological Measurements
|
| 5 |
+
Girish Narayanswamy, Yujia (Nancy) Liu, Yuzhe Yang, Chengqian (Jack) Ma,
|
| 6 |
+
Xin Liu, Daniel McDuff, Shwetak Patel
|
| 7 |
+
|
| 8 |
+
https://arxiv.org/abs/2303.11573
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
#####################################################
|
| 16 |
+
############ Wrapping Time Shift Module #############
|
| 17 |
+
#####################################################
|
| 18 |
+
class WTSM(nn.Module):
|
| 19 |
+
def __init__(self, n_segment=3, fold_div=3):
|
| 20 |
+
super(WTSM, self).__init__()
|
| 21 |
+
self.n_segment = n_segment
|
| 22 |
+
self.fold_div = fold_div
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
nt, c, h, w = x.size()
|
| 26 |
+
n_batch = nt // self.n_segment
|
| 27 |
+
x = x.view(n_batch, self.n_segment, c, h, w)
|
| 28 |
+
fold = c // self.fold_div
|
| 29 |
+
out = torch.zeros_like(x)
|
| 30 |
+
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
|
| 31 |
+
out[:, -1, :fold] = x[:, 0, :fold] # wrap left
|
| 32 |
+
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
|
| 33 |
+
out[:, 0, fold: 2 * fold] = x[:, -1, fold: 2 * fold] # wrap right
|
| 34 |
+
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # no shift for final fold
|
| 35 |
+
return out.view(nt, c, h, w)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
#######################################################################################
|
| 40 |
+
##################################### BigSmall Model ##################################
|
| 41 |
+
#######################################################################################
|
| 42 |
+
class BigSmall(nn.Module):
|
| 43 |
+
|
| 44 |
+
def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3,
|
| 45 |
+
dropout_rate1=0.25, dropout_rate2=0.5, dropout_rate3=0.5, pool_size1=(2, 2), pool_size2=(4,4),
|
| 46 |
+
nb_dense=128, out_size_bvp=1, out_size_resp=1, out_size_au=12, n_segment=3):
|
| 47 |
+
|
| 48 |
+
super(BigSmall, self).__init__()
|
| 49 |
+
|
| 50 |
+
self.in_channels = in_channels
|
| 51 |
+
self.kernel_size = kernel_size
|
| 52 |
+
self.dropout_rate1 = dropout_rate1
|
| 53 |
+
self.dropout_rate2 = dropout_rate2
|
| 54 |
+
self.dropout_rate3 = dropout_rate3
|
| 55 |
+
self.pool_size1 = pool_size1
|
| 56 |
+
self.pool_size2 = pool_size2
|
| 57 |
+
self.nb_filters1 = nb_filters1
|
| 58 |
+
self.nb_filters2 = nb_filters2
|
| 59 |
+
self.nb_dense = nb_dense
|
| 60 |
+
|
| 61 |
+
self.out_size_bvp = out_size_bvp
|
| 62 |
+
self.out_size_resp = out_size_resp
|
| 63 |
+
self.out_size_au = out_size_au
|
| 64 |
+
|
| 65 |
+
self.n_segment = n_segment
|
| 66 |
+
|
| 67 |
+
# Big Convolutional Layers
|
| 68 |
+
self.big_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
|
| 69 |
+
self.big_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
|
| 70 |
+
self.big_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
|
| 71 |
+
self.big_conv4 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
|
| 72 |
+
self.big_conv5 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
|
| 73 |
+
self.big_conv6 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True)
|
| 74 |
+
|
| 75 |
+
# Big Avg Pooling / Dropout Layers
|
| 76 |
+
self.big_avg_pooling1 = nn.AvgPool2d(self.pool_size1)
|
| 77 |
+
self.big_dropout1 = nn.Dropout(self.dropout_rate1)
|
| 78 |
+
self.big_avg_pooling2 = nn.AvgPool2d(self.pool_size1)
|
| 79 |
+
self.big_dropout2 = nn.Dropout(self.dropout_rate2)
|
| 80 |
+
self.big_avg_pooling3 = nn.AvgPool2d(self.pool_size2)
|
| 81 |
+
self.big_dropout3 = nn.Dropout(self.dropout_rate3)
|
| 82 |
+
|
| 83 |
+
# TSM layers
|
| 84 |
+
self.TSM_1 = WTSM(n_segment=self.n_segment)
|
| 85 |
+
self.TSM_2 = WTSM(n_segment=self.n_segment)
|
| 86 |
+
self.TSM_3 = WTSM(n_segment=self.n_segment)
|
| 87 |
+
self.TSM_4 = WTSM(n_segment=self.n_segment)
|
| 88 |
+
|
| 89 |
+
# Small Convolutional Layers
|
| 90 |
+
self.small_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True)
|
| 91 |
+
self.small_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True)
|
| 92 |
+
self.small_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True)
|
| 93 |
+
self.small_conv4 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1,1), bias=True)
|
| 94 |
+
|
| 95 |
+
# AU Fully Connected Layers
|
| 96 |
+
self.au_fc1 = nn.Linear(5184, self.nb_dense, bias=True)
|
| 97 |
+
self.au_fc2 = nn.Linear(self.nb_dense, self.out_size_au, bias=True)
|
| 98 |
+
|
| 99 |
+
# BVP Fully Connected Layers
|
| 100 |
+
self.bvp_fc1 = nn.Linear(5184, self.nb_dense, bias=True)
|
| 101 |
+
self.bvp_fc2 = nn.Linear(self.nb_dense, self.out_size_bvp, bias=True)
|
| 102 |
+
|
| 103 |
+
# Resp Fully Connected Layers
|
| 104 |
+
self.resp_fc1 = nn.Linear(5184, self.nb_dense, bias=True)
|
| 105 |
+
self.resp_fc2 = nn.Linear(self.nb_dense, self.out_size_resp, bias=True)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def forward(self, inputs, params=None):
|
| 109 |
+
|
| 110 |
+
big_input = inputs[0] # big res
|
| 111 |
+
small_input = inputs[1] # small res
|
| 112 |
+
|
| 113 |
+
# reshape Big
|
| 114 |
+
nt, c, h, w = big_input.size()
|
| 115 |
+
n_batch = nt // self.n_segment
|
| 116 |
+
big_input = big_input.view(n_batch, self.n_segment, c, h, w)
|
| 117 |
+
big_input = torch.moveaxis(big_input, 1, 2) # color channel to idx 1, sequence channel to idx 2
|
| 118 |
+
big_input = big_input[:, :, 0, :, :] # use only first frame in sequences
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# Big Conv block 1
|
| 122 |
+
b1 = nn.functional.relu(self.big_conv1(big_input))
|
| 123 |
+
b2 = nn.functional.relu(self.big_conv2(b1))
|
| 124 |
+
b3 = self.big_avg_pooling1(b2)
|
| 125 |
+
b4 = self.big_dropout1(b3)
|
| 126 |
+
|
| 127 |
+
# Big Conv block 2
|
| 128 |
+
b5 = nn.functional.relu(self.big_conv3(b4))
|
| 129 |
+
b6 = nn.functional.relu(self.big_conv4(b5))
|
| 130 |
+
b7 = self.big_avg_pooling2(b6)
|
| 131 |
+
b8 = self.big_dropout2(b7)
|
| 132 |
+
|
| 133 |
+
# Big Conv block 3
|
| 134 |
+
b9 = nn.functional.relu(self.big_conv5(b8))
|
| 135 |
+
b10 = nn.functional.relu(self.big_conv6(b9))
|
| 136 |
+
b11 = self.big_avg_pooling3(b10)
|
| 137 |
+
b12 = self.big_dropout3(b11)
|
| 138 |
+
|
| 139 |
+
# Reformat Big Shape For Concat w/ Small Branch
|
| 140 |
+
b13 = torch.stack((b12, b12, b12), 2) #TODO: this is hardcoded for num_segs = 3: change this...
|
| 141 |
+
b14 = torch.moveaxis(b13, 1, 2)
|
| 142 |
+
bN, bD, bC, bH, bW = b14.size()
|
| 143 |
+
b15 = b14.reshape(int(bN*bD), bC, bH, bW)
|
| 144 |
+
|
| 145 |
+
# Small Conv block 1
|
| 146 |
+
s1 = self.TSM_1(small_input)
|
| 147 |
+
s2 = nn.functional.relu(self.small_conv1(s1))
|
| 148 |
+
s3 = self.TSM_2(s2)
|
| 149 |
+
s4 = nn.functional.relu(self.small_conv2(s3))
|
| 150 |
+
|
| 151 |
+
# Small Conv block 2
|
| 152 |
+
s5 = self.TSM_3(s4)
|
| 153 |
+
s6 = nn.functional.relu(self.small_conv3(s5))
|
| 154 |
+
s7 = self.TSM_4(s6)
|
| 155 |
+
s8 = nn.functional.relu(self.small_conv4(s7))
|
| 156 |
+
|
| 157 |
+
# Shared Layers
|
| 158 |
+
concat = b15 + s8 # sum layers
|
| 159 |
+
|
| 160 |
+
# share1 = concat.view(concat.size(0), -1) # flatten entire tensors
|
| 161 |
+
share1 = concat.reshape(concat.size(0), -1)
|
| 162 |
+
|
| 163 |
+
# AU Output Layers
|
| 164 |
+
aufc1 = nn.functional.relu(self.au_fc1(share1))
|
| 165 |
+
au_out = self.au_fc2(aufc1)
|
| 166 |
+
|
| 167 |
+
# BVP Output Layers
|
| 168 |
+
bvpfc1 = nn.functional.relu(self.bvp_fc1(share1))
|
| 169 |
+
bvp_out = self.bvp_fc2(bvpfc1)
|
| 170 |
+
|
| 171 |
+
# Resp Output Layers
|
| 172 |
+
respfc1 = nn.functional.relu(self.resp_fc1(share1))
|
| 173 |
+
resp_out = self.resp_fc2(respfc1)
|
| 174 |
+
|
| 175 |
+
return au_out, bvp_out, resp_out
|
| 176 |
+
|
| 177 |
+
|
neural_methods/model/DeepPhys.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DeepPhys - 2D Convolutional Attention Network.
|
| 2 |
+
DeepPhys: Video-Based Physiological Measurement Using Convolutional Attention Networks
|
| 3 |
+
ECCV, 2018
|
| 4 |
+
Weixuan Chen, Daniel McDuff
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Attention_mask(nn.Module):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super(Attention_mask, self).__init__()
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
xsum = torch.sum(x, dim=2, keepdim=True)
|
| 17 |
+
xsum = torch.sum(xsum, dim=3, keepdim=True)
|
| 18 |
+
xshape = tuple(x.size())
|
| 19 |
+
return x / xsum * xshape[2] * xshape[3] * 0.5
|
| 20 |
+
|
| 21 |
+
def get_config(self):
|
| 22 |
+
"""May be generated manually. """
|
| 23 |
+
config = super(Attention_mask, self).get_config()
|
| 24 |
+
return config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DeepPhys(nn.Module):
|
| 28 |
+
|
| 29 |
+
def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
|
| 30 |
+
dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, img_size=36):
|
| 31 |
+
"""Definition of DeepPhys.
|
| 32 |
+
Args:
|
| 33 |
+
in_channels: the number of input channel. Default: 3
|
| 34 |
+
img_size: height/width of each frame. Default: 36.
|
| 35 |
+
Returns:
|
| 36 |
+
DeepPhys model.
|
| 37 |
+
"""
|
| 38 |
+
super(DeepPhys, self).__init__()
|
| 39 |
+
self.in_channels = in_channels
|
| 40 |
+
self.kernel_size = kernel_size
|
| 41 |
+
self.dropout_rate1 = dropout_rate1
|
| 42 |
+
self.dropout_rate2 = dropout_rate2
|
| 43 |
+
self.pool_size = pool_size
|
| 44 |
+
self.nb_filters1 = nb_filters1
|
| 45 |
+
self.nb_filters2 = nb_filters2
|
| 46 |
+
self.nb_dense = nb_dense
|
| 47 |
+
# Motion branch convs
|
| 48 |
+
self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
|
| 49 |
+
bias=True)
|
| 50 |
+
self.motion_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
|
| 51 |
+
self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
|
| 52 |
+
bias=True)
|
| 53 |
+
self.motion_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
|
| 54 |
+
# Apperance branch convs
|
| 55 |
+
self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size,
|
| 56 |
+
padding=(1, 1), bias=True)
|
| 57 |
+
self.apperance_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
|
| 58 |
+
self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size,
|
| 59 |
+
padding=(1, 1), bias=True)
|
| 60 |
+
self.apperance_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
|
| 61 |
+
# Attention layers
|
| 62 |
+
self.apperance_att_conv1 = nn.Conv2d(self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
|
| 63 |
+
self.attn_mask_1 = Attention_mask()
|
| 64 |
+
self.apperance_att_conv2 = nn.Conv2d(self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
|
| 65 |
+
self.attn_mask_2 = Attention_mask()
|
| 66 |
+
# Avg pooling
|
| 67 |
+
self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
|
| 68 |
+
self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
|
| 69 |
+
self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
|
| 70 |
+
# Dropout layers
|
| 71 |
+
self.dropout_1 = nn.Dropout(self.dropout_rate1)
|
| 72 |
+
self.dropout_2 = nn.Dropout(self.dropout_rate1)
|
| 73 |
+
self.dropout_3 = nn.Dropout(self.dropout_rate1)
|
| 74 |
+
self.dropout_4 = nn.Dropout(self.dropout_rate2)
|
| 75 |
+
# Dense layers
|
| 76 |
+
if img_size == 36:
|
| 77 |
+
self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True)
|
| 78 |
+
elif img_size == 72:
|
| 79 |
+
self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True)
|
| 80 |
+
elif img_size == 96:
|
| 81 |
+
self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True)
|
| 82 |
+
else:
|
| 83 |
+
raise Exception('Unsupported image size')
|
| 84 |
+
self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True)
|
| 85 |
+
|
| 86 |
+
def forward(self, inputs, params=None):
|
| 87 |
+
|
| 88 |
+
diff_input = inputs[:, :3, :, :]
|
| 89 |
+
raw_input = inputs[:, 3:, :, :]
|
| 90 |
+
|
| 91 |
+
d1 = torch.tanh(self.motion_conv1(diff_input))
|
| 92 |
+
d2 = torch.tanh(self.motion_conv2(d1))
|
| 93 |
+
|
| 94 |
+
r1 = torch.tanh(self.apperance_conv1(raw_input))
|
| 95 |
+
r2 = torch.tanh(self.apperance_conv2(r1))
|
| 96 |
+
|
| 97 |
+
g1 = torch.sigmoid(self.apperance_att_conv1(r2))
|
| 98 |
+
g1 = self.attn_mask_1(g1)
|
| 99 |
+
gated1 = d2 * g1
|
| 100 |
+
|
| 101 |
+
d3 = self.avg_pooling_1(gated1)
|
| 102 |
+
d4 = self.dropout_1(d3)
|
| 103 |
+
|
| 104 |
+
r3 = self.avg_pooling_2(r2)
|
| 105 |
+
r4 = self.dropout_2(r3)
|
| 106 |
+
|
| 107 |
+
d5 = torch.tanh(self.motion_conv3(d4))
|
| 108 |
+
d6 = torch.tanh(self.motion_conv4(d5))
|
| 109 |
+
|
| 110 |
+
r5 = torch.tanh(self.apperance_conv3(r4))
|
| 111 |
+
r6 = torch.tanh(self.apperance_conv4(r5))
|
| 112 |
+
|
| 113 |
+
g2 = torch.sigmoid(self.apperance_att_conv2(r6))
|
| 114 |
+
g2 = self.attn_mask_2(g2)
|
| 115 |
+
gated2 = d6 * g2
|
| 116 |
+
|
| 117 |
+
d7 = self.avg_pooling_3(gated2)
|
| 118 |
+
d8 = self.dropout_3(d7)
|
| 119 |
+
d9 = d8.view(d8.size(0), -1)
|
| 120 |
+
d10 = torch.tanh(self.final_dense_1(d9))
|
| 121 |
+
d11 = self.dropout_4(d10)
|
| 122 |
+
out = self.final_dense_2(d11)
|
| 123 |
+
|
| 124 |
+
return out
|
| 125 |
+
|
neural_methods/model/EfficientPhys.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""EfficientPhys: Enabling Simple, Fast and Accurate Camera-Based Vitals Measurement
|
| 2 |
+
Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV 2023)
|
| 3 |
+
Xin Liu, Brial Hill, Ziheng Jiang, Shwetak Patel, Daniel McDuff
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Attention_mask(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super(Attention_mask, self).__init__()
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
xsum = torch.sum(x, dim=2, keepdim=True)
|
| 16 |
+
xsum = torch.sum(xsum, dim=3, keepdim=True)
|
| 17 |
+
xshape = tuple(x.size())
|
| 18 |
+
return x / xsum * xshape[2] * xshape[3] * 0.5
|
| 19 |
+
|
| 20 |
+
def get_config(self):
|
| 21 |
+
"""May be generated manually. """
|
| 22 |
+
config = super(Attention_mask, self).get_config()
|
| 23 |
+
return config
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TSM(nn.Module):
|
| 27 |
+
def __init__(self, n_segment=10, fold_div=3):
|
| 28 |
+
super(TSM, self).__init__()
|
| 29 |
+
self.n_segment = n_segment
|
| 30 |
+
self.fold_div = fold_div
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
nt, c, h, w = x.size()
|
| 34 |
+
n_batch = nt // self.n_segment
|
| 35 |
+
x = x.view(n_batch, self.n_segment, c, h, w)
|
| 36 |
+
fold = c // self.fold_div
|
| 37 |
+
out = torch.zeros_like(x)
|
| 38 |
+
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
|
| 39 |
+
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
|
| 40 |
+
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
|
| 41 |
+
return out.view(nt, c, h, w)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class EfficientPhys(nn.Module):
|
| 45 |
+
|
| 46 |
+
def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
|
| 47 |
+
dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20, img_size=36, channel='raw'):
|
| 48 |
+
super(EfficientPhys, self).__init__()
|
| 49 |
+
self.in_channels = in_channels
|
| 50 |
+
self.kernel_size = kernel_size
|
| 51 |
+
self.dropout_rate1 = dropout_rate1
|
| 52 |
+
self.dropout_rate2 = dropout_rate2
|
| 53 |
+
self.pool_size = pool_size
|
| 54 |
+
self.nb_filters1 = nb_filters1
|
| 55 |
+
self.nb_filters2 = nb_filters2
|
| 56 |
+
self.nb_dense = nb_dense
|
| 57 |
+
# TSM layers
|
| 58 |
+
self.TSM_1 = TSM(n_segment=frame_depth)
|
| 59 |
+
self.TSM_2 = TSM(n_segment=frame_depth)
|
| 60 |
+
self.TSM_3 = TSM(n_segment=frame_depth)
|
| 61 |
+
self.TSM_4 = TSM(n_segment=frame_depth)
|
| 62 |
+
# Motion branch convs
|
| 63 |
+
self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
|
| 64 |
+
bias=True)
|
| 65 |
+
self.motion_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
|
| 66 |
+
self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
|
| 67 |
+
bias=True)
|
| 68 |
+
self.motion_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
|
| 69 |
+
# Attention layers
|
| 70 |
+
self.apperance_att_conv1 = nn.Conv2d(self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
|
| 71 |
+
self.attn_mask_1 = Attention_mask()
|
| 72 |
+
self.apperance_att_conv2 = nn.Conv2d(self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
|
| 73 |
+
self.attn_mask_2 = Attention_mask()
|
| 74 |
+
# Avg pooling
|
| 75 |
+
self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
|
| 76 |
+
self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
|
| 77 |
+
self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
|
| 78 |
+
# Dropout layers
|
| 79 |
+
self.dropout_1 = nn.Dropout(self.dropout_rate1)
|
| 80 |
+
self.dropout_2 = nn.Dropout(self.dropout_rate1)
|
| 81 |
+
self.dropout_3 = nn.Dropout(self.dropout_rate1)
|
| 82 |
+
self.dropout_4 = nn.Dropout(self.dropout_rate2)
|
| 83 |
+
# Dense layers
|
| 84 |
+
if img_size == 36:
|
| 85 |
+
self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True)
|
| 86 |
+
elif img_size == 72:
|
| 87 |
+
self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True)
|
| 88 |
+
elif img_size == 96:
|
| 89 |
+
self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True)
|
| 90 |
+
else:
|
| 91 |
+
raise Exception('Unsupported image size')
|
| 92 |
+
self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True)
|
| 93 |
+
self.batch_norm = nn.BatchNorm2d(3)
|
| 94 |
+
self.channel = channel
|
| 95 |
+
|
| 96 |
+
def forward(self, inputs, params=None):
|
| 97 |
+
inputs = torch.diff(inputs, dim=0)
|
| 98 |
+
inputs = self.batch_norm(inputs)
|
| 99 |
+
|
| 100 |
+
network_input = self.TSM_1(inputs)
|
| 101 |
+
d1 = torch.tanh(self.motion_conv1(network_input))
|
| 102 |
+
d1 = self.TSM_2(d1)
|
| 103 |
+
d2 = torch.tanh(self.motion_conv2(d1))
|
| 104 |
+
|
| 105 |
+
g1 = torch.sigmoid(self.apperance_att_conv1(d2))
|
| 106 |
+
g1 = self.attn_mask_1(g1)
|
| 107 |
+
gated1 = d2 * g1
|
| 108 |
+
|
| 109 |
+
d3 = self.avg_pooling_1(gated1)
|
| 110 |
+
d4 = self.dropout_1(d3)
|
| 111 |
+
|
| 112 |
+
d4 = self.TSM_3(d4)
|
| 113 |
+
d5 = torch.tanh(self.motion_conv3(d4))
|
| 114 |
+
d5 = self.TSM_4(d5)
|
| 115 |
+
d6 = torch.tanh(self.motion_conv4(d5))
|
| 116 |
+
|
| 117 |
+
g2 = torch.sigmoid(self.apperance_att_conv2(d6))
|
| 118 |
+
g2 = self.attn_mask_2(g2)
|
| 119 |
+
gated2 = d6 * g2
|
| 120 |
+
|
| 121 |
+
d7 = self.avg_pooling_3(gated2)
|
| 122 |
+
d8 = self.dropout_3(d7)
|
| 123 |
+
d9 = d8.view(d8.size(0), -1)
|
| 124 |
+
d10 = torch.tanh(self.final_dense_1(d9))
|
| 125 |
+
d11 = self.dropout_4(d10)
|
| 126 |
+
out = self.final_dense_2(d11)
|
| 127 |
+
|
| 128 |
+
return out
|
neural_methods/model/FactorizePhys/FSAM.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
|
| 3 |
+
NeurIPS 2024
|
| 4 |
+
Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 11 |
+
import numpy as np
|
| 12 |
+
import neurokit2 as nk
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class _MatrixDecompositionBase(nn.Module):
|
| 16 |
+
def __init__(self, device, md_config, debug=False, dim="3D"):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.dim = dim
|
| 20 |
+
self.md_type = md_config["MD_TYPE"]
|
| 21 |
+
if dim == "3D":
|
| 22 |
+
self.transform = md_config["MD_TRANSFORM"]
|
| 23 |
+
self.S = md_config["MD_S"]
|
| 24 |
+
self.R = md_config["MD_R"]
|
| 25 |
+
self.debug = debug
|
| 26 |
+
|
| 27 |
+
self.train_steps = md_config["MD_STEPS"]
|
| 28 |
+
self.eval_steps = md_config["MD_STEPS"]
|
| 29 |
+
|
| 30 |
+
self.inv_t = md_config["INV_T"]
|
| 31 |
+
self.eta = md_config["ETA"]
|
| 32 |
+
|
| 33 |
+
self.rand_init = md_config["RAND_INIT"]
|
| 34 |
+
self.device = device
|
| 35 |
+
|
| 36 |
+
# print('Dimension:', self.dim)
|
| 37 |
+
# print('S', self.S)
|
| 38 |
+
# print('D', self.D)
|
| 39 |
+
# print('R', self.R)
|
| 40 |
+
# print('train_steps', self.train_steps)
|
| 41 |
+
# print('eval_steps', self.eval_steps)
|
| 42 |
+
# print('inv_t', self.inv_t)
|
| 43 |
+
# print('eta', self.eta)
|
| 44 |
+
# print('rand_init', self.rand_init)
|
| 45 |
+
|
| 46 |
+
def _build_bases(self, B, S, D, R):
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
def local_step(self, x, bases, coef):
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
@torch.no_grad()
|
| 53 |
+
def local_inference(self, x, bases):
|
| 54 |
+
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
| 55 |
+
coef = torch.bmm(x.transpose(1, 2), bases)
|
| 56 |
+
coef = F.softmax(self.inv_t * coef, dim=-1)
|
| 57 |
+
|
| 58 |
+
steps = self.train_steps if self.training else self.eval_steps
|
| 59 |
+
for _ in range(steps):
|
| 60 |
+
bases, coef = self.local_step(x, bases, coef)
|
| 61 |
+
|
| 62 |
+
return bases, coef
|
| 63 |
+
|
| 64 |
+
def compute_coef(self, x, bases, coef):
|
| 65 |
+
raise NotImplementedError
|
| 66 |
+
|
| 67 |
+
def forward(self, x, return_bases=False):
|
| 68 |
+
|
| 69 |
+
if self.debug:
|
| 70 |
+
print("Org x.shape", x.shape)
|
| 71 |
+
|
| 72 |
+
if self.dim == "3D": # (B, C, T, H, W) -> (B * S, D, N)
|
| 73 |
+
B, C, T, H, W = x.shape
|
| 74 |
+
|
| 75 |
+
# t = Time, k = Channels, a & B = height and width
|
| 76 |
+
if self.transform.lower() == "t_kab":
|
| 77 |
+
# # dimension of vector of our interest is T (rPPG signal as T dimension), so forming this as vector
|
| 78 |
+
# # From spatial and channel dimension, which are features, only 2-4 shall be enough to generate the approximated attention matrix
|
| 79 |
+
D = T // self.S
|
| 80 |
+
N = C * H * W
|
| 81 |
+
|
| 82 |
+
elif self.transform.lower() == "tk_ab":
|
| 83 |
+
D = T * C // self.S
|
| 84 |
+
N = H * W
|
| 85 |
+
|
| 86 |
+
elif self.transform.lower() == "k_tab":
|
| 87 |
+
D = C // self.S
|
| 88 |
+
N = T * H * W
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
print("Invalid MD_TRANSFORM specified:", self.transform)
|
| 92 |
+
exit()
|
| 93 |
+
|
| 94 |
+
# # smoothening the temporal dimension
|
| 95 |
+
# x = x.view(B * self.S, N, D)
|
| 96 |
+
# # print("Intermediate-1 x", x.shape)
|
| 97 |
+
|
| 98 |
+
# sample_1 = x[:, :, 0].unsqueeze(2)
|
| 99 |
+
# sample_2 = x[:, :, -1].unsqueeze(2)
|
| 100 |
+
# x = torch.cat([sample_1, x, sample_2], dim=2)
|
| 101 |
+
# gaussian_kernel = [1.0, 1.0, 1.0]
|
| 102 |
+
# kernels = torch.FloatTensor([[gaussian_kernel]]).repeat(N, N, 1).to(self.device)
|
| 103 |
+
# bias = torch.FloatTensor(torch.zeros(N)).to(self.device)
|
| 104 |
+
# x = F.conv1d(x, kernels, bias=bias, padding="valid")
|
| 105 |
+
# x = (x - x.min()) / (x.max() - x.min())
|
| 106 |
+
|
| 107 |
+
# x = x.permute(0, 2, 1)
|
| 108 |
+
# # print("Intermediate-2 x", x.shape)
|
| 109 |
+
|
| 110 |
+
x = x.view(B * self.S, D, N)
|
| 111 |
+
|
| 112 |
+
elif self.dim == "2D": # (B, C, H, W) -> (B * S, D, N)
|
| 113 |
+
B, C, H, W = x.shape
|
| 114 |
+
D = C // self.S
|
| 115 |
+
N = H * W
|
| 116 |
+
x = x.view(B * self.S, D, N)
|
| 117 |
+
|
| 118 |
+
elif self.dim == "2D_TSM": # (B*frame_depth, C, H, W) -> (B, D, N)
|
| 119 |
+
B, C, H, W = x.shape
|
| 120 |
+
BN = B
|
| 121 |
+
B = B // self.S
|
| 122 |
+
D = self.S
|
| 123 |
+
N = C * H * W
|
| 124 |
+
x = x.view(B, D, N)
|
| 125 |
+
self.S = 1 # re-setting this for local inference
|
| 126 |
+
|
| 127 |
+
elif self.dim == "1D": # (B, C, L) -> (B * S, D, N)
|
| 128 |
+
B, C, L = x.shape
|
| 129 |
+
D = L // self.S
|
| 130 |
+
N = C
|
| 131 |
+
x = x.view(B * self.S, D, N)
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
print("Dimension not supported")
|
| 135 |
+
exit()
|
| 136 |
+
|
| 137 |
+
if self.debug:
|
| 138 |
+
print("MD_Type", self.md_type)
|
| 139 |
+
print("MD_S", self.S)
|
| 140 |
+
print("MD_D", D)
|
| 141 |
+
print("MD_N", N)
|
| 142 |
+
print("MD_R", self.R)
|
| 143 |
+
print("MD_TRAIN_STEPS", self.train_steps)
|
| 144 |
+
print("MD_EVAL_STEPS", self.eval_steps)
|
| 145 |
+
print("x.view(B * self.S, D, N)", x.shape)
|
| 146 |
+
|
| 147 |
+
if not self.rand_init and not hasattr(self, 'bases'):
|
| 148 |
+
bases = self._build_bases(1, self.S, D, self.R)
|
| 149 |
+
self.register_buffer('bases', bases)
|
| 150 |
+
|
| 151 |
+
# (S, D, R) -> (B * S, D, R)
|
| 152 |
+
if self.rand_init:
|
| 153 |
+
bases = self._build_bases(B, self.S, D, self.R)
|
| 154 |
+
else:
|
| 155 |
+
bases = self.bases.repeat(B, 1, 1).to(self.device)
|
| 156 |
+
|
| 157 |
+
bases, coef = self.local_inference(x, bases)
|
| 158 |
+
|
| 159 |
+
# (B * S, N, R)
|
| 160 |
+
coef = self.compute_coef(x, bases, coef)
|
| 161 |
+
|
| 162 |
+
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
|
| 163 |
+
x = torch.bmm(bases, coef.transpose(1, 2))
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if self.dim == "3D":
|
| 167 |
+
|
| 168 |
+
apply_smoothening = False
|
| 169 |
+
if apply_smoothening:
|
| 170 |
+
# smoothening the temporal dimension
|
| 171 |
+
x = x.view(B, D * self.S, N) #Joining temporal dimension for contiguous smoothening
|
| 172 |
+
# print("Intermediate-0 x", x.shape)
|
| 173 |
+
x = x.permute(0, 2, 1)
|
| 174 |
+
# print("Intermediate-1 x", x.shape)
|
| 175 |
+
|
| 176 |
+
sample_1 = x[:, :, 0].unsqueeze(2)
|
| 177 |
+
# sample_2 = x[:, :, 0].unsqueeze(2)
|
| 178 |
+
sample_3 = x[:, :, -1].unsqueeze(2)
|
| 179 |
+
# sample_4 = x[:, :, -1].unsqueeze(2)
|
| 180 |
+
x = torch.cat([sample_1, x, sample_3], dim=2)
|
| 181 |
+
# x = torch.cat([sample_1, sample_2, x, sample_3, sample_4], dim=2)
|
| 182 |
+
# gaussian_kernel = [0.25, 0.50, 0.75, 0.50, 0.25]
|
| 183 |
+
# gaussian_kernel = [0.33, 0.66, 1.00, 0.66, 0.33]
|
| 184 |
+
# gaussian_kernel = [0.3, 0.7, 1.0, 0.7, 0.3]
|
| 185 |
+
# gaussian_kernel = [0.3, 1.0, 1.0, 1.0, 0.3]
|
| 186 |
+
# gaussian_kernel = [0.20, 0.80, 1.00, 0.80, 0.20]
|
| 187 |
+
# gaussian_kernel = [1.0, 1.0, 1.0]
|
| 188 |
+
gaussian_kernel = [0.8, 1.0, 0.8]
|
| 189 |
+
kernels = torch.FloatTensor([[gaussian_kernel]]).repeat(N, N, 1).to(self.device)
|
| 190 |
+
bias = torch.FloatTensor(torch.zeros(N)).to(self.device)
|
| 191 |
+
x = F.conv1d(x, kernels, bias=bias, padding="valid")
|
| 192 |
+
# x = (x - x.min()) / (x.max() - x.min())
|
| 193 |
+
# x = (x - x.mean()) / (x.std())
|
| 194 |
+
# x = x - x.min()
|
| 195 |
+
x = (x - x.min())/(x.std())
|
| 196 |
+
|
| 197 |
+
# print("Intermediate-2 x", x.shape)
|
| 198 |
+
|
| 199 |
+
# (B * S, D, N) -> (B, C, T, H, W)
|
| 200 |
+
x = x.view(B, C, T, H, W)
|
| 201 |
+
elif self.dim == "2D":
|
| 202 |
+
# (B * S, D, N) -> (B, C, H, W)
|
| 203 |
+
x = x.view(B, C, H, W)
|
| 204 |
+
|
| 205 |
+
elif self.dim == "2D_TSM":
|
| 206 |
+
# (B, D, N) -> (B, C, H, W)
|
| 207 |
+
x = x.view(BN, C, H, W)
|
| 208 |
+
|
| 209 |
+
else:
|
| 210 |
+
# (B * S, D, N) -> (B, C, L)
|
| 211 |
+
x = x.view(B, C, L)
|
| 212 |
+
|
| 213 |
+
# (B * L, D, R) -> (B, L, N, D)
|
| 214 |
+
bases = bases.view(B, self.S, D, self.R)
|
| 215 |
+
|
| 216 |
+
if not self.rand_init and not self.training and not return_bases:
|
| 217 |
+
self.online_update(bases)
|
| 218 |
+
|
| 219 |
+
# if not self.rand_init or return_bases:
|
| 220 |
+
# return x, bases
|
| 221 |
+
# else:
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def online_update(self, bases):
|
| 226 |
+
# (B, S, D, R) -> (S, D, R)
|
| 227 |
+
update = bases.mean(dim=0)
|
| 228 |
+
self.bases += self.eta * (update - self.bases)
|
| 229 |
+
self.bases = F.normalize(self.bases, dim=1)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class NMF(_MatrixDecompositionBase):
|
| 233 |
+
def __init__(self, device, md_config, debug=False, dim="3D"):
|
| 234 |
+
super().__init__(device, md_config, debug=debug, dim=dim)
|
| 235 |
+
self.device = device
|
| 236 |
+
self.inv_t = 1
|
| 237 |
+
|
| 238 |
+
def _build_bases(self, B, S, D, R):
|
| 239 |
+
# bases = torch.rand((B * S, D, R)).to(self.device)
|
| 240 |
+
bases = torch.ones((B * S, D, R)).to(self.device)
|
| 241 |
+
bases = F.normalize(bases, dim=1)
|
| 242 |
+
|
| 243 |
+
return bases
|
| 244 |
+
|
| 245 |
+
@torch.no_grad()
|
| 246 |
+
def local_step(self, x, bases, coef):
|
| 247 |
+
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
| 248 |
+
numerator = torch.bmm(x.transpose(1, 2), bases)
|
| 249 |
+
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
|
| 250 |
+
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
| 251 |
+
# Multiplicative Update
|
| 252 |
+
coef = coef * numerator / (denominator + 1e-6)
|
| 253 |
+
|
| 254 |
+
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
|
| 255 |
+
numerator = torch.bmm(x, coef)
|
| 256 |
+
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
|
| 257 |
+
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
|
| 258 |
+
# Multiplicative Update
|
| 259 |
+
bases = bases * numerator / (denominator + 1e-6)
|
| 260 |
+
|
| 261 |
+
return bases, coef
|
| 262 |
+
|
| 263 |
+
def compute_coef(self, x, bases, coef):
|
| 264 |
+
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
| 265 |
+
numerator = torch.bmm(x.transpose(1, 2), bases)
|
| 266 |
+
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
|
| 267 |
+
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
| 268 |
+
# multiplication update
|
| 269 |
+
coef = coef * numerator / (denominator + 1e-6)
|
| 270 |
+
|
| 271 |
+
return coef
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class VQ(_MatrixDecompositionBase):
|
| 275 |
+
def __init__(self, device, md_config, debug=False, dim="3D"):
|
| 276 |
+
super().__init__(device, md_config, debug=debug, dim=dim)
|
| 277 |
+
self.device = device
|
| 278 |
+
|
| 279 |
+
def _build_bases(self, B, S, D, R):
|
| 280 |
+
# bases = torch.randn((B * S, D, R)).to(self.device)
|
| 281 |
+
bases = torch.ones((B * S, D, R)).to(self.device)
|
| 282 |
+
bases = F.normalize(bases, dim=1)
|
| 283 |
+
return bases
|
| 284 |
+
|
| 285 |
+
@torch.no_grad()
|
| 286 |
+
def local_step(self, x, bases, _):
|
| 287 |
+
# (B * S, D, N), normalize x along D (for cosine similarity)
|
| 288 |
+
std_x = F.normalize(x, dim=1)
|
| 289 |
+
|
| 290 |
+
# (B * S, D, R), normalize bases along D (for cosine similarity)
|
| 291 |
+
std_bases = F.normalize(bases, dim=1, eps=1e-6)
|
| 292 |
+
|
| 293 |
+
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
| 294 |
+
coef = torch.bmm(std_x.transpose(1, 2), std_bases)
|
| 295 |
+
|
| 296 |
+
# softmax along R
|
| 297 |
+
coef = F.softmax(self.inv_t * coef, dim=-1)
|
| 298 |
+
|
| 299 |
+
# normalize along N
|
| 300 |
+
coef = coef / (1e-6 + coef.sum(dim=1, keepdim=True))
|
| 301 |
+
|
| 302 |
+
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
|
| 303 |
+
bases = torch.bmm(x, coef)
|
| 304 |
+
|
| 305 |
+
return bases, coef
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def compute_coef(self, x, bases, _):
|
| 309 |
+
with torch.no_grad():
|
| 310 |
+
# (B * S, D, N) -> (B * S, 1, N)
|
| 311 |
+
x_norm = x.norm(dim=1, keepdim=True)
|
| 312 |
+
|
| 313 |
+
# (B * S, D, N) / (B * S, 1, N) -> (B * S, D, N)
|
| 314 |
+
std_x = x / (1e-6 + x_norm)
|
| 315 |
+
|
| 316 |
+
# (B * S, D, R), normalize bases along D (for cosine similarity)
|
| 317 |
+
std_bases = F.normalize(bases, dim=1, eps=1e-6)
|
| 318 |
+
|
| 319 |
+
# (B * S, N, D)^T @ (B * S, D, R) -> (B * S, N, R)
|
| 320 |
+
coef = torch.bmm(std_x.transpose(1, 2), std_bases)
|
| 321 |
+
|
| 322 |
+
# softmax along R
|
| 323 |
+
coef = F.softmax(self.inv_t * coef, dim=-1)
|
| 324 |
+
|
| 325 |
+
return coef
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class ConvBNReLU(nn.Module):
|
| 329 |
+
@classmethod
|
| 330 |
+
def _same_paddings(cls, kernel_size, dim):
|
| 331 |
+
if dim == "3D":
|
| 332 |
+
if kernel_size == (1, 1, 1):
|
| 333 |
+
return (0, 0, 0)
|
| 334 |
+
elif kernel_size == (3, 3, 3):
|
| 335 |
+
return (1, 1, 1)
|
| 336 |
+
elif dim == "2D" or dim == "2D_TSM":
|
| 337 |
+
if kernel_size == (1, 1):
|
| 338 |
+
return (0, 0)
|
| 339 |
+
elif kernel_size == (3, 3):
|
| 340 |
+
return (1, 1)
|
| 341 |
+
else:
|
| 342 |
+
if kernel_size == 1:
|
| 343 |
+
return 0
|
| 344 |
+
elif kernel_size == 3:
|
| 345 |
+
return 1
|
| 346 |
+
|
| 347 |
+
def __init__(self, in_c, out_c, dim,
|
| 348 |
+
kernel_size=1, stride=1, padding='same',
|
| 349 |
+
dilation=1, groups=1, act='relu', apply_bn=False, apply_act=True):
|
| 350 |
+
super().__init__()
|
| 351 |
+
|
| 352 |
+
self.apply_bn = apply_bn
|
| 353 |
+
self.apply_act = apply_act
|
| 354 |
+
self.dim = dim
|
| 355 |
+
if dilation == 1:
|
| 356 |
+
if self.dim == "3D":
|
| 357 |
+
dilation = (1, 1, 1)
|
| 358 |
+
elif self.dim == "2D" or dim == "2D_TSM":
|
| 359 |
+
dilation = (1, 1)
|
| 360 |
+
else:
|
| 361 |
+
dilation = 1
|
| 362 |
+
|
| 363 |
+
if kernel_size == 1:
|
| 364 |
+
if self.dim == "3D":
|
| 365 |
+
kernel_size = (1, 1, 1)
|
| 366 |
+
elif self.dim == "2D" or dim == "2D_TSM":
|
| 367 |
+
kernel_size = (1, 1)
|
| 368 |
+
else:
|
| 369 |
+
kernel_size = 1
|
| 370 |
+
|
| 371 |
+
if stride == 1:
|
| 372 |
+
if self.dim == "3D":
|
| 373 |
+
stride = (1, 1, 1)
|
| 374 |
+
elif self.dim == "2D" or dim == "2D_TSM":
|
| 375 |
+
stride = (1, 1)
|
| 376 |
+
else:
|
| 377 |
+
stride = 1
|
| 378 |
+
|
| 379 |
+
if padding == 'same':
|
| 380 |
+
padding = self._same_paddings(kernel_size, dim)
|
| 381 |
+
|
| 382 |
+
if self.dim == "3D":
|
| 383 |
+
self.conv = nn.Conv3d(in_c, out_c,
|
| 384 |
+
kernel_size=kernel_size, stride=stride,
|
| 385 |
+
padding=padding, dilation=dilation,
|
| 386 |
+
groups=groups,
|
| 387 |
+
bias=False)
|
| 388 |
+
elif self.dim == "2D" or dim == "2D_TSM":
|
| 389 |
+
self.conv = nn.Conv2d(in_c, out_c,
|
| 390 |
+
kernel_size=kernel_size, stride=stride,
|
| 391 |
+
padding=padding, dilation=dilation,
|
| 392 |
+
groups=groups,
|
| 393 |
+
bias=False)
|
| 394 |
+
else:
|
| 395 |
+
self.conv = nn.Conv1d(in_c, out_c,
|
| 396 |
+
kernel_size=kernel_size, stride=stride,
|
| 397 |
+
padding=padding, dilation=dilation,
|
| 398 |
+
groups=groups,
|
| 399 |
+
bias=False)
|
| 400 |
+
|
| 401 |
+
if act == "sigmoid":
|
| 402 |
+
self.act = nn.Sigmoid()
|
| 403 |
+
else:
|
| 404 |
+
self.act = nn.ReLU(inplace=True)
|
| 405 |
+
|
| 406 |
+
if self.apply_bn:
|
| 407 |
+
if self.dim == "3D":
|
| 408 |
+
self.bn = nn.InstanceNorm3d(out_c)
|
| 409 |
+
elif self.dim == "2D" or dim == "2D_TSM":
|
| 410 |
+
self.bn = nn.InstanceNorm2d(out_c)
|
| 411 |
+
else:
|
| 412 |
+
self.bn = nn.InstanceNorm1d(out_c)
|
| 413 |
+
|
| 414 |
+
def forward(self, x):
|
| 415 |
+
x = self.conv(x)
|
| 416 |
+
if self.apply_act:
|
| 417 |
+
x = self.act(x)
|
| 418 |
+
if self.apply_bn:
|
| 419 |
+
x = self.bn(x)
|
| 420 |
+
return x
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class FeaturesFactorizationModule(nn.Module):
|
| 424 |
+
def __init__(self, inC, device, md_config, dim="3D", debug=False):
|
| 425 |
+
super().__init__()
|
| 426 |
+
|
| 427 |
+
self.device = device
|
| 428 |
+
self.dim = dim
|
| 429 |
+
md_type = md_config["MD_TYPE"]
|
| 430 |
+
align_C = md_config["align_channels"] # inC // 2 # // 2 #// 8
|
| 431 |
+
|
| 432 |
+
if self.dim == "3D":
|
| 433 |
+
if "nmf" in md_type.lower():
|
| 434 |
+
self.pre_conv_block = nn.Sequential(
|
| 435 |
+
nn.Conv3d(inC, align_C, (1, 1, 1)),
|
| 436 |
+
nn.ReLU(inplace=True))
|
| 437 |
+
else:
|
| 438 |
+
self.pre_conv_block = nn.Conv3d(inC, align_C, (1, 1, 1))
|
| 439 |
+
elif self.dim == "2D" or self.dim == "2D_TSM":
|
| 440 |
+
if "nmf" in md_type.lower():
|
| 441 |
+
self.pre_conv_block = nn.Sequential(
|
| 442 |
+
nn.Conv2d(inC, align_C, (1, 1)),
|
| 443 |
+
nn.ReLU(inplace=True)
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
self.pre_conv_block = nn.Conv2d(inC, align_C, (1, 1))
|
| 447 |
+
elif self.dim == "1D":
|
| 448 |
+
if "nmf" in md_type.lower():
|
| 449 |
+
self.pre_conv_block = nn.Sequential(
|
| 450 |
+
nn.Conv1d(inC, align_C, 1),
|
| 451 |
+
nn.ReLU(inplace=True)
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
self.pre_conv_block = nn.Conv1d(inC, align_C, 1)
|
| 455 |
+
else:
|
| 456 |
+
print("Dimension not supported")
|
| 457 |
+
|
| 458 |
+
if "nmf" in md_type.lower():
|
| 459 |
+
self.md_block = NMF(self.device, md_config, dim=self.dim, debug=debug)
|
| 460 |
+
elif "vq" in md_type.lower():
|
| 461 |
+
self.md_block = VQ(self.device, md_config, dim=self.dim, debug=debug)
|
| 462 |
+
else:
|
| 463 |
+
print("Unknown type specified for MD_TYPE:", md_type)
|
| 464 |
+
exit()
|
| 465 |
+
|
| 466 |
+
if self.dim == "3D":
|
| 467 |
+
if "nmf" in md_type.lower():
|
| 468 |
+
self.post_conv_block = nn.Sequential(
|
| 469 |
+
ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1),
|
| 470 |
+
nn.Conv3d(align_C, inC, 1, bias=False)
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
self.post_conv_block = nn.Sequential(
|
| 474 |
+
ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False),
|
| 475 |
+
nn.Conv3d(align_C, inC, 1, bias=False)
|
| 476 |
+
)
|
| 477 |
+
elif self.dim == "2D" or self.dim == "2D_TSM":
|
| 478 |
+
if "nmf" in md_type.lower():
|
| 479 |
+
self.post_conv_block = nn.Sequential(
|
| 480 |
+
ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1),
|
| 481 |
+
nn.Conv2d(align_C, inC, 1, bias=False)
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
self.post_conv_block = nn.Sequential(
|
| 485 |
+
ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False),
|
| 486 |
+
nn.Conv2d(align_C, inC, 1, bias=False)
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
if "nmf" in md_type.lower():
|
| 490 |
+
self.post_conv_block = nn.Sequential(
|
| 491 |
+
ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1),
|
| 492 |
+
nn.Conv1d(align_C, inC, 1, bias=False)
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
self.post_conv_block = nn.Sequential(
|
| 496 |
+
ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False),
|
| 497 |
+
nn.Conv1d(align_C, inC, 1, bias=False)
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
self._init_weight()
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def _init_weight(self):
|
| 504 |
+
for m in self.modules():
|
| 505 |
+
if isinstance(m, nn.Conv3d):
|
| 506 |
+
N = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
|
| 507 |
+
m.weight.data.normal_(0, np.sqrt(2. / N))
|
| 508 |
+
elif isinstance(m, nn.Conv2d):
|
| 509 |
+
N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 510 |
+
m.weight.data.normal_(0, np.sqrt(2. / N))
|
| 511 |
+
elif isinstance(m, nn.Conv1d):
|
| 512 |
+
N = m.kernel_size[0] * m.out_channels
|
| 513 |
+
m.weight.data.normal_(0, np.sqrt(2. / N))
|
| 514 |
+
elif isinstance(m, _BatchNorm):
|
| 515 |
+
m.weight.data.fill_(1)
|
| 516 |
+
if m.bias is not None:
|
| 517 |
+
m.bias.data.zero_()
|
| 518 |
+
|
| 519 |
+
def forward(self, x):
|
| 520 |
+
x = self.pre_conv_block(x)
|
| 521 |
+
att = self.md_block(x)
|
| 522 |
+
dist = torch.dist(x, att)
|
| 523 |
+
att = self.post_conv_block(att)
|
| 524 |
+
|
| 525 |
+
return att, dist
|
| 526 |
+
|
| 527 |
+
def online_update(self, bases):
|
| 528 |
+
if hasattr(self.md_block, 'online_update'):
|
| 529 |
+
self.md_block.online_update(bases)
|
| 530 |
+
|
neural_methods/model/FactorizePhys/FactorizePhys.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
|
| 3 |
+
NeurIPS 2024
|
| 4 |
+
Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from neural_methods.model.FactorizePhys.FSAM import FeaturesFactorizationModule
|
| 10 |
+
|
| 11 |
+
nf = [8, 12, 16]
|
| 12 |
+
|
| 13 |
+
model_config = {
|
| 14 |
+
"MD_FSAM": True,
|
| 15 |
+
"MD_TYPE": "NMF",
|
| 16 |
+
"MD_TRANSFORM": "T_KAB",
|
| 17 |
+
"MD_R": 1,
|
| 18 |
+
"MD_S": 1,
|
| 19 |
+
"MD_STEPS": 4,
|
| 20 |
+
"MD_INFERENCE": False,
|
| 21 |
+
"MD_RESIDUAL": False,
|
| 22 |
+
"INV_T": 1,
|
| 23 |
+
"ETA": 0.9,
|
| 24 |
+
"RAND_INIT": True,
|
| 25 |
+
"in_channels": 3,
|
| 26 |
+
"data_channels": 4,
|
| 27 |
+
"align_channels": nf[2] // 2,
|
| 28 |
+
"height": 72,
|
| 29 |
+
"weight": 72,
|
| 30 |
+
"batch_size": 4,
|
| 31 |
+
"frames": 160,
|
| 32 |
+
"debug": False,
|
| 33 |
+
"assess_latency": False,
|
| 34 |
+
"num_trials": 20,
|
| 35 |
+
"visualize": False,
|
| 36 |
+
"ckpt_path": "",
|
| 37 |
+
"data_path": "",
|
| 38 |
+
"label_path": ""
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ConvBlock3D(nn.Module):
|
| 43 |
+
def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
|
| 44 |
+
super(ConvBlock3D, self).__init__()
|
| 45 |
+
self.conv_block_3d = nn.Sequential(
|
| 46 |
+
nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding=padding, bias=False),
|
| 47 |
+
nn.Tanh(),
|
| 48 |
+
nn.InstanceNorm3d(out_channel),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return self.conv_block_3d(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class rPPG_FeatureExtractor(nn.Module):
|
| 56 |
+
def __init__(self, inCh, dropout_rate=0.1, debug=False):
|
| 57 |
+
super(rPPG_FeatureExtractor, self).__init__()
|
| 58 |
+
# inCh, out_channel, kernel_size, stride, padding
|
| 59 |
+
|
| 60 |
+
self.debug = debug
|
| 61 |
+
# Input: #B, inCh, 160, 72, 72
|
| 62 |
+
self.FeatureExtractor = nn.Sequential(
|
| 63 |
+
ConvBlock3D(inCh, nf[0], [3, 3, 3], [1, 1, 1], [1, 1, 1]), #B, nf[0], 160, 72, 72
|
| 64 |
+
ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 160, 35, 35
|
| 65 |
+
ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 33, 33
|
| 66 |
+
nn.Dropout3d(p=dropout_rate),
|
| 67 |
+
|
| 68 |
+
ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 31, 31
|
| 69 |
+
ConvBlock3D(nf[1], nf[2], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[2], 160, 15, 15
|
| 70 |
+
ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 13, 13
|
| 71 |
+
nn.Dropout3d(p=dropout_rate),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
voxel_embeddings = self.FeatureExtractor(x)
|
| 76 |
+
if self.debug:
|
| 77 |
+
print("rPPG Feature Extractor")
|
| 78 |
+
print(" voxel_embeddings.shape", voxel_embeddings.shape)
|
| 79 |
+
return voxel_embeddings
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class BVP_Head(nn.Module):
|
| 83 |
+
def __init__(self, md_config, device, dropout_rate=0.1, debug=False):
|
| 84 |
+
super(BVP_Head, self).__init__()
|
| 85 |
+
self.debug = debug
|
| 86 |
+
|
| 87 |
+
self.use_fsam = md_config["MD_FSAM"]
|
| 88 |
+
self.md_type = md_config["MD_TYPE"]
|
| 89 |
+
self.md_infer = md_config["MD_INFERENCE"]
|
| 90 |
+
self.md_res = md_config["MD_RESIDUAL"]
|
| 91 |
+
|
| 92 |
+
self.conv_block = nn.Sequential(
|
| 93 |
+
ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 11, 11
|
| 94 |
+
ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 9, 9
|
| 95 |
+
ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 7, 7
|
| 96 |
+
nn.Dropout3d(p=dropout_rate),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if self.use_fsam:
|
| 100 |
+
inC = nf[2]
|
| 101 |
+
self.fsam = FeaturesFactorizationModule(inC, device, md_config, dim="3D", debug=debug)
|
| 102 |
+
self.fsam_norm = nn.InstanceNorm3d(inC)
|
| 103 |
+
self.bias1 = nn.Parameter(torch.tensor(1.0), requires_grad=True).to(device)
|
| 104 |
+
else:
|
| 105 |
+
inC = nf[2]
|
| 106 |
+
|
| 107 |
+
self.final_layer = nn.Sequential(
|
| 108 |
+
ConvBlock3D(inC, nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 5, 5
|
| 109 |
+
ConvBlock3D(nf[1], nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 160, 3, 3
|
| 110 |
+
nn.Conv3d(nf[0], 1, (3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), #B, 1, 160, 1, 1
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def forward(self, voxel_embeddings, batch, length):
|
| 115 |
+
|
| 116 |
+
if self.debug:
|
| 117 |
+
print("BVP Head")
|
| 118 |
+
print(" voxel_embeddings.shape", voxel_embeddings.shape)
|
| 119 |
+
|
| 120 |
+
voxel_embeddings = self.conv_block(voxel_embeddings)
|
| 121 |
+
|
| 122 |
+
if (self.md_infer or self.training or self.debug) and self.use_fsam:
|
| 123 |
+
if "NMF" in self.md_type:
|
| 124 |
+
att_mask, appx_error = self.fsam(voxel_embeddings - voxel_embeddings.min()) # to make it positive (>= 0)
|
| 125 |
+
else:
|
| 126 |
+
att_mask, appx_error = self.fsam(voxel_embeddings)
|
| 127 |
+
|
| 128 |
+
if self.debug:
|
| 129 |
+
print("att_mask.shape", att_mask.shape)
|
| 130 |
+
|
| 131 |
+
# # directly use att_mask ---> difficult to converge without Residual connection. Needs high rank
|
| 132 |
+
# factorized_embeddings = self.fsam_norm(att_mask)
|
| 133 |
+
|
| 134 |
+
# # Residual connection:
|
| 135 |
+
# factorized_embeddings = voxel_embeddings + self.fsam_norm(att_mask)
|
| 136 |
+
|
| 137 |
+
if self.md_res:
|
| 138 |
+
# Multiplication with Residual connection
|
| 139 |
+
x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1)
|
| 140 |
+
factorized_embeddings = self.fsam_norm(x)
|
| 141 |
+
factorized_embeddings = voxel_embeddings + factorized_embeddings
|
| 142 |
+
else:
|
| 143 |
+
# Multiplication
|
| 144 |
+
x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1)
|
| 145 |
+
factorized_embeddings = self.fsam_norm(x)
|
| 146 |
+
|
| 147 |
+
# # Concatenate
|
| 148 |
+
# factorized_embeddings = torch.cat([voxel_embeddings, self.fsam_norm(x)], dim=1)
|
| 149 |
+
|
| 150 |
+
x = self.final_layer(factorized_embeddings)
|
| 151 |
+
|
| 152 |
+
else:
|
| 153 |
+
x = self.final_layer(voxel_embeddings)
|
| 154 |
+
|
| 155 |
+
rPPG = x.view(-1, length)
|
| 156 |
+
|
| 157 |
+
if self.debug:
|
| 158 |
+
print(" rPPG.shape", rPPG.shape)
|
| 159 |
+
|
| 160 |
+
if (self.md_infer or self.training or self.debug) and self.use_fsam:
|
| 161 |
+
return rPPG, factorized_embeddings, appx_error
|
| 162 |
+
else:
|
| 163 |
+
return rPPG
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class FactorizePhys(nn.Module):
|
| 168 |
+
def __init__(self, frames, md_config, in_channels=3, dropout=0.1, device=torch.device("cpu"), debug=False):
|
| 169 |
+
super(FactorizePhys, self).__init__()
|
| 170 |
+
self.debug = debug
|
| 171 |
+
|
| 172 |
+
self.in_channels = in_channels
|
| 173 |
+
if self.in_channels == 1 or self.in_channels == 3:
|
| 174 |
+
self.norm = nn.InstanceNorm3d(self.in_channels)
|
| 175 |
+
elif self.in_channels == 4:
|
| 176 |
+
self.rgb_norm = nn.InstanceNorm3d(3)
|
| 177 |
+
self.thermal_norm = nn.InstanceNorm3d(1)
|
| 178 |
+
else:
|
| 179 |
+
print("Unsupported input channels")
|
| 180 |
+
|
| 181 |
+
self.use_fsam = md_config["MD_FSAM"]
|
| 182 |
+
self.md_infer = md_config["MD_INFERENCE"]
|
| 183 |
+
|
| 184 |
+
for key in model_config:
|
| 185 |
+
if key not in md_config:
|
| 186 |
+
md_config[key] = model_config[key]
|
| 187 |
+
|
| 188 |
+
if self.debug:
|
| 189 |
+
print("nf:", nf)
|
| 190 |
+
|
| 191 |
+
self.rppg_feature_extractor = rPPG_FeatureExtractor(self.in_channels, dropout_rate=dropout, debug=debug)
|
| 192 |
+
|
| 193 |
+
self.rppg_head = BVP_Head(md_config, device=device, dropout_rate=dropout, debug=debug)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32]
|
| 197 |
+
|
| 198 |
+
[batch, channel, length, width, height] = x.shape
|
| 199 |
+
|
| 200 |
+
# if self.in_channels == 1:
|
| 201 |
+
# x = x[:, :, :-1, :, :]
|
| 202 |
+
# else:
|
| 203 |
+
# x = torch.diff(x, dim=2)
|
| 204 |
+
|
| 205 |
+
x = torch.diff(x, dim=2)
|
| 206 |
+
|
| 207 |
+
if self.debug:
|
| 208 |
+
print("Input.shape", x.shape)
|
| 209 |
+
|
| 210 |
+
if self.in_channels == 1:
|
| 211 |
+
x = self.norm(x[:, -1:, :, :, :])
|
| 212 |
+
elif self.in_channels == 3:
|
| 213 |
+
x = self.norm(x[:, :3, :, :, :])
|
| 214 |
+
elif self.in_channels == 4:
|
| 215 |
+
rgb_x = self.rgb_norm(x[:, :3, :, :, :])
|
| 216 |
+
thermal_x = self.thermal_norm(x[:, -1:, :, :, :])
|
| 217 |
+
x = torch.concat([rgb_x, thermal_x], dim = 1)
|
| 218 |
+
else:
|
| 219 |
+
try:
|
| 220 |
+
print("Specified input channels:", self.in_channels)
|
| 221 |
+
print("Data channels", channel)
|
| 222 |
+
assert self.in_channels <= channel
|
| 223 |
+
except:
|
| 224 |
+
print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels")
|
| 225 |
+
print("Default or specified channels:", self.in_channels)
|
| 226 |
+
print("Data channels [B, C, N, W, H]", x.shape)
|
| 227 |
+
print("Exiting")
|
| 228 |
+
exit()
|
| 229 |
+
|
| 230 |
+
if self.debug:
|
| 231 |
+
print("Diff Normalized shape", x.shape)
|
| 232 |
+
|
| 233 |
+
voxel_embeddings = self.rppg_feature_extractor(x)
|
| 234 |
+
|
| 235 |
+
if (self.md_infer or self.training or self.debug) and self.use_fsam:
|
| 236 |
+
rPPG, factorized_embeddings, appx_error = self.rppg_head(voxel_embeddings, batch, length-1)
|
| 237 |
+
else:
|
| 238 |
+
rPPG = self.rppg_head(voxel_embeddings, batch, length-1)
|
| 239 |
+
|
| 240 |
+
# if self.debug:
|
| 241 |
+
# print("rppg_feats.shape", rppg_feats.shape)
|
| 242 |
+
|
| 243 |
+
# rPPG = rppg_feats.view(-1, length-1)
|
| 244 |
+
|
| 245 |
+
if self.debug:
|
| 246 |
+
print("rPPG.shape", rPPG.shape)
|
| 247 |
+
|
| 248 |
+
if (self.md_infer or self.training or self.debug) and self.use_fsam:
|
| 249 |
+
return rPPG, voxel_embeddings, factorized_embeddings, appx_error
|
| 250 |
+
else:
|
| 251 |
+
return rPPG, voxel_embeddings
|
neural_methods/model/FactorizePhys/FactorizePhysBig.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
|
| 3 |
+
NeurIPS 2024
|
| 4 |
+
Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from neural_methods.model.FactorizePhys.FSAM import FeaturesFactorizationModule
|
| 10 |
+
|
| 11 |
+
nf = [8, 12, 16]
|
| 12 |
+
|
| 13 |
+
model_config = {
|
| 14 |
+
"MD_FSAM": True,
|
| 15 |
+
"MD_TYPE": "NMF",
|
| 16 |
+
"MD_TRANSFORM": "T_KAB",
|
| 17 |
+
"MD_R": 1,
|
| 18 |
+
"MD_S": 1,
|
| 19 |
+
"MD_STEPS": 4,
|
| 20 |
+
"MD_INFERENCE": False,
|
| 21 |
+
"MD_RESIDUAL": False,
|
| 22 |
+
"INV_T": 1,
|
| 23 |
+
"ETA": 0.9,
|
| 24 |
+
"RAND_INIT": True,
|
| 25 |
+
"in_channels": 3,
|
| 26 |
+
"data_channels": 4,
|
| 27 |
+
"align_channels": nf[2] // 2,
|
| 28 |
+
"height": 128,
|
| 29 |
+
"weight": 128,
|
| 30 |
+
"batch_size": 4,
|
| 31 |
+
"frames": 240,
|
| 32 |
+
"debug": False,
|
| 33 |
+
"assess_latency": False,
|
| 34 |
+
"num_trials": 20,
|
| 35 |
+
"visualize": False,
|
| 36 |
+
"ckpt_path": "",
|
| 37 |
+
"data_path": "",
|
| 38 |
+
"label_path": ""
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ConvBlock3D(nn.Module):
|
| 43 |
+
def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
|
| 44 |
+
super(ConvBlock3D, self).__init__()
|
| 45 |
+
self.conv_block_3d = nn.Sequential(
|
| 46 |
+
nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding=padding, bias=False),
|
| 47 |
+
nn.Tanh(),
|
| 48 |
+
nn.InstanceNorm3d(out_channel),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return self.conv_block_3d(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class rPPG_FeatureExtractor(nn.Module):
|
| 56 |
+
def __init__(self, inCh, dropout_rate=0.1, debug=False):
|
| 57 |
+
super(rPPG_FeatureExtractor, self).__init__()
|
| 58 |
+
# inCh, out_channel, kernel_size, stride, padding
|
| 59 |
+
|
| 60 |
+
self.debug = debug
|
| 61 |
+
# Input: #B, inCh, 240, 128, 128
|
| 62 |
+
self.FeatureExtractor = nn.Sequential(
|
| 63 |
+
ConvBlock3D(inCh, nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 240, 126, 126
|
| 64 |
+
ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 240, 62, 62
|
| 65 |
+
ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 240, 60, 60
|
| 66 |
+
nn.Dropout3d(p=dropout_rate),
|
| 67 |
+
|
| 68 |
+
ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 240, 29, 29
|
| 69 |
+
ConvBlock3D(nf[1], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 27, 27
|
| 70 |
+
ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 25, 25
|
| 71 |
+
nn.Dropout3d(p=dropout_rate),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
voxel_embeddings = self.FeatureExtractor(x)
|
| 76 |
+
if self.debug:
|
| 77 |
+
print("rPPG Feature Extractor")
|
| 78 |
+
print(" voxel_embeddings.shape", voxel_embeddings.shape)
|
| 79 |
+
return voxel_embeddings
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class BVP_Head(nn.Module):
|
| 83 |
+
def __init__(self, md_config, device, dropout_rate=0.1, debug=False):
|
| 84 |
+
super(BVP_Head, self).__init__()
|
| 85 |
+
self.debug = debug
|
| 86 |
+
|
| 87 |
+
self.use_fsam = md_config["MD_FSAM"]
|
| 88 |
+
self.md_type = md_config["MD_TYPE"]
|
| 89 |
+
self.md_infer = md_config["MD_INFERENCE"]
|
| 90 |
+
self.md_res = md_config["MD_RESIDUAL"]
|
| 91 |
+
|
| 92 |
+
self.conv_block = nn.Sequential(
|
| 93 |
+
ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[2], 240, 12, 12
|
| 94 |
+
ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 10, 10
|
| 95 |
+
ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 8, 8
|
| 96 |
+
nn.Dropout3d(p=dropout_rate),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if self.use_fsam:
|
| 100 |
+
inC = nf[2]
|
| 101 |
+
self.fsam = FeaturesFactorizationModule(inC, device, md_config, dim="3D", debug=debug)
|
| 102 |
+
self.fsam_norm = nn.InstanceNorm3d(inC)
|
| 103 |
+
self.bias1 = nn.Parameter(torch.tensor(1.0), requires_grad=True).to(device)
|
| 104 |
+
else:
|
| 105 |
+
inC = nf[2]
|
| 106 |
+
|
| 107 |
+
self.final_layer = nn.Sequential(
|
| 108 |
+
ConvBlock3D(inC, nf[1], [3, 4, 4], [1, 1, 1], [1, 0, 0]), #B, nf[1], 240, 5, 5
|
| 109 |
+
ConvBlock3D(nf[1], nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 240, 3, 3
|
| 110 |
+
nn.Conv3d(nf[0], 1, (3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), #B, 1, 240, 1, 1
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def forward(self, voxel_embeddings, batch, length):
|
| 115 |
+
|
| 116 |
+
if self.debug:
|
| 117 |
+
print("BVP Head")
|
| 118 |
+
print(" voxel_embeddings.shape", voxel_embeddings.shape)
|
| 119 |
+
|
| 120 |
+
if (self.md_infer or self.training or self.debug) and self.use_fsam:
|
| 121 |
+
if "NMF" in self.md_type:
|
| 122 |
+
att_mask, appx_error = self.fsam(voxel_embeddings - voxel_embeddings.min()) # to make it positive (>= 0)
|
| 123 |
+
else:
|
| 124 |
+
att_mask, appx_error = self.fsam(voxel_embeddings)
|
| 125 |
+
|
| 126 |
+
if self.debug:
|
| 127 |
+
print("att_mask.shape", att_mask.shape)
|
| 128 |
+
|
| 129 |
+
# # directly use att_mask ---> difficult to converge without Residual connection. Needs high rank
|
| 130 |
+
# factorized_embeddings = self.fsam_norm(att_mask)
|
| 131 |
+
|
| 132 |
+
# # Residual connection:
|
| 133 |
+
# factorized_embeddings = voxel_embeddings + self.fsam_norm(att_mask)
|
| 134 |
+
|
| 135 |
+
if self.md_res:
|
| 136 |
+
# Multiplication with Residual connection
|
| 137 |
+
x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1)
|
| 138 |
+
factorized_embeddings = self.fsam_norm(x)
|
| 139 |
+
factorized_embeddings = voxel_embeddings + factorized_embeddings
|
| 140 |
+
else:
|
| 141 |
+
# Multiplication
|
| 142 |
+
x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1)
|
| 143 |
+
factorized_embeddings = self.fsam_norm(x)
|
| 144 |
+
|
| 145 |
+
# # Concatenate
|
| 146 |
+
# factorized_embeddings = torch.cat([voxel_embeddings, self.fsam_norm(x)], dim=1)
|
| 147 |
+
|
| 148 |
+
x = self.conv_block(factorized_embeddings)
|
| 149 |
+
x = self.final_layer(x)
|
| 150 |
+
|
| 151 |
+
else:
|
| 152 |
+
x = self.conv_block(voxel_embeddings)
|
| 153 |
+
x = self.final_layer(x)
|
| 154 |
+
|
| 155 |
+
rPPG = x.view(-1, length)
|
| 156 |
+
|
| 157 |
+
if self.debug:
|
| 158 |
+
print(" rPPG.shape", rPPG.shape)
|
| 159 |
+
|
| 160 |
+
if (self.md_infer or self.training or self.debug) and self.use_fsam:
|
| 161 |
+
return rPPG, factorized_embeddings, appx_error
|
| 162 |
+
else:
|
| 163 |
+
return rPPG
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class FactorizePhysBig(nn.Module):
|
| 168 |
+
def __init__(self, frames, md_config, in_channels=3, dropout=0.1, device=torch.device("cpu"), debug=False):
|
| 169 |
+
super(FactorizePhysBig, self).__init__()
|
| 170 |
+
self.debug = debug
|
| 171 |
+
|
| 172 |
+
self.in_channels = in_channels
|
| 173 |
+
if self.in_channels == 1 or self.in_channels == 3:
|
| 174 |
+
self.norm = nn.InstanceNorm3d(self.in_channels)
|
| 175 |
+
elif self.in_channels == 4:
|
| 176 |
+
self.rgb_norm = nn.InstanceNorm3d(3)
|
| 177 |
+
self.thermal_norm = nn.InstanceNorm3d(1)
|
| 178 |
+
else:
|
| 179 |
+
print("Unsupported input channels")
|
| 180 |
+
|
| 181 |
+
self.use_fsam = md_config["MD_FSAM"]
|
| 182 |
+
self.md_infer = md_config["MD_INFERENCE"]
|
| 183 |
+
|
| 184 |
+
for key in model_config:
|
| 185 |
+
if key not in md_config:
|
| 186 |
+
md_config[key] = model_config[key]
|
| 187 |
+
|
| 188 |
+
if self.debug:
|
| 189 |
+
print("nf:", nf)
|
| 190 |
+
|
| 191 |
+
self.rppg_feature_extractor = rPPG_FeatureExtractor(self.in_channels, dropout_rate=dropout, debug=debug)
|
| 192 |
+
|
| 193 |
+
self.rppg_head = BVP_Head(md_config, device=device, dropout_rate=dropout, debug=debug)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32]
|
| 197 |
+
|
| 198 |
+
[batch, channel, length, width, height] = x.shape
|
| 199 |
+
|
| 200 |
+
# if self.in_channels == 1:
|
| 201 |
+
# x = x[:, :, :-1, :, :]
|
| 202 |
+
# else:
|
| 203 |
+
# x = torch.diff(x, dim=2)
|
| 204 |
+
|
| 205 |
+
x = torch.diff(x, dim=2)
|
| 206 |
+
|
| 207 |
+
if self.debug:
|
| 208 |
+
print("Input.shape", x.shape)
|
| 209 |
+
|
| 210 |
+
if self.in_channels == 1:
|
| 211 |
+
x = self.norm(x[:, -1:, :, :, :])
|
| 212 |
+
elif self.in_channels == 3:
|
| 213 |
+
x = self.norm(x[:, :3, :, :, :])
|
| 214 |
+
elif self.in_channels == 4:
|
| 215 |
+
rgb_x = self.rgb_norm(x[:, :3, :, :, :])
|
| 216 |
+
thermal_x = self.thermal_norm(x[:, -1:, :, :, :])
|
| 217 |
+
x = torch.concat([rgb_x, thermal_x], dim = 1)
|
| 218 |
+
else:
|
| 219 |
+
try:
|
| 220 |
+
print("Specified input channels:", self.in_channels)
|
| 221 |
+
print("Data channels", channel)
|
| 222 |
+
assert self.in_channels <= channel
|
| 223 |
+
except:
|
| 224 |
+
print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels")
|
| 225 |
+
print("Default or specified channels:", self.in_channels)
|
| 226 |
+
print("Data channels [B, C, N, W, H]", x.shape)
|
| 227 |
+
print("Exiting")
|
| 228 |
+
exit()
|
| 229 |
+
|
| 230 |
+
if self.debug:
|
| 231 |
+
print("Diff Normalized shape", x.shape)
|
| 232 |
+
|
| 233 |
+
voxel_embeddings = self.rppg_feature_extractor(x)
|
| 234 |
+
|
| 235 |
+
if (self.md_infer or self.training or self.debug) and self.use_fsam:
|
| 236 |
+
rPPG, factorized_embeddings, appx_error = self.rppg_head(voxel_embeddings, batch, length-1)
|
| 237 |
+
else:
|
| 238 |
+
rPPG = self.rppg_head(voxel_embeddings, batch, length-1)
|
| 239 |
+
|
| 240 |
+
# if self.debug:
|
| 241 |
+
# print("rppg_feats.shape", rppg_feats.shape)
|
| 242 |
+
|
| 243 |
+
# rPPG = rppg_feats.view(-1, length-1)
|
| 244 |
+
|
| 245 |
+
if self.debug:
|
| 246 |
+
print("rPPG.shape", rPPG.shape)
|
| 247 |
+
|
| 248 |
+
if (self.md_infer or self.training or self.debug) and self.use_fsam:
|
| 249 |
+
return rPPG, voxel_embeddings, factorized_embeddings, appx_error
|
| 250 |
+
else:
|
| 251 |
+
return rPPG, voxel_embeddings
|
neural_methods/model/FactorizePhys/__init__.py
ADDED
|
File without changes
|
neural_methods/model/FactorizePhys/test_FactorizePhys.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
|
| 3 |
+
NeurIPS 2024
|
| 4 |
+
Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.cm as cm
|
| 12 |
+
from scipy.signal import resample
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys
|
| 17 |
+
# from torch.utils.tensorboard import SummaryWriter
|
| 18 |
+
|
| 19 |
+
model_config = {
|
| 20 |
+
"MD_FSAM": True,
|
| 21 |
+
"MD_TYPE": "NMF",
|
| 22 |
+
"MD_TRANSFORM": "T_KAB",
|
| 23 |
+
"MD_R": 1,
|
| 24 |
+
"MD_S": 1,
|
| 25 |
+
"MD_STEPS": 4,
|
| 26 |
+
"MD_INFERENCE": True,
|
| 27 |
+
"MD_RESIDUAL": True,
|
| 28 |
+
"in_channels": 3,
|
| 29 |
+
"data_channels": 4,
|
| 30 |
+
"height": 72,
|
| 31 |
+
"weight": 72,
|
| 32 |
+
"batch_size": 2,
|
| 33 |
+
"frames": 160,
|
| 34 |
+
"debug": True,
|
| 35 |
+
"assess_latency": False,
|
| 36 |
+
"num_trials": 20,
|
| 37 |
+
"visualize": False,
|
| 38 |
+
"ckpt_path": "./final_model_release/iBVP_FactorizePhys_FSAM_Res.pth",
|
| 39 |
+
"data_path": "/mnt/sda/data/prep/iBVP_Dataset/iBVP_RGB_160_72x72",
|
| 40 |
+
"label_path": "/mnt/sda/data/prep/iBVP_Dataset/iBVP_RGB_160_72x72"
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TestFactorizePhysBig(object):
|
| 45 |
+
def __init__(self) -> None:
|
| 46 |
+
self.ckpt_path = Path(model_config["ckpt_path"])
|
| 47 |
+
self.data_path = Path(model_config["data_path"])
|
| 48 |
+
self.label_path = Path(model_config["label_path"])
|
| 49 |
+
|
| 50 |
+
self.use_fsam = model_config["MD_FSAM"]
|
| 51 |
+
self.md_infer = model_config["MD_INFERENCE"]
|
| 52 |
+
|
| 53 |
+
self.batch_size = model_config["batch_size"]
|
| 54 |
+
self.frames = model_config["frames"]
|
| 55 |
+
self.in_channels = model_config["in_channels"]
|
| 56 |
+
self.data_channels = model_config["data_channels"]
|
| 57 |
+
self.height = model_config["height"]
|
| 58 |
+
self.width = model_config["weight"]
|
| 59 |
+
self.debug = bool(model_config["debug"])
|
| 60 |
+
self.assess_latency = bool(model_config["assess_latency"])
|
| 61 |
+
self.visualize = model_config["visualize"]
|
| 62 |
+
|
| 63 |
+
if self.visualize:
|
| 64 |
+
self.data_files = list(sorted(self.data_path.rglob("*input*.npy")))
|
| 65 |
+
self.label_files = list(sorted(self.data_path.rglob("*label*.npy")))
|
| 66 |
+
self.num_trials = len(self.data_files)
|
| 67 |
+
|
| 68 |
+
self.plot_dir = Path.cwd().joinpath("plots").joinpath("inference")
|
| 69 |
+
self.plot_dir.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
self.attention_map_dir = self.plot_dir.joinpath("attention_maps").joinpath(self.data_path.name).joinpath(self.ckpt_path.name)
|
| 72 |
+
self.attention_map_dir.mkdir(parents=True, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
else:
|
| 75 |
+
if self.assess_latency:
|
| 76 |
+
self.num_trials = model_config["num_trials"]
|
| 77 |
+
else:
|
| 78 |
+
self.num_trials = 1
|
| 79 |
+
|
| 80 |
+
if torch.cuda.is_available():
|
| 81 |
+
self.device = torch.device(0)
|
| 82 |
+
else:
|
| 83 |
+
self.device = torch.device("cpu")
|
| 84 |
+
|
| 85 |
+
md_config = {}
|
| 86 |
+
md_config["FRAME_NUM"] = model_config["frames"]
|
| 87 |
+
md_config["MD_S"] = model_config["MD_S"]
|
| 88 |
+
md_config["MD_R"] = model_config["MD_R"]
|
| 89 |
+
md_config["MD_STEPS"] = model_config["MD_STEPS"]
|
| 90 |
+
md_config["MD_FSAM"] = model_config["MD_FSAM"]
|
| 91 |
+
md_config["MD_TYPE"] = model_config["MD_TYPE"]
|
| 92 |
+
md_config["MD_TRANSFORM"] = model_config["MD_TRANSFORM"]
|
| 93 |
+
md_config["MD_INFERENCE"] = model_config["MD_INFERENCE"]
|
| 94 |
+
md_config["MD_RESIDUAL"] = model_config["MD_RESIDUAL"]
|
| 95 |
+
|
| 96 |
+
if self.visualize:
|
| 97 |
+
self.net = nn.DataParallel(FactorizePhys(frames=self.frames, md_config=md_config,
|
| 98 |
+
device=self.device, in_channels=self.in_channels, debug=self.debug), device_ids=[0]).to(self.device)
|
| 99 |
+
self.net.load_state_dict(torch.load(str(self.ckpt_path), map_location=self.device))
|
| 100 |
+
else:
|
| 101 |
+
self.net = FactorizePhys(frames=self.frames, md_config=md_config,
|
| 102 |
+
device=self.device, in_channels=self.in_channels, debug=self.debug).to(self.device)
|
| 103 |
+
|
| 104 |
+
self.net.eval()
|
| 105 |
+
if self.assess_latency:
|
| 106 |
+
self.time_vec = []
|
| 107 |
+
|
| 108 |
+
if self.debug:
|
| 109 |
+
self.appx_error_list = []
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_data(self, num_trial):
|
| 113 |
+
|
| 114 |
+
if self.visualize:
|
| 115 |
+
self.np_data = np.load(str(self.data_files[num_trial]))
|
| 116 |
+
self.np_label = np.load(str(self.label_files[num_trial]))
|
| 117 |
+
self.np_label = np.expand_dims(self.np_label, 0)
|
| 118 |
+
self.np_label = torch.tensor(self.np_label)
|
| 119 |
+
|
| 120 |
+
# print("Chunk data shape", self.np_data.shape)
|
| 121 |
+
# print("Chunk label shape", self.np_label.shape)
|
| 122 |
+
# print("Min Max of input data:", np.min(self.np_data), np.max(self.np_data))
|
| 123 |
+
# exit()
|
| 124 |
+
|
| 125 |
+
self.test_data = np.transpose(self.np_data, (3, 0, 1, 2))
|
| 126 |
+
self.test_data = torch.from_numpy(self.test_data)
|
| 127 |
+
self.test_data = self.test_data.unsqueeze(0)
|
| 128 |
+
|
| 129 |
+
last_frame = torch.unsqueeze(self.test_data[:, :, -1, :, :], 2).repeat(1, 1, 1, 1, 1)
|
| 130 |
+
self.test_data = torch.cat((self.test_data, last_frame), 2)
|
| 131 |
+
self.test_data = self.test_data.to(torch.float32).to(self.device)
|
| 132 |
+
else:
|
| 133 |
+
self.test_data = torch.rand(self.batch_size, self.data_channels, self.frames + 1, self.height, self.width)
|
| 134 |
+
self.test_data = self.test_data.to(torch.float32).to(self.device)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def run_inference(self, num_trial):
|
| 138 |
+
|
| 139 |
+
if self.visualize:
|
| 140 |
+
print("Processing:", self.data_files[num_trial].name)
|
| 141 |
+
if self.assess_latency:
|
| 142 |
+
t0 = time.time()
|
| 143 |
+
|
| 144 |
+
if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
|
| 145 |
+
self.pred, self.vox_embed, self.factorized_embed, self.appx_error = self.net(self.test_data)
|
| 146 |
+
else:
|
| 147 |
+
self.pred, self.vox_embed = self.net(self.test_data)
|
| 148 |
+
|
| 149 |
+
if self.assess_latency:
|
| 150 |
+
t1 = time.time()
|
| 151 |
+
self.time_vec.append(t1-t0)
|
| 152 |
+
|
| 153 |
+
if self.debug:
|
| 154 |
+
print("pred.shape", self.pred.shape)
|
| 155 |
+
if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
|
| 156 |
+
self.appx_error_list.append(self.appx_error.item())
|
| 157 |
+
|
| 158 |
+
if self.visualize:
|
| 159 |
+
self.save_attention_maps(num_trial)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def save_attention_maps(self, num_trial):
|
| 163 |
+
b, channels, enc_frames, enc_height, enc_width = self.vox_embed.shape
|
| 164 |
+
label_matrix = self.np_label.unsqueeze(0).repeat(1, channels, 1).unsqueeze(
|
| 165 |
+
2).unsqueeze(2).permute(0, 1, 4, 3, 2).repeat(1, 1, 1, enc_height, enc_width)
|
| 166 |
+
label_matrix = label_matrix.to(device=self.device)
|
| 167 |
+
corr_matrix = F.cosine_similarity(self.vox_embed, label_matrix, dim=2).abs()
|
| 168 |
+
|
| 169 |
+
# avg_emb = torch.mean(self.vox_embed, dim=1)
|
| 170 |
+
# b, enc_frames, enc_height, enc_width = avg_emb.shape
|
| 171 |
+
# label_matrix = np_label.unsqueeze(0).unsqueeze(2).permute(0, 3, 2, 1).repeat(1, 1, enc_height, enc_width)
|
| 172 |
+
# label_matrix = label_matrix.to(device=device)
|
| 173 |
+
# corr_matrix = F.cosine_similarity(avg_emb, label_matrix, dim=1)
|
| 174 |
+
|
| 175 |
+
if self.debug:
|
| 176 |
+
print("corr_matrix.shape", corr_matrix.shape)
|
| 177 |
+
print("self.test_data.shape:", self.test_data.shape)
|
| 178 |
+
print("self.vox_embed.shape:", self.vox_embed.shape)
|
| 179 |
+
|
| 180 |
+
self.test_data = self.test_data.detach().cpu().numpy()
|
| 181 |
+
self.vox_embed = self.vox_embed.detach().cpu().numpy()
|
| 182 |
+
corr_matrix = corr_matrix.detach().cpu().numpy()
|
| 183 |
+
|
| 184 |
+
fig, ax = plt.subplots(4, 4, figsize=[16, 16])
|
| 185 |
+
fig.tight_layout()
|
| 186 |
+
|
| 187 |
+
ax[0, 0].imshow(self.np_data[enc_frames//2, ...].astype(np.uint8))
|
| 188 |
+
ax[0, 0].axis('off')
|
| 189 |
+
cmap = "coolwarm"
|
| 190 |
+
|
| 191 |
+
ch = 0
|
| 192 |
+
ax[0, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 193 |
+
ax[0, 1].axis('off')
|
| 194 |
+
|
| 195 |
+
ch = 1
|
| 196 |
+
ax[0, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 197 |
+
ax[0, 2].axis('off')
|
| 198 |
+
|
| 199 |
+
ch = 2
|
| 200 |
+
ax[0, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 201 |
+
ax[0, 3].axis('off')
|
| 202 |
+
|
| 203 |
+
ch = 3
|
| 204 |
+
ax[1, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 205 |
+
ax[1, 0].axis('off')
|
| 206 |
+
|
| 207 |
+
ch = 4
|
| 208 |
+
ax[1, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 209 |
+
ax[1, 1].axis('off')
|
| 210 |
+
|
| 211 |
+
ch = 5
|
| 212 |
+
ax[1, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 213 |
+
ax[1, 2].axis('off')
|
| 214 |
+
|
| 215 |
+
ch = 6
|
| 216 |
+
ax[1, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 217 |
+
ax[1, 3].axis('off')
|
| 218 |
+
|
| 219 |
+
ch = 7
|
| 220 |
+
ax[2, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 221 |
+
ax[2, 0].axis('off')
|
| 222 |
+
|
| 223 |
+
ch = 8
|
| 224 |
+
ax[2, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 225 |
+
ax[2, 1].axis('off')
|
| 226 |
+
|
| 227 |
+
ch = 9
|
| 228 |
+
ax[2, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 229 |
+
ax[2, 2].axis('off')
|
| 230 |
+
|
| 231 |
+
ch = 10
|
| 232 |
+
ax[2, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 233 |
+
ax[2, 3].axis('off')
|
| 234 |
+
|
| 235 |
+
ch = 11
|
| 236 |
+
ax[3, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 237 |
+
ax[3, 0].axis('off')
|
| 238 |
+
|
| 239 |
+
ch = 12
|
| 240 |
+
ax[3, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 241 |
+
ax[3, 1].axis('off')
|
| 242 |
+
|
| 243 |
+
ch = 13
|
| 244 |
+
ax[3, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 245 |
+
ax[3, 2].axis('off')
|
| 246 |
+
|
| 247 |
+
ch = 14
|
| 248 |
+
ax[3, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 249 |
+
ax[3, 3].axis('off')
|
| 250 |
+
|
| 251 |
+
# plt.show()
|
| 252 |
+
plt.savefig(str(self.attention_map_dir.joinpath(str(self.data_files[num_trial].name.replace(".npy", "_attention_map.jpg")))))
|
| 253 |
+
plt.close(fig)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def output_summary_results(self):
|
| 257 |
+
if self.assess_latency:
|
| 258 |
+
print("Median time: ", np.median(self.time_vec))
|
| 259 |
+
plt.plot(self.time_vec)
|
| 260 |
+
plt.savefig(str(self.plot_dir.joinpath("Latency.jpg")))
|
| 261 |
+
|
| 262 |
+
if self.debug:
|
| 263 |
+
if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
|
| 264 |
+
print("Median error:", np.median(self.appx_error_list))
|
| 265 |
+
|
| 266 |
+
pytorch_total_params = sum(p.numel() for p in self.net.parameters())
|
| 267 |
+
print("Total parameters = ", pytorch_total_params)
|
| 268 |
+
|
| 269 |
+
pytorch_trainable_params = sum(p.numel()
|
| 270 |
+
for p in self.net.parameters() if p.requires_grad)
|
| 271 |
+
print("Trainable parameters = ", pytorch_trainable_params)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
|
| 276 |
+
testObj = TestFactorizePhysBig()
|
| 277 |
+
|
| 278 |
+
print("testObj.num_trials:", testObj.num_trials)
|
| 279 |
+
for trial_num in range(testObj.num_trials):
|
| 280 |
+
testObj.load_data(trial_num)
|
| 281 |
+
testObj.run_inference(trial_num)
|
| 282 |
+
|
| 283 |
+
testObj.output_summary_results()
|
| 284 |
+
|
| 285 |
+
# writer.add_graph(net, test_data)
|
| 286 |
+
# writer.close()
|
neural_methods/model/FactorizePhys/test_FactorizePhysBig.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
|
| 3 |
+
NeurIPS 2024
|
| 4 |
+
Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.cm as cm
|
| 12 |
+
from scipy.signal import resample
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig
|
| 18 |
+
# from torch.utils.tensorboard import SummaryWriter
|
| 19 |
+
|
| 20 |
+
model_config = {
|
| 21 |
+
"MD_FSAM": True,
|
| 22 |
+
"MD_TYPE": "NMF",
|
| 23 |
+
"MD_TRANSFORM": "T_KAB",
|
| 24 |
+
"MD_R": 1,
|
| 25 |
+
"MD_S": 1,
|
| 26 |
+
"MD_STEPS": 4,
|
| 27 |
+
"MD_INFERENCE": True,
|
| 28 |
+
"MD_RESIDUAL": True,
|
| 29 |
+
"in_channels": 3,
|
| 30 |
+
"data_channels": 4,
|
| 31 |
+
"height": 128,
|
| 32 |
+
"weight": 128,
|
| 33 |
+
"batch_size": 1,
|
| 34 |
+
"frames": 240,
|
| 35 |
+
"debug": True,
|
| 36 |
+
"assess_latency": False,
|
| 37 |
+
"num_trials": 20,
|
| 38 |
+
"visualize": False,
|
| 39 |
+
# "ckpt_path": "./final_model_release/UBFC-rPPG_Intra_FactorizePhys_Base_HighRes.pth",
|
| 40 |
+
"ckpt_path": "./final_model_release/UBFC-rPPG_Intra_FactorizePhys_FSAM_Res_HighRes.pth",
|
| 41 |
+
"data_path": "/mnt/sda/data/prep/UBFC-rPPG/UBFC-rPPG_Raw_240_128x128",
|
| 42 |
+
"label_path": "/mnt/sda/data/prep/UBFC-rPPG/UBFC-rPPG_Raw_240_128x128"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# default `log_dir` is "runs" - we'll be more specific here
|
| 46 |
+
# writer = SummaryWriter('runs/FactorizePhys')
|
| 47 |
+
|
| 48 |
+
class TestFactorizePhysBig(object):
|
| 49 |
+
def __init__(self) -> None:
|
| 50 |
+
self.ckpt_path = Path(model_config["ckpt_path"])
|
| 51 |
+
self.data_path = Path(model_config["data_path"])
|
| 52 |
+
self.label_path = Path(model_config["label_path"])
|
| 53 |
+
|
| 54 |
+
self.use_fsam = model_config["MD_FSAM"]
|
| 55 |
+
self.md_infer = model_config["MD_INFERENCE"]
|
| 56 |
+
|
| 57 |
+
self.batch_size = model_config["batch_size"]
|
| 58 |
+
self.frames = model_config["frames"]
|
| 59 |
+
self.in_channels = model_config["in_channels"]
|
| 60 |
+
self.data_channels = model_config["data_channels"]
|
| 61 |
+
self.height = model_config["height"]
|
| 62 |
+
self.width = model_config["weight"]
|
| 63 |
+
self.debug = bool(model_config["debug"])
|
| 64 |
+
self.assess_latency = bool(model_config["assess_latency"])
|
| 65 |
+
self.visualize = model_config["visualize"]
|
| 66 |
+
|
| 67 |
+
if self.visualize:
|
| 68 |
+
# self.data_files = list(sorted(self.data_path.rglob("*subject12*input*.npy")))
|
| 69 |
+
# self.label_files = list(sorted(self.data_path.rglob("*subject12*label*.npy")))
|
| 70 |
+
self.data_files = list(sorted(self.data_path.rglob("*input*.npy")))
|
| 71 |
+
self.label_files = list(sorted(self.data_path.rglob("*label*.npy")))
|
| 72 |
+
self.num_trials = len(self.data_files)
|
| 73 |
+
|
| 74 |
+
self.plot_dir = Path.cwd().joinpath("plots").joinpath("inference")
|
| 75 |
+
self.plot_dir.mkdir(parents=True, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
self.attention_map_dir = self.plot_dir.joinpath("attention_maps").joinpath(self.data_path.name).joinpath(self.ckpt_path.name)
|
| 78 |
+
self.attention_map_dir.mkdir(parents=True, exist_ok=True)
|
| 79 |
+
|
| 80 |
+
else:
|
| 81 |
+
if self.assess_latency:
|
| 82 |
+
self.num_trials = model_config["num_trials"]
|
| 83 |
+
else:
|
| 84 |
+
self.num_trials = 1
|
| 85 |
+
|
| 86 |
+
if torch.cuda.is_available():
|
| 87 |
+
self.device = torch.device(0)
|
| 88 |
+
else:
|
| 89 |
+
self.device = torch.device("cpu")
|
| 90 |
+
|
| 91 |
+
md_config = {}
|
| 92 |
+
md_config["FRAME_NUM"] = model_config["frames"]
|
| 93 |
+
md_config["MD_S"] = model_config["MD_S"]
|
| 94 |
+
md_config["MD_R"] = model_config["MD_R"]
|
| 95 |
+
md_config["MD_STEPS"] = model_config["MD_STEPS"]
|
| 96 |
+
md_config["MD_FSAM"] = model_config["MD_FSAM"]
|
| 97 |
+
md_config["MD_TYPE"] = model_config["MD_TYPE"]
|
| 98 |
+
md_config["MD_TRANSFORM"] = model_config["MD_TRANSFORM"]
|
| 99 |
+
md_config["MD_INFERENCE"] = model_config["MD_INFERENCE"]
|
| 100 |
+
md_config["MD_RESIDUAL"] = model_config["MD_RESIDUAL"]
|
| 101 |
+
|
| 102 |
+
if self.visualize:
|
| 103 |
+
self.net = nn.DataParallel(FactorizePhysBig(frames=self.frames, md_config=md_config,
|
| 104 |
+
device=self.device, in_channels=self.in_channels, debug=self.debug), device_ids=[0]).to(self.device)
|
| 105 |
+
self.net.load_state_dict(torch.load(str(self.ckpt_path), map_location=self.device))
|
| 106 |
+
else:
|
| 107 |
+
self.net = FactorizePhysBig(frames=self.frames, md_config=md_config,
|
| 108 |
+
device=self.device, in_channels=self.in_channels, debug=self.debug).to(self.device)
|
| 109 |
+
|
| 110 |
+
self.net.eval()
|
| 111 |
+
if self.assess_latency:
|
| 112 |
+
self.time_vec = []
|
| 113 |
+
|
| 114 |
+
if self.debug:
|
| 115 |
+
self.appx_error_list = []
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def load_data(self, num_trial):
|
| 119 |
+
|
| 120 |
+
if self.visualize:
|
| 121 |
+
self.np_data = np.load(str(self.data_files[num_trial]))
|
| 122 |
+
self.np_label = np.load(str(self.label_files[num_trial]))
|
| 123 |
+
self.np_label = np.expand_dims(self.np_label, 0)
|
| 124 |
+
self.np_label = torch.tensor(self.np_label)
|
| 125 |
+
|
| 126 |
+
# print("Chunk data shape", self.np_data.shape)
|
| 127 |
+
# print("Chunk label shape", self.np_label.shape)
|
| 128 |
+
# print("Min Max of input data:", np.min(self.np_data), np.max(self.np_data))
|
| 129 |
+
# exit()
|
| 130 |
+
|
| 131 |
+
self.test_data = np.transpose(self.np_data, (3, 0, 1, 2))
|
| 132 |
+
self.test_data = torch.from_numpy(self.test_data)
|
| 133 |
+
self.test_data = self.test_data.unsqueeze(0)
|
| 134 |
+
|
| 135 |
+
last_frame = torch.unsqueeze(self.test_data[:, :, -1, :, :], 2).repeat(1, 1, 1, 1, 1)
|
| 136 |
+
self.test_data = torch.cat((self.test_data, last_frame), 2)
|
| 137 |
+
self.test_data = self.test_data.to(torch.float32).to(self.device)
|
| 138 |
+
else:
|
| 139 |
+
self.test_data = torch.rand(self.batch_size, self.data_channels, self.frames + 1, self.height, self.width)
|
| 140 |
+
self.test_data = self.test_data.to(torch.float32).to(self.device)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def run_inference(self, num_trial):
|
| 144 |
+
|
| 145 |
+
if self.visualize:
|
| 146 |
+
print("Processing:", self.data_files[num_trial].name)
|
| 147 |
+
if self.assess_latency:
|
| 148 |
+
t0 = time.time()
|
| 149 |
+
|
| 150 |
+
if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
|
| 151 |
+
self.pred, self.vox_embed, self.factorized_embed, self.appx_error = self.net(self.test_data)
|
| 152 |
+
else:
|
| 153 |
+
self.pred, self.vox_embed = self.net(self.test_data)
|
| 154 |
+
|
| 155 |
+
if self.assess_latency:
|
| 156 |
+
t1 = time.time()
|
| 157 |
+
self.time_vec.append(t1-t0)
|
| 158 |
+
|
| 159 |
+
if self.debug:
|
| 160 |
+
print("pred.shape", self.pred.shape)
|
| 161 |
+
if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
|
| 162 |
+
self.appx_error_list.append(self.appx_error.item())
|
| 163 |
+
|
| 164 |
+
if self.visualize:
|
| 165 |
+
self.save_attention_maps(num_trial)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def save_attention_maps(self, num_trial):
|
| 169 |
+
b, channels, enc_frames, enc_height, enc_width = self.vox_embed.shape
|
| 170 |
+
label_matrix = self.np_label.unsqueeze(0).repeat(1, channels, 1).unsqueeze(
|
| 171 |
+
2).unsqueeze(2).permute(0, 1, 4, 3, 2).repeat(1, 1, 1, enc_height, enc_width)
|
| 172 |
+
label_matrix = label_matrix.to(device=self.device)
|
| 173 |
+
corr_matrix = F.cosine_similarity(self.vox_embed, label_matrix, dim=2).abs()
|
| 174 |
+
|
| 175 |
+
# avg_emb = torch.mean(self.vox_embed, dim=1)
|
| 176 |
+
# b, enc_frames, enc_height, enc_width = avg_emb.shape
|
| 177 |
+
# label_matrix = np_label.unsqueeze(0).unsqueeze(2).permute(0, 3, 2, 1).repeat(1, 1, enc_height, enc_width)
|
| 178 |
+
# label_matrix = label_matrix.to(device=device)
|
| 179 |
+
# corr_matrix = F.cosine_similarity(avg_emb, label_matrix, dim=1)
|
| 180 |
+
|
| 181 |
+
if self.debug:
|
| 182 |
+
print("corr_matrix.shape", corr_matrix.shape)
|
| 183 |
+
print("self.test_data.shape:", self.test_data.shape)
|
| 184 |
+
print("self.vox_embed.shape:", self.vox_embed.shape)
|
| 185 |
+
|
| 186 |
+
self.test_data = self.test_data.detach().cpu().numpy()
|
| 187 |
+
self.vox_embed = self.vox_embed.detach().cpu().numpy()
|
| 188 |
+
corr_matrix = corr_matrix.detach().cpu().numpy()
|
| 189 |
+
|
| 190 |
+
fig, ax = plt.subplots(4, 4, figsize=[16, 16])
|
| 191 |
+
fig.tight_layout()
|
| 192 |
+
cmap = "coolwarm"
|
| 193 |
+
|
| 194 |
+
ax[0, 0].imshow(self.np_data[enc_frames//2, ...].astype(np.uint8))
|
| 195 |
+
ax[0, 0].axis('off')
|
| 196 |
+
|
| 197 |
+
ch = 0
|
| 198 |
+
ax[0, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 199 |
+
ax[0, 1].axis('off')
|
| 200 |
+
|
| 201 |
+
ch = 1
|
| 202 |
+
ax[0, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 203 |
+
ax[0, 2].axis('off')
|
| 204 |
+
|
| 205 |
+
ch = 2
|
| 206 |
+
ax[0, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 207 |
+
ax[0, 3].axis('off')
|
| 208 |
+
|
| 209 |
+
ch = 3
|
| 210 |
+
ax[1, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 211 |
+
ax[1, 0].axis('off')
|
| 212 |
+
|
| 213 |
+
ch = 4
|
| 214 |
+
ax[1, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 215 |
+
ax[1, 1].axis('off')
|
| 216 |
+
|
| 217 |
+
ch = 5
|
| 218 |
+
ax[1, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 219 |
+
ax[1, 2].axis('off')
|
| 220 |
+
|
| 221 |
+
ch = 6
|
| 222 |
+
ax[1, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 223 |
+
ax[1, 3].axis('off')
|
| 224 |
+
|
| 225 |
+
ch = 7
|
| 226 |
+
ax[2, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 227 |
+
ax[2, 0].axis('off')
|
| 228 |
+
|
| 229 |
+
ch = 8
|
| 230 |
+
ax[2, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 231 |
+
ax[2, 1].axis('off')
|
| 232 |
+
|
| 233 |
+
ch = 9
|
| 234 |
+
ax[2, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 235 |
+
ax[2, 2].axis('off')
|
| 236 |
+
|
| 237 |
+
ch = 10
|
| 238 |
+
ax[2, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 239 |
+
ax[2, 3].axis('off')
|
| 240 |
+
|
| 241 |
+
ch = 11
|
| 242 |
+
ax[3, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 243 |
+
ax[3, 0].axis('off')
|
| 244 |
+
|
| 245 |
+
ch = 12
|
| 246 |
+
ax[3, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 247 |
+
ax[3, 1].axis('off')
|
| 248 |
+
|
| 249 |
+
ch = 13
|
| 250 |
+
ax[3, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 251 |
+
ax[3, 2].axis('off')
|
| 252 |
+
|
| 253 |
+
ch = 14
|
| 254 |
+
ax[3, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1)
|
| 255 |
+
ax[3, 3].axis('off')
|
| 256 |
+
|
| 257 |
+
# plt.show()
|
| 258 |
+
plt.savefig(str(self.attention_map_dir.joinpath(str(self.data_files[num_trial].name.replace(".npy", "_attention_map.jpg")))))
|
| 259 |
+
plt.close(fig)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def output_summary_results(self):
|
| 263 |
+
if self.assess_latency:
|
| 264 |
+
print("Median time: ", np.median(self.time_vec))
|
| 265 |
+
plt.plot(self.time_vec)
|
| 266 |
+
plt.savefig(str(self.plot_dir.joinpath("Latency.jpg")))
|
| 267 |
+
|
| 268 |
+
if self.debug:
|
| 269 |
+
if (self.md_infer or self.net.training or self.debug) and self.use_fsam:
|
| 270 |
+
print("Median error:", np.median(self.appx_error_list))
|
| 271 |
+
|
| 272 |
+
pytorch_total_params = sum(p.numel() for p in self.net.parameters())
|
| 273 |
+
print("Total parameters = ", pytorch_total_params)
|
| 274 |
+
|
| 275 |
+
pytorch_trainable_params = sum(p.numel()
|
| 276 |
+
for p in self.net.parameters() if p.requires_grad)
|
| 277 |
+
print("Trainable parameters = ", pytorch_trainable_params)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
|
| 282 |
+
testObj = TestFactorizePhysBig()
|
| 283 |
+
|
| 284 |
+
print("testObj.num_trials:", testObj.num_trials)
|
| 285 |
+
for trial_num in range(testObj.num_trials):
|
| 286 |
+
testObj.load_data(trial_num)
|
| 287 |
+
testObj.run_inference(trial_num)
|
| 288 |
+
|
| 289 |
+
testObj.output_summary_results()
|
| 290 |
+
|
| 291 |
+
# writer.add_graph(net, test_data)
|
| 292 |
+
# writer.close()
|
neural_methods/model/PhysFormer.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file is a combination of Physformer.py and transformer_layer.py
|
| 2 |
+
in the official PhysFormer implementation here:
|
| 3 |
+
https://github.com/ZitongYu/PhysFormer
|
| 4 |
+
|
| 5 |
+
model.py - Model and module class for ViT.
|
| 6 |
+
They are built to mirror those in the official Jax implementation.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import Optional
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
def as_tuple(x):
|
| 18 |
+
return x if isinstance(x, tuple) else (x, x)
|
| 19 |
+
|
| 20 |
+
'''
|
| 21 |
+
Temporal Center-difference based Convolutional layer (3D version)
|
| 22 |
+
theta: control the percentage of original convolution and centeral-difference convolution
|
| 23 |
+
'''
|
| 24 |
+
class CDC_T(nn.Module):
|
| 25 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
| 26 |
+
padding=1, dilation=1, groups=1, bias=False, theta=0.6):
|
| 27 |
+
|
| 28 |
+
super(CDC_T, self).__init__()
|
| 29 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
| 30 |
+
dilation=dilation, groups=groups, bias=bias)
|
| 31 |
+
self.theta = theta
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
out_normal = self.conv(x)
|
| 35 |
+
|
| 36 |
+
if math.fabs(self.theta - 0.0) < 1e-8:
|
| 37 |
+
return out_normal
|
| 38 |
+
else:
|
| 39 |
+
[C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
|
| 40 |
+
|
| 41 |
+
# only CD works on temporal kernel size>1
|
| 42 |
+
if self.conv.weight.shape[2] > 1:
|
| 43 |
+
kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
|
| 44 |
+
2).sum(2)
|
| 45 |
+
kernel_diff = kernel_diff[:, :, None, None, None]
|
| 46 |
+
out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
|
| 47 |
+
padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
|
| 48 |
+
return out_normal - self.theta * out_diff
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
return out_normal
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def split_last(x, shape):
|
| 55 |
+
"split the last dimension to given shape"
|
| 56 |
+
shape = list(shape)
|
| 57 |
+
assert shape.count(-1) <= 1
|
| 58 |
+
if -1 in shape:
|
| 59 |
+
shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
|
| 60 |
+
return x.view(*x.size()[:-1], *shape)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def merge_last(x, n_dims):
|
| 64 |
+
"merge the last n_dims to a dimension"
|
| 65 |
+
s = x.size()
|
| 66 |
+
assert n_dims > 1 and n_dims < len(s)
|
| 67 |
+
return x.view(*s[:-n_dims], -1)
|
| 68 |
+
|
| 69 |
+
class MultiHeadedSelfAttention_TDC_gra_sharp(nn.Module):
|
| 70 |
+
"""Multi-Headed Dot Product Attention with depth-wise Conv3d"""
|
| 71 |
+
def __init__(self, dim, num_heads, dropout, theta):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
self.proj_q = nn.Sequential(
|
| 75 |
+
CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=theta),
|
| 76 |
+
nn.BatchNorm3d(dim),
|
| 77 |
+
)
|
| 78 |
+
self.proj_k = nn.Sequential(
|
| 79 |
+
CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=theta),
|
| 80 |
+
nn.BatchNorm3d(dim),
|
| 81 |
+
)
|
| 82 |
+
self.proj_v = nn.Sequential(
|
| 83 |
+
nn.Conv3d(dim, dim, 1, stride=1, padding=0, groups=1, bias=False),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.drop = nn.Dropout(dropout)
|
| 87 |
+
self.n_heads = num_heads
|
| 88 |
+
self.scores = None # for visualization
|
| 89 |
+
|
| 90 |
+
def forward(self, x, gra_sharp): # [B, 4*4*40, 128]
|
| 91 |
+
"""
|
| 92 |
+
x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
|
| 93 |
+
mask : (B(batch_size) x S(seq_len))
|
| 94 |
+
* split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
|
| 95 |
+
"""
|
| 96 |
+
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
|
| 97 |
+
|
| 98 |
+
[B, P, C]=x.shape
|
| 99 |
+
x = x.transpose(1, 2).view(B, C, P//16, 4, 4) # [B, dim, 40, 4, 4]
|
| 100 |
+
q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
|
| 101 |
+
q = q.flatten(2).transpose(1, 2) # [B, 4*4*40, dim]
|
| 102 |
+
k = k.flatten(2).transpose(1, 2) # [B, 4*4*40, dim]
|
| 103 |
+
v = v.flatten(2).transpose(1, 2) # [B, 4*4*40, dim]
|
| 104 |
+
|
| 105 |
+
q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
|
| 106 |
+
# (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
|
| 107 |
+
scores = q @ k.transpose(-2, -1) / gra_sharp
|
| 108 |
+
|
| 109 |
+
scores = self.drop(F.softmax(scores, dim=-1))
|
| 110 |
+
# (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
|
| 111 |
+
h = (scores @ v).transpose(1, 2).contiguous()
|
| 112 |
+
# -merge-> (B, S, D)
|
| 113 |
+
h = merge_last(h, 2)
|
| 114 |
+
self.scores = scores
|
| 115 |
+
return h, scores
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class PositionWiseFeedForward_ST(nn.Module):
|
| 121 |
+
"""FeedForward Neural Networks for each position"""
|
| 122 |
+
def __init__(self, dim, ff_dim):
|
| 123 |
+
super().__init__()
|
| 124 |
+
|
| 125 |
+
self.fc1 = nn.Sequential(
|
| 126 |
+
nn.Conv3d(dim, ff_dim, 1, stride=1, padding=0, bias=False),
|
| 127 |
+
nn.BatchNorm3d(ff_dim),
|
| 128 |
+
nn.ELU(),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.STConv = nn.Sequential(
|
| 132 |
+
nn.Conv3d(ff_dim, ff_dim, 3, stride=1, padding=1, groups=ff_dim, bias=False),
|
| 133 |
+
nn.BatchNorm3d(ff_dim),
|
| 134 |
+
nn.ELU(),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.fc2 = nn.Sequential(
|
| 138 |
+
nn.Conv3d(ff_dim, dim, 1, stride=1, padding=0, bias=False),
|
| 139 |
+
nn.BatchNorm3d(dim),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, x): # [B, 4*4*40, 128]
|
| 143 |
+
[B, P, C]=x.shape
|
| 144 |
+
x = x.transpose(1, 2).view(B, C, P//16, 4, 4) # [B, dim, 40, 4, 4]
|
| 145 |
+
x = self.fc1(x) # x [B, ff_dim, 40, 4, 4]
|
| 146 |
+
x = self.STConv(x) # x [B, ff_dim, 40, 4, 4]
|
| 147 |
+
x = self.fc2(x) # x [B, dim, 40, 4, 4]
|
| 148 |
+
x = x.flatten(2).transpose(1, 2) # [B, 4*4*40, dim]
|
| 149 |
+
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
class Block_ST_TDC_gra_sharp(nn.Module):
|
| 153 |
+
"""Transformer Block"""
|
| 154 |
+
def __init__(self, dim, num_heads, ff_dim, dropout, theta):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.attn = MultiHeadedSelfAttention_TDC_gra_sharp(dim, num_heads, dropout, theta)
|
| 157 |
+
self.proj = nn.Linear(dim, dim)
|
| 158 |
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
| 159 |
+
self.pwff = PositionWiseFeedForward_ST(dim, ff_dim)
|
| 160 |
+
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
| 161 |
+
self.drop = nn.Dropout(dropout)
|
| 162 |
+
|
| 163 |
+
def forward(self, x, gra_sharp):
|
| 164 |
+
Atten, Score = self.attn(self.norm1(x), gra_sharp)
|
| 165 |
+
h = self.drop(self.proj(Atten))
|
| 166 |
+
x = x + h
|
| 167 |
+
h = self.drop(self.pwff(self.norm2(x)))
|
| 168 |
+
x = x + h
|
| 169 |
+
return x, Score
|
| 170 |
+
|
| 171 |
+
class Transformer_ST_TDC_gra_sharp(nn.Module):
|
| 172 |
+
"""Transformer with Self-Attentive Blocks"""
|
| 173 |
+
def __init__(self, num_layers, dim, num_heads, ff_dim, dropout, theta):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.blocks = nn.ModuleList([
|
| 176 |
+
Block_ST_TDC_gra_sharp(dim, num_heads, ff_dim, dropout, theta) for _ in range(num_layers)])
|
| 177 |
+
|
| 178 |
+
def forward(self, x, gra_sharp):
|
| 179 |
+
for block in self.blocks:
|
| 180 |
+
x, Score = block(x, gra_sharp)
|
| 181 |
+
return x, Score
|
| 182 |
+
|
| 183 |
+
# stem_3DCNN + ST-ViT with local Depthwise Spatio-Temporal MLP
|
| 184 |
+
class ViT_ST_ST_Compact3_TDC_gra_sharp(nn.Module):
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
name: Optional[str] = None,
|
| 189 |
+
pretrained: bool = False,
|
| 190 |
+
patches: int = 16,
|
| 191 |
+
dim: int = 768,
|
| 192 |
+
ff_dim: int = 3072,
|
| 193 |
+
num_heads: int = 12,
|
| 194 |
+
num_layers: int = 12,
|
| 195 |
+
attention_dropout_rate: float = 0.0,
|
| 196 |
+
dropout_rate: float = 0.2,
|
| 197 |
+
representation_size: Optional[int] = None,
|
| 198 |
+
load_repr_layer: bool = False,
|
| 199 |
+
classifier: str = 'token',
|
| 200 |
+
#positional_embedding: str = '1d',
|
| 201 |
+
in_channels: int = 3,
|
| 202 |
+
frame: int = 160,
|
| 203 |
+
theta: float = 0.2,
|
| 204 |
+
image_size: Optional[int] = None,
|
| 205 |
+
):
|
| 206 |
+
super().__init__()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
self.image_size = image_size
|
| 210 |
+
self.frame = frame
|
| 211 |
+
self.dim = dim
|
| 212 |
+
|
| 213 |
+
# Image and patch sizes
|
| 214 |
+
t, h, w = as_tuple(image_size) # tube sizes
|
| 215 |
+
ft, fh, fw = as_tuple(patches) # patch sizes, ft = 4 ==> 160/4=40
|
| 216 |
+
gt, gh, gw = t//ft, h // fh, w // fw # number of patches
|
| 217 |
+
seq_len = gh * gw * gt
|
| 218 |
+
|
| 219 |
+
# Patch embedding [4x16x16]conv
|
| 220 |
+
self.patch_embedding = nn.Conv3d(dim, dim, kernel_size=(ft, fh, fw), stride=(ft, fh, fw))
|
| 221 |
+
|
| 222 |
+
# Transformer
|
| 223 |
+
self.transformer1 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads,
|
| 224 |
+
ff_dim=ff_dim, dropout=dropout_rate, theta=theta)
|
| 225 |
+
# Transformer
|
| 226 |
+
self.transformer2 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads,
|
| 227 |
+
ff_dim=ff_dim, dropout=dropout_rate, theta=theta)
|
| 228 |
+
# Transformer
|
| 229 |
+
self.transformer3 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads,
|
| 230 |
+
ff_dim=ff_dim, dropout=dropout_rate, theta=theta)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
self.Stem0 = nn.Sequential(
|
| 235 |
+
nn.Conv3d(3, dim//4, [1, 5, 5], stride=1, padding=[0,2,2]),
|
| 236 |
+
nn.BatchNorm3d(dim//4),
|
| 237 |
+
nn.ReLU(inplace=True),
|
| 238 |
+
nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.Stem1 = nn.Sequential(
|
| 242 |
+
nn.Conv3d(dim//4, dim//2, [3, 3, 3], stride=1, padding=1),
|
| 243 |
+
nn.BatchNorm3d(dim//2),
|
| 244 |
+
nn.ReLU(inplace=True),
|
| 245 |
+
nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
|
| 246 |
+
)
|
| 247 |
+
self.Stem2 = nn.Sequential(
|
| 248 |
+
nn.Conv3d(dim//2, dim, [3, 3, 3], stride=1, padding=1),
|
| 249 |
+
nn.BatchNorm3d(dim),
|
| 250 |
+
nn.ReLU(inplace=True),
|
| 251 |
+
nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
self.upsample = nn.Sequential(
|
| 255 |
+
nn.Upsample(scale_factor=(2,1,1)),
|
| 256 |
+
nn.Conv3d(dim, dim, [3, 1, 1], stride=1, padding=(1,0,0)),
|
| 257 |
+
nn.BatchNorm3d(dim),
|
| 258 |
+
nn.ELU(),
|
| 259 |
+
)
|
| 260 |
+
self.upsample2 = nn.Sequential(
|
| 261 |
+
nn.Upsample(scale_factor=(2,1,1)),
|
| 262 |
+
nn.Conv3d(dim, dim//2, [3, 1, 1], stride=1, padding=(1,0,0)),
|
| 263 |
+
nn.BatchNorm3d(dim//2),
|
| 264 |
+
nn.ELU(),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self.ConvBlockLast = nn.Conv1d(dim//2, 1, 1,stride=1, padding=0)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# Initialize weights
|
| 271 |
+
self.init_weights()
|
| 272 |
+
|
| 273 |
+
@torch.no_grad()
|
| 274 |
+
def init_weights(self):
|
| 275 |
+
def _init(m):
|
| 276 |
+
if isinstance(m, nn.Linear):
|
| 277 |
+
nn.init.xavier_uniform_(m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal
|
| 278 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 279 |
+
nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0)
|
| 280 |
+
self.apply(_init)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def forward(self, x, gra_sharp):
|
| 284 |
+
|
| 285 |
+
# b is batch number, c channels, t frame, fh frame height, and fw frame width
|
| 286 |
+
b, c, t, fh, fw = x.shape
|
| 287 |
+
|
| 288 |
+
x = self.Stem0(x)
|
| 289 |
+
x = self.Stem1(x)
|
| 290 |
+
x = self.Stem2(x) # [B, 64, 160, 64, 64]
|
| 291 |
+
|
| 292 |
+
x = self.patch_embedding(x) # [B, 64, 40, 4, 4]
|
| 293 |
+
x = x.flatten(2).transpose(1, 2) # [B, 40*4*4, 64]
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
Trans_features, Score1 = self.transformer1(x, gra_sharp) # [B, 4*4*40, 64]
|
| 297 |
+
Trans_features2, Score2 = self.transformer2(Trans_features, gra_sharp) # [B, 4*4*40, 64]
|
| 298 |
+
Trans_features3, Score3 = self.transformer3(Trans_features2, gra_sharp) # [B, 4*4*40, 64]
|
| 299 |
+
|
| 300 |
+
# upsampling heads
|
| 301 |
+
#features_last = Trans_features3.transpose(1, 2).view(b, self.dim, 40, 4, 4) # [B, 64, 40, 4, 4]
|
| 302 |
+
features_last = Trans_features3.transpose(1, 2).view(b, self.dim, t//4, 4, 4) # [B, 64, 40, 4, 4]
|
| 303 |
+
|
| 304 |
+
features_last = self.upsample(features_last) # x [B, 64, 7*7, 80]
|
| 305 |
+
features_last = self.upsample2(features_last) # x [B, 32, 7*7, 160]
|
| 306 |
+
|
| 307 |
+
features_last = torch.mean(features_last,3) # x [B, 32, 160, 4]
|
| 308 |
+
features_last = torch.mean(features_last,3) # x [B, 32, 160]
|
| 309 |
+
rPPG = self.ConvBlockLast(features_last) # x [B, 1, 160]
|
| 310 |
+
|
| 311 |
+
rPPG = rPPG.squeeze(1)
|
| 312 |
+
|
| 313 |
+
return rPPG, Score1, Score2, Score3
|
neural_methods/model/PhysMamba.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 5 |
+
from mamba_ssm import Mamba
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
class ChannelAttention3D(nn.Module):
|
| 9 |
+
def __init__(self, in_channels, reduction):
|
| 10 |
+
super(ChannelAttention3D, self).__init__()
|
| 11 |
+
self.avg_pool = nn.AdaptiveAvgPool3d(1)
|
| 12 |
+
self.max_pool = nn.AdaptiveMaxPool3d(1)
|
| 13 |
+
|
| 14 |
+
self.fc = nn.Sequential(
|
| 15 |
+
nn.Conv3d(in_channels, in_channels // reduction, 1, bias=False),
|
| 16 |
+
nn.ReLU(),
|
| 17 |
+
nn.Conv3d(in_channels // reduction, in_channels, 1, bias=False)
|
| 18 |
+
)
|
| 19 |
+
self.sigmoid = nn.Sigmoid()
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
avg_out = self.fc(self.avg_pool(x))
|
| 23 |
+
max_out = self.fc(self.max_pool(x))
|
| 24 |
+
out = avg_out + max_out
|
| 25 |
+
attention = self.sigmoid(out)
|
| 26 |
+
return x*attention
|
| 27 |
+
|
| 28 |
+
class LateralConnection(nn.Module):
|
| 29 |
+
def __init__(self, fast_channels=32, slow_channels=64):
|
| 30 |
+
super(LateralConnection, self).__init__()
|
| 31 |
+
self.conv = nn.Sequential(
|
| 32 |
+
nn.Conv3d(fast_channels, slow_channels, [3, 1, 1], stride=[2, 1, 1], padding=[1,0,0]),
|
| 33 |
+
nn.BatchNorm3d(64),
|
| 34 |
+
nn.ReLU(),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def forward(self, slow_path, fast_path):
|
| 38 |
+
fast_path = self.conv(fast_path)
|
| 39 |
+
return fast_path + slow_path
|
| 40 |
+
|
| 41 |
+
class CDC_T(nn.Module):
|
| 42 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
| 43 |
+
padding=1, dilation=1, groups=1, bias=False, theta=0.2):
|
| 44 |
+
|
| 45 |
+
super(CDC_T, self).__init__()
|
| 46 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
| 47 |
+
dilation=dilation, groups=groups, bias=bias)
|
| 48 |
+
self.theta = theta
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
|
| 52 |
+
out_normal = self.conv(x)
|
| 53 |
+
|
| 54 |
+
if math.fabs(self.theta - 0.0) < 1e-8:
|
| 55 |
+
return out_normal
|
| 56 |
+
else:
|
| 57 |
+
[C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
|
| 58 |
+
|
| 59 |
+
# only CD works on temporal kernel size>1
|
| 60 |
+
if self.conv.weight.shape[2] > 1:
|
| 61 |
+
kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
|
| 62 |
+
2).sum(2)
|
| 63 |
+
kernel_diff = kernel_diff[:, :, None, None, None]
|
| 64 |
+
out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
|
| 65 |
+
padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
|
| 66 |
+
return out_normal - self.theta * out_diff
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
return out_normal
|
| 70 |
+
|
| 71 |
+
class MambaLayer(nn.Module):
|
| 72 |
+
def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, channel_token = False):
|
| 73 |
+
super(MambaLayer, self).__init__()
|
| 74 |
+
self.dim = dim
|
| 75 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 76 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 77 |
+
drop_path = 0
|
| 78 |
+
self.mamba = Mamba(
|
| 79 |
+
d_model=dim, # Model dimension d_model
|
| 80 |
+
d_state=d_state, # SSM state expansion factor
|
| 81 |
+
d_conv=d_conv, # Local convolution width
|
| 82 |
+
expand=expand, # Block expansion factor
|
| 83 |
+
bimamba=True,
|
| 84 |
+
)
|
| 85 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 86 |
+
self.apply(self._init_weights)
|
| 87 |
+
|
| 88 |
+
def _init_weights(self, m):
|
| 89 |
+
if isinstance(m, nn.Linear):
|
| 90 |
+
trunc_normal_(m.weight, std=.02)
|
| 91 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 92 |
+
nn.init.constant_(m.bias, 0)
|
| 93 |
+
elif isinstance(m, nn.LayerNorm):
|
| 94 |
+
nn.init.constant_(m.bias, 0)
|
| 95 |
+
nn.init.constant_(m.weight, 1.0)
|
| 96 |
+
elif isinstance(m, nn.Conv2d):
|
| 97 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 98 |
+
fan_out //= m.groups
|
| 99 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 100 |
+
if m.bias is not None:
|
| 101 |
+
m.bias.data.zero_()
|
| 102 |
+
|
| 103 |
+
def forward_patch_token(self, x):
|
| 104 |
+
B, C, nf, H, W = x.shape
|
| 105 |
+
B, d_model = x.shape[:2]
|
| 106 |
+
assert d_model == self.dim
|
| 107 |
+
n_tokens = x.shape[2:].numel()
|
| 108 |
+
img_dims = x.shape[2:]
|
| 109 |
+
x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2)
|
| 110 |
+
x_norm = self.norm1(x_flat)
|
| 111 |
+
x_mamba = self.mamba(x_norm)
|
| 112 |
+
x_out = self.norm2(x_flat + self.drop_path(x_mamba))
|
| 113 |
+
out = x_out.transpose(-1, -2).reshape(B, d_model, *img_dims)
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 118 |
+
x = x.type(torch.float32)
|
| 119 |
+
out = self.forward_patch_token(x)
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
def conv_block(in_channels, out_channels, kernel_size, stride, padding, bn=True, activation='relu'):
|
| 123 |
+
layers = [nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)]
|
| 124 |
+
if bn:
|
| 125 |
+
layers.append(nn.BatchNorm3d(out_channels))
|
| 126 |
+
if activation == 'relu':
|
| 127 |
+
layers.append(nn.ReLU(inplace=True))
|
| 128 |
+
elif activation == 'elu':
|
| 129 |
+
layers.append(nn.ELU(inplace=True))
|
| 130 |
+
return nn.Sequential(*layers)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PhysMamba(nn.Module):
|
| 134 |
+
def __init__(self, theta=0.5, drop_rate1=0.25, drop_rate2=0.5, frames=128):
|
| 135 |
+
super(PhysMamba, self).__init__()
|
| 136 |
+
|
| 137 |
+
self.ConvBlock1 = conv_block(3, 16, [1, 5, 5], stride=1, padding=[0, 2, 2])
|
| 138 |
+
self.ConvBlock2 = conv_block(16, 32, [3, 3, 3], stride=1, padding=1)
|
| 139 |
+
self.ConvBlock3 = conv_block(32, 64, [3, 3, 3], stride=1, padding=1)
|
| 140 |
+
self.ConvBlock4 = conv_block(64, 64, [4, 1, 1], stride=[4, 1, 1], padding=0)
|
| 141 |
+
self.ConvBlock5 = conv_block(64, 32, [2, 1, 1], stride=[2, 1, 1], padding=0)
|
| 142 |
+
self.ConvBlock6 = conv_block(32, 32, [3, 1, 1], stride=1, padding=[1, 0, 0], activation='elu')
|
| 143 |
+
|
| 144 |
+
# Temporal Difference Mamba Blocks
|
| 145 |
+
# Slow Stream
|
| 146 |
+
self.Block1 = self._build_block(64, theta)
|
| 147 |
+
self.Block2 = self._build_block(64, theta)
|
| 148 |
+
self.Block3 = self._build_block(64, theta)
|
| 149 |
+
# Fast Stream
|
| 150 |
+
self.Block4 = self._build_block(32, theta)
|
| 151 |
+
self.Block5 = self._build_block(32, theta)
|
| 152 |
+
self.Block6 = self._build_block(32, theta)
|
| 153 |
+
|
| 154 |
+
# Upsampling
|
| 155 |
+
self.upsample1 = nn.Sequential(
|
| 156 |
+
nn.Upsample(scale_factor=(2,1,1)),
|
| 157 |
+
nn.Conv3d(64, 64, [3, 1, 1], stride=1, padding=(1,0,0)),
|
| 158 |
+
nn.BatchNorm3d(64),
|
| 159 |
+
nn.ELU(),
|
| 160 |
+
)
|
| 161 |
+
self.upsample2 = nn.Sequential(
|
| 162 |
+
nn.Upsample(scale_factor=(2,1,1)),
|
| 163 |
+
nn.Conv3d(96, 48, [3, 1, 1], stride=1, padding=(1,0,0)),
|
| 164 |
+
nn.BatchNorm3d(48),
|
| 165 |
+
nn.ELU(),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.ConvBlockLast = nn.Conv3d(48, 1, [1, 1, 1], stride=1, padding=0)
|
| 169 |
+
self.MaxpoolSpa = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2))
|
| 170 |
+
self.MaxpoolSpaTem = nn.MaxPool3d((2, 2, 2), stride=2)
|
| 171 |
+
|
| 172 |
+
self.fuse_1 = LateralConnection(fast_channels=32, slow_channels=64)
|
| 173 |
+
self.fuse_2 = LateralConnection(fast_channels=32, slow_channels=64)
|
| 174 |
+
|
| 175 |
+
self.drop_1 = nn.Dropout(drop_rate1)
|
| 176 |
+
self.drop_2 = nn.Dropout(drop_rate1)
|
| 177 |
+
self.drop_3 = nn.Dropout(drop_rate2)
|
| 178 |
+
self.drop_4 = nn.Dropout(drop_rate2)
|
| 179 |
+
self.drop_5 = nn.Dropout(drop_rate2)
|
| 180 |
+
self.drop_6 = nn.Dropout(drop_rate2)
|
| 181 |
+
|
| 182 |
+
self.poolspa = nn.AdaptiveAvgPool3d((frames, 1, 1))
|
| 183 |
+
|
| 184 |
+
def _build_block(self, channels, theta):
|
| 185 |
+
return nn.Sequential(
|
| 186 |
+
CDC_T(channels, channels, theta=theta),
|
| 187 |
+
nn.BatchNorm3d(channels),
|
| 188 |
+
nn.ReLU(),
|
| 189 |
+
MambaLayer(dim=channels),
|
| 190 |
+
ChannelAttention3D(in_channels=channels, reduction=2),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def forward(self, x):
|
| 194 |
+
[batch, channel, length, width, height] = x.shape
|
| 195 |
+
|
| 196 |
+
x = self.ConvBlock1(x)
|
| 197 |
+
x = self.MaxpoolSpa(x)
|
| 198 |
+
x = self.ConvBlock2(x)
|
| 199 |
+
x = self.ConvBlock3(x)
|
| 200 |
+
x = self.MaxpoolSpa(x)
|
| 201 |
+
|
| 202 |
+
# Process streams
|
| 203 |
+
s_x = self.ConvBlock4(x) # Slow stream
|
| 204 |
+
f_x = self.ConvBlock5(x) # Fast stream
|
| 205 |
+
|
| 206 |
+
# First set of blocks and fusion
|
| 207 |
+
s_x1 = self.Block1(s_x)
|
| 208 |
+
s_x1 = self.MaxpoolSpa(s_x1)
|
| 209 |
+
s_x1 = self.drop_1(s_x1)
|
| 210 |
+
|
| 211 |
+
f_x1 = self.Block4(f_x)
|
| 212 |
+
f_x1 = self.MaxpoolSpa(f_x1)
|
| 213 |
+
f_x1 = self.drop_2(f_x1)
|
| 214 |
+
|
| 215 |
+
s_x1 = self.fuse_1(s_x1,f_x1) # LateralConnection
|
| 216 |
+
|
| 217 |
+
# Second set of blocks and fusion
|
| 218 |
+
s_x2 = self.Block2(s_x1)
|
| 219 |
+
s_x2 = self.MaxpoolSpa(s_x2)
|
| 220 |
+
s_x2 = self.drop_3(s_x2)
|
| 221 |
+
|
| 222 |
+
f_x2 = self.Block5(f_x1)
|
| 223 |
+
f_x2 = self.MaxpoolSpa(f_x2)
|
| 224 |
+
f_x2 = self.drop_4(f_x2)
|
| 225 |
+
|
| 226 |
+
s_x2 = self.fuse_2(s_x2,f_x2) # LateralConnection
|
| 227 |
+
|
| 228 |
+
# Third blocks and upsampling
|
| 229 |
+
s_x3 = self.Block3(s_x2)
|
| 230 |
+
s_x3 = self.upsample1(s_x3)
|
| 231 |
+
s_x3 = self.drop_5(s_x3)
|
| 232 |
+
|
| 233 |
+
f_x3 = self.Block6(f_x2)
|
| 234 |
+
f_x3 = self.ConvBlock6(f_x3)
|
| 235 |
+
f_x3 = self.drop_6(f_x3)
|
| 236 |
+
|
| 237 |
+
# Final fusion and upsampling
|
| 238 |
+
x_fusion = torch.cat((f_x3, s_x3), dim=1)
|
| 239 |
+
x_final = self.upsample2(x_fusion)
|
| 240 |
+
|
| 241 |
+
x_final = self.poolspa(x_final)
|
| 242 |
+
x_final = self.ConvBlockLast(x_final)
|
| 243 |
+
|
| 244 |
+
rPPG = x_final.view(-1, length)
|
| 245 |
+
|
| 246 |
+
return rPPG
|
neural_methods/model/PhysNet.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PhysNet
|
| 2 |
+
We repulicate the net pipeline of the orginal paper, but set the input as diffnormalized data.
|
| 3 |
+
orginal source:
|
| 4 |
+
Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks
|
| 5 |
+
British Machine Vision Conference (BMVC)} 2019,
|
| 6 |
+
By Zitong Yu, 2019/05/05
|
| 7 |
+
Only for research purpose, and commercial use is not allowed.
|
| 8 |
+
MIT License
|
| 9 |
+
Copyright (c) 2019
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
import pdb
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.nn.modules.utils import _triple
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PhysNet_padding_Encoder_Decoder_MAX(nn.Module):
|
| 21 |
+
def __init__(self, frames=128):
|
| 22 |
+
super(PhysNet_padding_Encoder_Decoder_MAX, self).__init__()
|
| 23 |
+
|
| 24 |
+
self.ConvBlock1 = nn.Sequential(
|
| 25 |
+
nn.Conv3d(3, 16, [1, 5, 5], stride=1, padding=[0, 2, 2]),
|
| 26 |
+
nn.BatchNorm3d(16),
|
| 27 |
+
nn.ReLU(inplace=True),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.ConvBlock2 = nn.Sequential(
|
| 31 |
+
nn.Conv3d(16, 32, [3, 3, 3], stride=1, padding=1),
|
| 32 |
+
nn.BatchNorm3d(32),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
)
|
| 35 |
+
self.ConvBlock3 = nn.Sequential(
|
| 36 |
+
nn.Conv3d(32, 64, [3, 3, 3], stride=1, padding=1),
|
| 37 |
+
nn.BatchNorm3d(64),
|
| 38 |
+
nn.ReLU(inplace=True),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.ConvBlock4 = nn.Sequential(
|
| 42 |
+
nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
|
| 43 |
+
nn.BatchNorm3d(64),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
)
|
| 46 |
+
self.ConvBlock5 = nn.Sequential(
|
| 47 |
+
nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
|
| 48 |
+
nn.BatchNorm3d(64),
|
| 49 |
+
nn.ReLU(inplace=True),
|
| 50 |
+
)
|
| 51 |
+
self.ConvBlock6 = nn.Sequential(
|
| 52 |
+
nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
|
| 53 |
+
nn.BatchNorm3d(64),
|
| 54 |
+
nn.ReLU(inplace=True),
|
| 55 |
+
)
|
| 56 |
+
self.ConvBlock7 = nn.Sequential(
|
| 57 |
+
nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
|
| 58 |
+
nn.BatchNorm3d(64),
|
| 59 |
+
nn.ReLU(inplace=True),
|
| 60 |
+
)
|
| 61 |
+
self.ConvBlock8 = nn.Sequential(
|
| 62 |
+
nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
|
| 63 |
+
nn.BatchNorm3d(64),
|
| 64 |
+
nn.ReLU(inplace=True),
|
| 65 |
+
)
|
| 66 |
+
self.ConvBlock9 = nn.Sequential(
|
| 67 |
+
nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1),
|
| 68 |
+
nn.BatchNorm3d(64),
|
| 69 |
+
nn.ReLU(inplace=True),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
self.upsample = nn.Sequential(
|
| 73 |
+
nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=[
|
| 74 |
+
4, 1, 1], stride=[2, 1, 1], padding=[1, 0, 0]), # [1, 128, 32]
|
| 75 |
+
nn.BatchNorm3d(64),
|
| 76 |
+
nn.ELU(),
|
| 77 |
+
)
|
| 78 |
+
self.upsample2 = nn.Sequential(
|
| 79 |
+
nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=[
|
| 80 |
+
4, 1, 1], stride=[2, 1, 1], padding=[1, 0, 0]), # [1, 128, 32]
|
| 81 |
+
nn.BatchNorm3d(64),
|
| 82 |
+
nn.ELU(),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.ConvBlock10 = nn.Conv3d(64, 1, [1, 1, 1], stride=1, padding=0)
|
| 86 |
+
|
| 87 |
+
self.MaxpoolSpa = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2))
|
| 88 |
+
self.MaxpoolSpaTem = nn.MaxPool3d((2, 2, 2), stride=2)
|
| 89 |
+
|
| 90 |
+
# self.poolspa = nn.AdaptiveMaxPool3d((frames,1,1)) # pool only spatial space
|
| 91 |
+
self.poolspa = nn.AdaptiveAvgPool3d((frames, 1, 1))
|
| 92 |
+
|
| 93 |
+
def forward(self, x): # Batch_size*[3, T, 128,128]
|
| 94 |
+
x_visual = x
|
| 95 |
+
[batch, channel, length, width, height] = x.shape
|
| 96 |
+
|
| 97 |
+
x = self.ConvBlock1(x) # x [3, T, 128,128]
|
| 98 |
+
x = self.MaxpoolSpa(x) # x [16, T, 64,64]
|
| 99 |
+
|
| 100 |
+
x = self.ConvBlock2(x) # x [32, T, 64,64]
|
| 101 |
+
x_visual6464 = self.ConvBlock3(x) # x [32, T, 64,64]
|
| 102 |
+
# x [32, T/2, 32,32] Temporal halve
|
| 103 |
+
x = self.MaxpoolSpaTem(x_visual6464)
|
| 104 |
+
|
| 105 |
+
x = self.ConvBlock4(x) # x [64, T/2, 32,32]
|
| 106 |
+
x_visual3232 = self.ConvBlock5(x) # x [64, T/2, 32,32]
|
| 107 |
+
x = self.MaxpoolSpaTem(x_visual3232) # x [64, T/4, 16,16]
|
| 108 |
+
|
| 109 |
+
x = self.ConvBlock6(x) # x [64, T/4, 16,16]
|
| 110 |
+
x_visual1616 = self.ConvBlock7(x) # x [64, T/4, 16,16]
|
| 111 |
+
x = self.MaxpoolSpa(x_visual1616) # x [64, T/4, 8,8]
|
| 112 |
+
|
| 113 |
+
x = self.ConvBlock8(x) # x [64, T/4, 8, 8]
|
| 114 |
+
x = self.ConvBlock9(x) # x [64, T/4, 8, 8]
|
| 115 |
+
x = self.upsample(x) # x [64, T/2, 8, 8]
|
| 116 |
+
x = self.upsample2(x) # x [64, T, 8, 8]
|
| 117 |
+
|
| 118 |
+
# x [64, T, 1,1] --> groundtruth left and right - 7
|
| 119 |
+
x = self.poolspa(x)
|
| 120 |
+
x = self.ConvBlock10(x) # x [1, T, 1,1]
|
| 121 |
+
|
| 122 |
+
rPPG = x.view(-1, length)
|
| 123 |
+
|
| 124 |
+
return rPPG, x_visual, x_visual3232, x_visual1616
|
neural_methods/model/RhythmFormer.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RhythmFormer:Extracting rPPG Signals Based on Hierarchical Temporal Periodic Transformer
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn, Tensor, LongTensor
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
import math
|
| 9 |
+
from typing import Tuple, Union
|
| 10 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
Adapted from here: https://github.com/rayleizhu/BiFormer
|
| 16 |
+
"""
|
| 17 |
+
import torch
|
| 18 |
+
from torch import Tensor, LongTensor , nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from typing import Optional, Tuple
|
| 21 |
+
|
| 22 |
+
def _grid2seq(x:Tensor, region_size:Tuple[int], num_heads:int):
|
| 23 |
+
"""
|
| 24 |
+
Args:
|
| 25 |
+
x: BCTHW tensor
|
| 26 |
+
region size: int
|
| 27 |
+
num_heads: number of attention heads
|
| 28 |
+
Return:
|
| 29 |
+
out: rearranged x, has a shape of (bs, nhead, nregion, reg_size, head_dim)
|
| 30 |
+
region_t, region_h, region_w: number of regions per t/col/row
|
| 31 |
+
"""
|
| 32 |
+
B, C, T, H, W = x.size()
|
| 33 |
+
region_t ,region_h, region_w = T//region_size[0], H//region_size[1], W//region_size[2]
|
| 34 |
+
x = x.view(B, num_heads, C//num_heads, region_t, region_size[0],region_h, region_size[1], region_w, region_size[2])
|
| 35 |
+
x = torch.einsum('bmdtohpwq->bmthwopqd', x).flatten(2, 4).flatten(-4, -2) # (bs, nhead, nregion, reg_size, head_dim)
|
| 36 |
+
return x, region_t, region_h, region_w
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _seq2grid(x:Tensor, region_t:int, region_h:int, region_w:int, region_size:Tuple[int]):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
x: (bs, nhead, nregion, reg_size^2, head_dim)
|
| 43 |
+
Return:
|
| 44 |
+
x: (bs, C, T, H, W)
|
| 45 |
+
"""
|
| 46 |
+
bs, nhead, nregion, reg_size_square, head_dim = x.size()
|
| 47 |
+
x = x.view(bs, nhead, region_t, region_h, region_w, region_size[0], region_size[1], region_size[2], head_dim)
|
| 48 |
+
x = torch.einsum('bmthwopqd->bmdtohpwq', x).reshape(bs, nhead*head_dim,
|
| 49 |
+
region_t*region_size[0],region_h*region_size[1], region_w*region_size[2])
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def video_regional_routing_attention_torch(
|
| 54 |
+
query:Tensor, key:Tensor, value:Tensor, scale:float,
|
| 55 |
+
region_graph:LongTensor, region_size:Tuple[int],
|
| 56 |
+
kv_region_size:Optional[Tuple[int]]=None,
|
| 57 |
+
auto_pad=False)->Tensor:
|
| 58 |
+
"""
|
| 59 |
+
Args:
|
| 60 |
+
query, key, value: (B, C, T, H, W) tensor
|
| 61 |
+
scale: the scale/temperature for dot product attention
|
| 62 |
+
region_graph: (B, nhead, t_q*h_q*w_q, topk) tensor, topk <= t_k*h_k*w_k
|
| 63 |
+
region_size: region/window size for queries, (rt, rh, rw)
|
| 64 |
+
key_region_size: optional, if None, key_region_size=region_size
|
| 65 |
+
Return:
|
| 66 |
+
output: (B, C, T, H, W) tensor
|
| 67 |
+
attn: (bs, nhead, q_nregion, reg_size, topk*kv_region_size) attention matrix
|
| 68 |
+
"""
|
| 69 |
+
kv_region_size = kv_region_size or region_size
|
| 70 |
+
bs, nhead, q_nregion, topk = region_graph.size()
|
| 71 |
+
|
| 72 |
+
# # Auto pad to deal with any input size
|
| 73 |
+
# q_pad_b, q_pad_r, kv_pad_b, kv_pad_r = 0, 0, 0, 0
|
| 74 |
+
# if auto_pad:
|
| 75 |
+
# _, _, Hq, Wq = query.size()
|
| 76 |
+
# q_pad_b = (region_size[0] - Hq % region_size[0]) % region_size[0]
|
| 77 |
+
# q_pad_r = (region_size[1] - Wq % region_size[1]) % region_size[1]
|
| 78 |
+
# if (q_pad_b > 0 or q_pad_r > 0):
|
| 79 |
+
# query = F.pad(query, (0, q_pad_r, 0, q_pad_b)) # zero padding
|
| 80 |
+
|
| 81 |
+
# _, _, Hk, Wk = key.size()
|
| 82 |
+
# kv_pad_b = (kv_region_size[0] - Hk % kv_region_size[0]) % kv_region_size[0]
|
| 83 |
+
# kv_pad_r = (kv_region_size[1] - Wk % kv_region_size[1]) % kv_region_size[1]
|
| 84 |
+
# if (kv_pad_r > 0 or kv_pad_b > 0):
|
| 85 |
+
# key = F.pad(key, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
|
| 86 |
+
# value = F.pad(value, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
|
| 87 |
+
|
| 88 |
+
# to sequence format, i.e. (bs, nhead, nregion, reg_size, head_dim)
|
| 89 |
+
query, q_region_t, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=nhead)
|
| 90 |
+
key, _, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=nhead)
|
| 91 |
+
value, _, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=nhead)
|
| 92 |
+
|
| 93 |
+
# gather key and values.
|
| 94 |
+
# torch.gather does not support broadcasting, hence we do it manually
|
| 95 |
+
bs, nhead, kv_nregion, kv_region_size, head_dim = key.size()
|
| 96 |
+
broadcasted_region_graph = region_graph.view(bs, nhead, q_nregion, topk, 1, 1).\
|
| 97 |
+
expand(-1, -1, -1, -1, kv_region_size, head_dim)
|
| 98 |
+
key_g = torch.gather(key.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\
|
| 99 |
+
expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
|
| 100 |
+
index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
|
| 101 |
+
value_g = torch.gather(value.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\
|
| 102 |
+
expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
|
| 103 |
+
index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
|
| 104 |
+
|
| 105 |
+
# token-to-token attention
|
| 106 |
+
# (bs, nhead, q_nregion, reg_size, head_dim) @ (bs, nhead, q_nregion, head_dim, topk*kv_region_size)
|
| 107 |
+
# -> (bs, nhead, q_nregion, reg_size, topk*kv_region_size)
|
| 108 |
+
attn = (query * scale) @ key_g.flatten(-3, -2).transpose(-1, -2)
|
| 109 |
+
attn = torch.softmax(attn, dim=-1)
|
| 110 |
+
# (bs, nhead, q_nregion, reg_size, topk*kv_region_size) @ (bs, nhead, q_nregion, topk*kv_region_size, head_dim)
|
| 111 |
+
# -> (bs, nhead, q_nregion, reg_size, head_dim)
|
| 112 |
+
output = attn @ value_g.flatten(-3, -2)
|
| 113 |
+
|
| 114 |
+
# to BCTHW format
|
| 115 |
+
output = _seq2grid(output, region_t=q_region_t, region_h=q_region_h, region_w=q_region_w, region_size=region_size)
|
| 116 |
+
|
| 117 |
+
# remove paddings if needed
|
| 118 |
+
# if auto_pad and (q_pad_b > 0 or q_pad_r > 0):
|
| 119 |
+
# output = output[:, :, :Hq, :Wq]
|
| 120 |
+
|
| 121 |
+
return output, attn
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class CDC_T(nn.Module):
|
| 127 |
+
"""
|
| 128 |
+
The CDC_T Module is from here: https://github.com/ZitongYu/PhysFormer/model/transformer_layer.py
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
| 131 |
+
padding=1, dilation=1, groups=1, bias=False, theta=0.6):
|
| 132 |
+
|
| 133 |
+
super(CDC_T, self).__init__()
|
| 134 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
| 135 |
+
dilation=dilation, groups=groups, bias=bias)
|
| 136 |
+
self.theta = theta
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
out_normal = self.conv(x)
|
| 140 |
+
|
| 141 |
+
if math.fabs(self.theta - 0.0) < 1e-8:
|
| 142 |
+
return out_normal
|
| 143 |
+
else:
|
| 144 |
+
# pdb.set_trace()
|
| 145 |
+
[C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
|
| 146 |
+
|
| 147 |
+
# only CD works on temporal kernel size>1
|
| 148 |
+
if self.conv.weight.shape[2] > 1:
|
| 149 |
+
kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
|
| 150 |
+
2).sum(2)
|
| 151 |
+
kernel_diff = kernel_diff[:, :, None, None, None]
|
| 152 |
+
out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
|
| 153 |
+
padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
|
| 154 |
+
return out_normal - self.theta * out_diff
|
| 155 |
+
|
| 156 |
+
else:
|
| 157 |
+
return out_normal
|
| 158 |
+
|
| 159 |
+
class video_BRA(nn.Module):
|
| 160 |
+
|
| 161 |
+
def __init__(self, dim, num_heads=8, t_patch=8, qk_scale=None, topk=4, side_dwconv=3, auto_pad=False, attn_backend='torch'):
|
| 162 |
+
super().__init__()
|
| 163 |
+
|
| 164 |
+
self.dim = dim
|
| 165 |
+
self.num_heads = num_heads
|
| 166 |
+
assert self.dim % num_heads == 0, 'dim must be divisible by num_heads!'
|
| 167 |
+
self.head_dim = self.dim // self.num_heads
|
| 168 |
+
self.scale = qk_scale or self.dim ** -0.5
|
| 169 |
+
self.topk = topk
|
| 170 |
+
self.t_patch = t_patch # frame of patch
|
| 171 |
+
################side_dwconv (i.e. LCE in Shunted Transformer)###########
|
| 172 |
+
self.lepe = nn.Conv3d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
|
| 173 |
+
lambda x: torch.zeros_like(x)
|
| 174 |
+
##########################################
|
| 175 |
+
self.qkv_linear = nn.Conv3d(self.dim, 3*self.dim, kernel_size=1)
|
| 176 |
+
self.output_linear = nn.Conv3d(self.dim, self.dim, kernel_size=1)
|
| 177 |
+
self.proj_q = nn.Sequential(
|
| 178 |
+
CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=0.2),
|
| 179 |
+
nn.BatchNorm3d(dim),
|
| 180 |
+
)
|
| 181 |
+
self.proj_k = nn.Sequential(
|
| 182 |
+
CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=0.2),
|
| 183 |
+
nn.BatchNorm3d(dim),
|
| 184 |
+
)
|
| 185 |
+
self.proj_v = nn.Sequential(
|
| 186 |
+
nn.Conv3d(dim, dim, 1, stride=1, padding=0, groups=1, bias=False),
|
| 187 |
+
)
|
| 188 |
+
if attn_backend == 'torch':
|
| 189 |
+
self.attn_fn = video_regional_routing_attention_torch
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError('CUDA implementation is not available yet. Please stay tuned.')
|
| 192 |
+
|
| 193 |
+
def forward(self, x:Tensor):
|
| 194 |
+
|
| 195 |
+
N, C, T, H, W = x.size()
|
| 196 |
+
t_region = max(4 // self.t_patch , 1)
|
| 197 |
+
region_size = (t_region, H//4 , W//4)
|
| 198 |
+
|
| 199 |
+
# STEP 1: linear projection
|
| 200 |
+
q , k , v = self.proj_q(x) , self.proj_k(x) ,self.proj_v(x)
|
| 201 |
+
|
| 202 |
+
# STEP 2: pre attention
|
| 203 |
+
q_r = F.avg_pool3d(q.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False)
|
| 204 |
+
k_r = F.avg_pool3d(k.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) # ncthw
|
| 205 |
+
q_r:Tensor = q_r.permute(0, 2, 3, 4, 1).flatten(1, 3) # n(thw)c
|
| 206 |
+
k_r:Tensor = k_r.flatten(2, 4) # nc(thw)
|
| 207 |
+
a_r = q_r @ k_r # n(thw)(thw)
|
| 208 |
+
_, idx_r = torch.topk(a_r, k=self.topk, dim=-1) # n(thw)k
|
| 209 |
+
idx_r:LongTensor = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1)
|
| 210 |
+
|
| 211 |
+
# STEP 3: refined attention
|
| 212 |
+
output, attn_mat = self.attn_fn(query=q, key=k, value=v, scale=self.scale,
|
| 213 |
+
region_graph=idx_r, region_size=region_size)
|
| 214 |
+
|
| 215 |
+
output = output + self.lepe(v) # nctHW
|
| 216 |
+
output = self.output_linear(output) # nctHW
|
| 217 |
+
|
| 218 |
+
return output
|
| 219 |
+
|
| 220 |
+
class video_BiFormerBlock(nn.Module):
|
| 221 |
+
def __init__(self, dim, drop_path=0., num_heads=4, t_patch=1,qk_scale=None, topk=4, mlp_ratio=2, side_dwconv=5):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.t_patch = t_patch
|
| 224 |
+
self.norm1 = nn.BatchNorm3d(dim)
|
| 225 |
+
self.attn = video_BRA(dim=dim, num_heads=num_heads, t_patch=t_patch,qk_scale=qk_scale, topk=topk, side_dwconv=side_dwconv)
|
| 226 |
+
self.norm2 = nn.BatchNorm3d(dim)
|
| 227 |
+
self.mlp = nn.Sequential(nn.Conv3d(dim, int(mlp_ratio*dim), kernel_size=1),
|
| 228 |
+
nn.BatchNorm3d(int(mlp_ratio*dim)),
|
| 229 |
+
nn.GELU(),
|
| 230 |
+
nn.Conv3d(int(mlp_ratio*dim), int(mlp_ratio*dim), 3, stride=1, padding=1),
|
| 231 |
+
nn.BatchNorm3d(int(mlp_ratio*dim)),
|
| 232 |
+
nn.GELU(),
|
| 233 |
+
nn.Conv3d(int(mlp_ratio*dim), dim, kernel_size=1),
|
| 234 |
+
nn.BatchNorm3d(dim),
|
| 235 |
+
)
|
| 236 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 237 |
+
|
| 238 |
+
def forward(self, x):
|
| 239 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 240 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 241 |
+
return x
|
| 242 |
+
|
| 243 |
+
class Fusion_Stem(nn.Module):
|
| 244 |
+
def __init__(self,apha=0.5,belta=0.5):
|
| 245 |
+
super(Fusion_Stem, self).__init__()
|
| 246 |
+
|
| 247 |
+
self.stem11 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
|
| 248 |
+
nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
|
| 249 |
+
nn.ReLU(inplace=True),
|
| 250 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.stem12 = nn.Sequential(nn.Conv2d(12, 64, kernel_size=5, stride=2, padding=2),
|
| 254 |
+
nn.BatchNorm2d(64),
|
| 255 |
+
nn.ReLU(inplace=True),
|
| 256 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.stem21 =nn.Sequential(
|
| 260 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
| 261 |
+
nn.BatchNorm2d(64),
|
| 262 |
+
nn.ReLU(inplace=True),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self.stem22 =nn.Sequential(
|
| 266 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
| 267 |
+
nn.BatchNorm2d(64),
|
| 268 |
+
nn.ReLU(inplace=True),
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
self.apha = apha
|
| 272 |
+
self.belta = belta
|
| 273 |
+
|
| 274 |
+
def forward(self, x):
|
| 275 |
+
"""Definition of Fusion_Stem.
|
| 276 |
+
Args:
|
| 277 |
+
x [N,D,C,H,W]
|
| 278 |
+
Returns:
|
| 279 |
+
fusion_x [N*D,C,H/4,W/4]
|
| 280 |
+
"""
|
| 281 |
+
N, D, C, H, W = x.shape
|
| 282 |
+
x1 = torch.cat([x[:,:1,:,:,:],x[:,:1,:,:,:],x[:,:D-2,:,:,:]],1)
|
| 283 |
+
x2 = torch.cat([x[:,:1,:,:,:],x[:,:D-1,:,:,:]],1)
|
| 284 |
+
x3 = x
|
| 285 |
+
x4 = torch.cat([x[:,1:,:,:,:],x[:,D-1:,:,:,:]],1)
|
| 286 |
+
x5 = torch.cat([x[:,2:,:,:,:],x[:,D-1:,:,:,:],x[:,D-1:,:,:,:]],1)
|
| 287 |
+
x_diff = self.stem12(torch.cat([x2-x1,x3-x2,x4-x3,x5-x4],2).view(N * D, 12, H, W))
|
| 288 |
+
x3 = x3.contiguous().view(N * D, C, H, W)
|
| 289 |
+
x = self.stem11(x3)
|
| 290 |
+
|
| 291 |
+
#fusion layer1
|
| 292 |
+
x_path1 = self.apha*x + self.belta*x_diff
|
| 293 |
+
x_path1 = self.stem21(x_path1)
|
| 294 |
+
#fusion layer2
|
| 295 |
+
x_path2 = self.stem22(x_diff)
|
| 296 |
+
x = self.apha*x_path1 + self.belta*x_path2
|
| 297 |
+
|
| 298 |
+
return x
|
| 299 |
+
|
| 300 |
+
class TPT_Block(nn.Module):
|
| 301 |
+
def __init__(self, dim, depth, num_heads, t_patch, topk,
|
| 302 |
+
mlp_ratio=4., drop_path=0., side_dwconv=5):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.dim = dim
|
| 305 |
+
self.depth = depth
|
| 306 |
+
############ downsample layers & upsample layers #####################
|
| 307 |
+
self.downsample_layers = nn.ModuleList()
|
| 308 |
+
self.upsample_layers = nn.ModuleList()
|
| 309 |
+
self.layer_n = int(math.log(t_patch,2))
|
| 310 |
+
for i in range(self.layer_n):
|
| 311 |
+
downsample_layer = nn.Sequential(
|
| 312 |
+
nn.BatchNorm3d(dim),
|
| 313 |
+
nn.Conv3d(dim , dim , kernel_size=(2, 1, 1), stride=(2, 1, 1)),
|
| 314 |
+
)
|
| 315 |
+
self.downsample_layers.append(downsample_layer)
|
| 316 |
+
upsample_layer = nn.Sequential(
|
| 317 |
+
nn.Upsample(scale_factor=(2, 1, 1)),
|
| 318 |
+
nn.Conv3d(dim , dim , [3, 1, 1], stride=1, padding=(1, 0, 0)),
|
| 319 |
+
nn.BatchNorm3d(dim),
|
| 320 |
+
nn.ELU(),
|
| 321 |
+
)
|
| 322 |
+
self.upsample_layers.append(upsample_layer)
|
| 323 |
+
######################################################################
|
| 324 |
+
self.blocks = nn.ModuleList([
|
| 325 |
+
video_BiFormerBlock(
|
| 326 |
+
dim=dim,
|
| 327 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 328 |
+
num_heads=num_heads,
|
| 329 |
+
t_patch=t_patch,
|
| 330 |
+
topk=topk,
|
| 331 |
+
mlp_ratio=mlp_ratio,
|
| 332 |
+
side_dwconv=side_dwconv,
|
| 333 |
+
)
|
| 334 |
+
for i in range(depth)
|
| 335 |
+
])
|
| 336 |
+
def forward(self, x:torch.Tensor):
|
| 337 |
+
"""Definition of TPT_Block.
|
| 338 |
+
Args:
|
| 339 |
+
x [N,C,D,H,W]
|
| 340 |
+
Returns:
|
| 341 |
+
x [N,C,D,H,W]
|
| 342 |
+
"""
|
| 343 |
+
for i in range(self.layer_n) :
|
| 344 |
+
x = self.downsample_layers[i](x)
|
| 345 |
+
for blk in self.blocks:
|
| 346 |
+
x = blk(x)
|
| 347 |
+
for i in range(self.layer_n) :
|
| 348 |
+
x = self.upsample_layers[i](x)
|
| 349 |
+
|
| 350 |
+
return x
|
| 351 |
+
|
| 352 |
+
class RhythmFormer(nn.Module):
|
| 353 |
+
|
| 354 |
+
def __init__(
|
| 355 |
+
self,
|
| 356 |
+
name: Optional[str] = None,
|
| 357 |
+
pretrained: bool = False,
|
| 358 |
+
dim: int = 64, frame: int = 160,
|
| 359 |
+
image_size: Optional[int] = (160,128,128),
|
| 360 |
+
in_chans=64, head_dim=16,
|
| 361 |
+
stage_n = 3,
|
| 362 |
+
embed_dim=[64, 64, 64], mlp_ratios=[1.5, 1.5, 1.5],
|
| 363 |
+
depth=[2, 2, 2],
|
| 364 |
+
t_patchs:Union[int, Tuple[int]]=(2, 4, 8),
|
| 365 |
+
topks:Union[int, Tuple[int]]=(40, 40, 40),
|
| 366 |
+
side_dwconv:int=3,
|
| 367 |
+
drop_path_rate=0.,
|
| 368 |
+
use_checkpoint_stages=[],
|
| 369 |
+
):
|
| 370 |
+
super().__init__()
|
| 371 |
+
|
| 372 |
+
self.image_size = image_size
|
| 373 |
+
self.frame = frame
|
| 374 |
+
self.dim = dim
|
| 375 |
+
self.stage_n = stage_n
|
| 376 |
+
|
| 377 |
+
self.Fusion_Stem = Fusion_Stem()
|
| 378 |
+
self.patch_embedding = nn.Conv3d(in_chans,embed_dim[0], kernel_size=(1, 4, 4), stride=(1, 4, 4))
|
| 379 |
+
self.ConvBlockLast = nn.Conv1d(embed_dim[-1], 1, kernel_size=1,stride=1, padding=0)
|
| 380 |
+
|
| 381 |
+
##########################################################################
|
| 382 |
+
self.stages = nn.ModuleList()
|
| 383 |
+
nheads= [dim // head_dim for dim in embed_dim]
|
| 384 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]
|
| 385 |
+
for i in range(stage_n):
|
| 386 |
+
stage = TPT_Block(dim=embed_dim[i],
|
| 387 |
+
depth=depth[i],
|
| 388 |
+
num_heads=nheads[i],
|
| 389 |
+
mlp_ratio=mlp_ratios[i],
|
| 390 |
+
drop_path=dp_rates[sum(depth[:i]):sum(depth[:i+1])],
|
| 391 |
+
t_patch=t_patchs[i], topk=topks[i], side_dwconv=side_dwconv
|
| 392 |
+
)
|
| 393 |
+
self.stages.append(stage)
|
| 394 |
+
##########################################################################
|
| 395 |
+
|
| 396 |
+
self.apply(self._init_weights)
|
| 397 |
+
|
| 398 |
+
def _init_weights(self, m):
|
| 399 |
+
if isinstance(m, nn.Linear):
|
| 400 |
+
trunc_normal_(m.weight, std=.02)
|
| 401 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 402 |
+
nn.init.constant_(m.bias, 0)
|
| 403 |
+
elif isinstance(m, nn.LayerNorm):
|
| 404 |
+
nn.init.constant_(m.bias, 0)
|
| 405 |
+
nn.init.constant_(m.weight, 1.0)
|
| 406 |
+
|
| 407 |
+
def forward(self, x):
|
| 408 |
+
N, D, C, H, W = x.shape
|
| 409 |
+
x = self.Fusion_Stem(x) #[N*D 64 H/4 W/4]
|
| 410 |
+
x = x.view(N,D,64,H//4,W//4).permute(0,2,1,3,4)
|
| 411 |
+
x = self.patch_embedding(x) #[N 64 D 8 8]
|
| 412 |
+
for i in range(3):
|
| 413 |
+
x = self.stages[i](x) #[N 64 D 8 8]
|
| 414 |
+
features_last = torch.mean(x,3) #[N, 64, D, 8]
|
| 415 |
+
features_last = torch.mean(features_last,3) #[N, 64, D]
|
| 416 |
+
rPPG = self.ConvBlockLast(features_last) #[N, 1, D]
|
| 417 |
+
rPPG = rPPG.squeeze(1)
|
| 418 |
+
return rPPG
|
neural_methods/model/TS_CAN.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Temporal Shift Convolutional Attention Network (TS-CAN).
|
| 2 |
+
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement
|
| 3 |
+
NeurIPS, 2020
|
| 4 |
+
Xin Liu, Josh Fromm, Shwetak Patel, Daniel McDuff
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Attention_mask(nn.Module):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super(Attention_mask, self).__init__()
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
xsum = torch.sum(x, dim=2, keepdim=True)
|
| 17 |
+
xsum = torch.sum(xsum, dim=3, keepdim=True)
|
| 18 |
+
xshape = tuple(x.size())
|
| 19 |
+
return x / xsum * xshape[2] * xshape[3] * 0.5
|
| 20 |
+
|
| 21 |
+
def get_config(self):
|
| 22 |
+
"""May be generated manually. """
|
| 23 |
+
config = super(Attention_mask, self).get_config()
|
| 24 |
+
return config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TSM(nn.Module):
|
| 28 |
+
def __init__(self, n_segment=10, fold_div=3):
|
| 29 |
+
super(TSM, self).__init__()
|
| 30 |
+
self.n_segment = n_segment
|
| 31 |
+
self.fold_div = fold_div
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
nt, c, h, w = x.size()
|
| 35 |
+
n_batch = nt // self.n_segment
|
| 36 |
+
x = x.view(n_batch, self.n_segment, c, h, w)
|
| 37 |
+
fold = c // self.fold_div
|
| 38 |
+
out = torch.zeros_like(x)
|
| 39 |
+
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
|
| 40 |
+
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
|
| 41 |
+
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
|
| 42 |
+
return out.view(nt, c, h, w)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TSCAN(nn.Module):
|
| 46 |
+
|
| 47 |
+
def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
|
| 48 |
+
dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20, img_size=36):
|
| 49 |
+
"""Definition of TS_CAN.
|
| 50 |
+
Args:
|
| 51 |
+
in_channels: the number of input channel. Default: 3
|
| 52 |
+
frame_depth: the number of frame (window size) used in temport shift. Default: 20
|
| 53 |
+
img_size: height/width of each frame. Default: 36.
|
| 54 |
+
Returns:
|
| 55 |
+
TS_CAN model.
|
| 56 |
+
"""
|
| 57 |
+
super(TSCAN, self).__init__()
|
| 58 |
+
self.in_channels = in_channels
|
| 59 |
+
self.kernel_size = kernel_size
|
| 60 |
+
self.dropout_rate1 = dropout_rate1
|
| 61 |
+
self.dropout_rate2 = dropout_rate2
|
| 62 |
+
self.pool_size = pool_size
|
| 63 |
+
self.nb_filters1 = nb_filters1
|
| 64 |
+
self.nb_filters2 = nb_filters2
|
| 65 |
+
self.nb_dense = nb_dense
|
| 66 |
+
# TSM layers
|
| 67 |
+
self.TSM_1 = TSM(n_segment=frame_depth)
|
| 68 |
+
self.TSM_2 = TSM(n_segment=frame_depth)
|
| 69 |
+
self.TSM_3 = TSM(n_segment=frame_depth)
|
| 70 |
+
self.TSM_4 = TSM(n_segment=frame_depth)
|
| 71 |
+
# Motion branch convs
|
| 72 |
+
self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
|
| 73 |
+
bias=True)
|
| 74 |
+
self.motion_conv2 = nn.Conv2d(
|
| 75 |
+
self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
|
| 76 |
+
self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
|
| 77 |
+
bias=True)
|
| 78 |
+
self.motion_conv4 = nn.Conv2d(
|
| 79 |
+
self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
|
| 80 |
+
# Apperance branch convs
|
| 81 |
+
self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size,
|
| 82 |
+
padding=(1, 1), bias=True)
|
| 83 |
+
self.apperance_conv2 = nn.Conv2d(
|
| 84 |
+
self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
|
| 85 |
+
self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size,
|
| 86 |
+
padding=(1, 1), bias=True)
|
| 87 |
+
self.apperance_conv4 = nn.Conv2d(
|
| 88 |
+
self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
|
| 89 |
+
# Attention layers
|
| 90 |
+
self.apperance_att_conv1 = nn.Conv2d(
|
| 91 |
+
self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
|
| 92 |
+
self.attn_mask_1 = Attention_mask()
|
| 93 |
+
self.apperance_att_conv2 = nn.Conv2d(
|
| 94 |
+
self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
|
| 95 |
+
self.attn_mask_2 = Attention_mask()
|
| 96 |
+
# Avg pooling
|
| 97 |
+
self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
|
| 98 |
+
self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
|
| 99 |
+
self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
|
| 100 |
+
# Dropout layers
|
| 101 |
+
self.dropout_1 = nn.Dropout(self.dropout_rate1)
|
| 102 |
+
self.dropout_2 = nn.Dropout(self.dropout_rate1)
|
| 103 |
+
self.dropout_3 = nn.Dropout(self.dropout_rate1)
|
| 104 |
+
self.dropout_4 = nn.Dropout(self.dropout_rate2)
|
| 105 |
+
# Dense layers
|
| 106 |
+
if img_size == 36:
|
| 107 |
+
self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True)
|
| 108 |
+
elif img_size == 72:
|
| 109 |
+
self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True)
|
| 110 |
+
elif img_size == 96:
|
| 111 |
+
self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True)
|
| 112 |
+
elif img_size == 128:
|
| 113 |
+
self.final_dense_1 = nn.Linear(57600, self.nb_dense, bias=True)
|
| 114 |
+
else:
|
| 115 |
+
raise Exception('Unsupported image size')
|
| 116 |
+
self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True)
|
| 117 |
+
|
| 118 |
+
def forward(self, inputs, params=None):
|
| 119 |
+
diff_input = inputs[:, :3, :, :]
|
| 120 |
+
raw_input = inputs[:, 3:, :, :]
|
| 121 |
+
|
| 122 |
+
diff_input = self.TSM_1(diff_input)
|
| 123 |
+
d1 = torch.tanh(self.motion_conv1(diff_input))
|
| 124 |
+
d1 = self.TSM_2(d1)
|
| 125 |
+
d2 = torch.tanh(self.motion_conv2(d1))
|
| 126 |
+
|
| 127 |
+
r1 = torch.tanh(self.apperance_conv1(raw_input))
|
| 128 |
+
r2 = torch.tanh(self.apperance_conv2(r1))
|
| 129 |
+
|
| 130 |
+
g1 = torch.sigmoid(self.apperance_att_conv1(r2))
|
| 131 |
+
g1 = self.attn_mask_1(g1)
|
| 132 |
+
gated1 = d2 * g1
|
| 133 |
+
|
| 134 |
+
d3 = self.avg_pooling_1(gated1)
|
| 135 |
+
d4 = self.dropout_1(d3)
|
| 136 |
+
|
| 137 |
+
r3 = self.avg_pooling_2(r2)
|
| 138 |
+
r4 = self.dropout_2(r3)
|
| 139 |
+
|
| 140 |
+
d4 = self.TSM_3(d4)
|
| 141 |
+
d5 = torch.tanh(self.motion_conv3(d4))
|
| 142 |
+
d5 = self.TSM_4(d5)
|
| 143 |
+
d6 = torch.tanh(self.motion_conv4(d5))
|
| 144 |
+
|
| 145 |
+
r5 = torch.tanh(self.apperance_conv3(r4))
|
| 146 |
+
r6 = torch.tanh(self.apperance_conv4(r5))
|
| 147 |
+
|
| 148 |
+
g2 = torch.sigmoid(self.apperance_att_conv2(r6))
|
| 149 |
+
g2 = self.attn_mask_2(g2)
|
| 150 |
+
gated2 = d6 * g2
|
| 151 |
+
|
| 152 |
+
d7 = self.avg_pooling_3(gated2)
|
| 153 |
+
d8 = self.dropout_3(d7)
|
| 154 |
+
d9 = d8.view(d8.size(0), -1)
|
| 155 |
+
d10 = torch.tanh(self.final_dense_1(d9))
|
| 156 |
+
d11 = self.dropout_4(d10)
|
| 157 |
+
out = self.final_dense_2(d11)
|
| 158 |
+
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class MTTS_CAN(nn.Module):
|
| 163 |
+
"""MTTS_CAN is the multi-task (respiration) version of TS-CAN"""
|
| 164 |
+
|
| 165 |
+
def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
|
| 166 |
+
dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20):
|
| 167 |
+
super(MTTS_CAN, self).__init__()
|
| 168 |
+
self.in_channels = in_channels
|
| 169 |
+
self.kernel_size = kernel_size
|
| 170 |
+
self.dropout_rate1 = dropout_rate1
|
| 171 |
+
self.dropout_rate2 = dropout_rate2
|
| 172 |
+
self.pool_size = pool_size
|
| 173 |
+
self.nb_filters1 = nb_filters1
|
| 174 |
+
self.nb_filters2 = nb_filters2
|
| 175 |
+
self.nb_dense = nb_dense
|
| 176 |
+
# TSM layers
|
| 177 |
+
self.TSM_1 = TSM(n_segment=frame_depth)
|
| 178 |
+
self.TSM_2 = TSM(n_segment=frame_depth)
|
| 179 |
+
self.TSM_3 = TSM(n_segment=frame_depth)
|
| 180 |
+
self.TSM_4 = TSM(n_segment=frame_depth)
|
| 181 |
+
# Motion branch convs
|
| 182 |
+
self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
|
| 183 |
+
bias=True)
|
| 184 |
+
self.motion_conv2 = nn.Conv2d(
|
| 185 |
+
self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
|
| 186 |
+
self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
|
| 187 |
+
bias=True)
|
| 188 |
+
self.motion_conv4 = nn.Conv2d(
|
| 189 |
+
self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
|
| 190 |
+
# Apperance branch convs
|
| 191 |
+
self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size,
|
| 192 |
+
padding=(1, 1), bias=True)
|
| 193 |
+
self.apperance_conv2 = nn.Conv2d(
|
| 194 |
+
self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
|
| 195 |
+
self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size,
|
| 196 |
+
padding=(1, 1), bias=True)
|
| 197 |
+
self.apperance_conv4 = nn.Conv2d(
|
| 198 |
+
self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
|
| 199 |
+
# Attention layers
|
| 200 |
+
self.apperance_att_conv1 = nn.Conv2d(
|
| 201 |
+
self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
|
| 202 |
+
self.attn_mask_1 = Attention_mask()
|
| 203 |
+
self.apperance_att_conv2 = nn.Conv2d(
|
| 204 |
+
self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
|
| 205 |
+
self.attn_mask_2 = Attention_mask()
|
| 206 |
+
# Avg pooling
|
| 207 |
+
self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
|
| 208 |
+
self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
|
| 209 |
+
self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
|
| 210 |
+
# Dropout layers
|
| 211 |
+
self.dropout_1 = nn.Dropout(self.dropout_rate1)
|
| 212 |
+
self.dropout_2 = nn.Dropout(self.dropout_rate1)
|
| 213 |
+
self.dropout_3 = nn.Dropout(self.dropout_rate1)
|
| 214 |
+
self.dropout_4_y = nn.Dropout(self.dropout_rate2)
|
| 215 |
+
self.dropout_4_r = nn.Dropout(self.dropout_rate2)
|
| 216 |
+
|
| 217 |
+
# Dense layers
|
| 218 |
+
self.final_dense_1_y = nn.Linear(16384, self.nb_dense, bias=True)
|
| 219 |
+
self.final_dense_2_y = nn.Linear(self.nb_dense, 1, bias=True)
|
| 220 |
+
self.final_dense_1_r = nn.Linear(16384, self.nb_dense, bias=True)
|
| 221 |
+
self.final_dense_2_r = nn.Linear(self.nb_dense, 1, bias=True)
|
| 222 |
+
|
| 223 |
+
def forward(self, inputs, params=None):
|
| 224 |
+
diff_input = inputs[:, :3, :, :]
|
| 225 |
+
raw_input = inputs[:, 3:, :, :]
|
| 226 |
+
|
| 227 |
+
diff_input = self.TSM_1(diff_input)
|
| 228 |
+
d1 = torch.tanh(self.motion_conv1(diff_input))
|
| 229 |
+
d1 = self.TSM_2(d1)
|
| 230 |
+
d2 = torch.tanh(self.motion_conv2(d1))
|
| 231 |
+
|
| 232 |
+
r1 = torch.tanh(self.apperance_conv1(raw_input))
|
| 233 |
+
r2 = torch.tanh(self.apperance_conv2(r1))
|
| 234 |
+
|
| 235 |
+
g1 = torch.sigmoid(self.apperance_att_conv1(r2))
|
| 236 |
+
g1 = self.attn_mask_1(g1)
|
| 237 |
+
gated1 = d2 * g1
|
| 238 |
+
|
| 239 |
+
d3 = self.avg_pooling_1(gated1)
|
| 240 |
+
d4 = self.dropout_1(d3)
|
| 241 |
+
|
| 242 |
+
r3 = self.avg_pooling_2(r2)
|
| 243 |
+
r4 = self.dropout_2(r3)
|
| 244 |
+
|
| 245 |
+
d4 = self.TSM_3(d4)
|
| 246 |
+
d5 = torch.tanh(self.motion_conv3(d4))
|
| 247 |
+
d5 = self.TSM_4(d5)
|
| 248 |
+
d6 = torch.tanh(self.motion_conv4(d5))
|
| 249 |
+
|
| 250 |
+
r5 = torch.tanh(self.apperance_conv3(r4))
|
| 251 |
+
r6 = torch.tanh(self.apperance_conv4(r5))
|
| 252 |
+
|
| 253 |
+
g2 = torch.sigmoid(self.apperance_att_conv2(r6))
|
| 254 |
+
g2 = self.attn_mask_2(g2)
|
| 255 |
+
gated2 = d6 * g2
|
| 256 |
+
|
| 257 |
+
d7 = self.avg_pooling_3(gated2)
|
| 258 |
+
d8 = self.dropout_3(d7)
|
| 259 |
+
d9 = d8.view(d8.size(0), -1)
|
| 260 |
+
|
| 261 |
+
d10 = torch.tanh(self.final_dense_1_y(d9))
|
| 262 |
+
d11 = self.dropout_4_y(d10)
|
| 263 |
+
out_y = self.final_dense_2_y(d11)
|
| 264 |
+
|
| 265 |
+
d10 = torch.tanh(self.final_dense_1_r(d9))
|
| 266 |
+
d11 = self.dropout_4_r(d10)
|
| 267 |
+
out_r = self.final_dense_2_r(d11)
|
| 268 |
+
|
| 269 |
+
return out_y, out_r
|
neural_methods/model/__init__.py
ADDED
|
File without changes
|
neural_methods/model/iBVPNet.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""iBVPNet - 3D Convolutional Network.
|
| 2 |
+
Proposed along with the iBVP Dataset, see https://doi.org/10.3390/electronics13071334
|
| 3 |
+
|
| 4 |
+
Joshi, Jitesh, and Youngjun Cho. 2024. "iBVP Dataset: RGB-Thermal rPPG Dataset with High Resolution Signal Quality Labels" Electronics 13, no. 7: 1334.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConvBlock3D(nn.Module):
|
| 13 |
+
def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
|
| 14 |
+
super(ConvBlock3D, self).__init__()
|
| 15 |
+
self.conv_block_3d = nn.Sequential(
|
| 16 |
+
nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding),
|
| 17 |
+
nn.Tanh(),
|
| 18 |
+
nn.InstanceNorm3d(out_channel),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return self.conv_block_3d(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DeConvBlock3D(nn.Module):
|
| 26 |
+
def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
|
| 27 |
+
super(DeConvBlock3D, self).__init__()
|
| 28 |
+
k_t, k_s1, k_s2 = kernel_size
|
| 29 |
+
s_t, s_s1, s_s2 = stride
|
| 30 |
+
self.deconv_block_3d = nn.Sequential(
|
| 31 |
+
nn.ConvTranspose3d(in_channel, in_channel, (k_t, 1, 1), (s_t, 1, 1), padding),
|
| 32 |
+
nn.Tanh(),
|
| 33 |
+
nn.InstanceNorm3d(in_channel),
|
| 34 |
+
|
| 35 |
+
nn.Conv3d(in_channel, out_channel, (1, k_s1, k_s2), (1, s_s1, s_s2), padding),
|
| 36 |
+
nn.Tanh(),
|
| 37 |
+
nn.InstanceNorm3d(out_channel),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return self.deconv_block_3d(x)
|
| 42 |
+
|
| 43 |
+
# num_filters
|
| 44 |
+
nf = [8, 16, 24, 40, 64]
|
| 45 |
+
|
| 46 |
+
class encoder_block(nn.Module):
|
| 47 |
+
def __init__(self, in_channel, debug=False):
|
| 48 |
+
super(encoder_block, self).__init__()
|
| 49 |
+
# in_channel, out_channel, kernel_size, stride, padding
|
| 50 |
+
|
| 51 |
+
self.debug = debug
|
| 52 |
+
self.spatio_temporal_encoder = nn.Sequential(
|
| 53 |
+
ConvBlock3D(in_channel, nf[0], [1, 3, 3], [1, 1, 1], [0, 1, 1]),
|
| 54 |
+
ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 1, 1], [1, 1, 1]),
|
| 55 |
+
nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
|
| 56 |
+
ConvBlock3D(nf[1], nf[2], [1, 3, 3], [1, 1, 1], [0, 1, 1]),
|
| 57 |
+
ConvBlock3D(nf[2], nf[3], [3, 3, 3], [1, 1, 1], [1, 1, 1]),
|
| 58 |
+
nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
|
| 59 |
+
ConvBlock3D(nf[3], nf[4], [1, 3, 3], [1, 1, 1], [0, 1, 1]),
|
| 60 |
+
ConvBlock3D(nf[4], nf[4], [3, 3, 3], [1, 1, 1], [1, 1, 1]),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.temporal_encoder = nn.Sequential(
|
| 64 |
+
ConvBlock3D(nf[4], nf[4], [11, 1, 1], [1, 1, 1], [5, 0, 0]),
|
| 65 |
+
ConvBlock3D(nf[4], nf[4], [11, 3, 3], [1, 1, 1], [5, 1, 1]),
|
| 66 |
+
nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),
|
| 67 |
+
ConvBlock3D(nf[4], nf[4], [11, 1, 1], [1, 1, 1], [5, 0, 0]),
|
| 68 |
+
ConvBlock3D(nf[4], nf[4], [11, 3, 3], [1, 1, 1], [5, 1, 1]),
|
| 69 |
+
nn.MaxPool3d((2, 2, 2), stride=(2, 1, 1)),
|
| 70 |
+
ConvBlock3D(nf[4], nf[4], [7, 1, 1], [1, 1, 1], [3, 0, 0]),
|
| 71 |
+
ConvBlock3D(nf[4], nf[4], [7, 3, 3], [1, 1, 1], [3, 1, 1])
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
if self.debug:
|
| 76 |
+
print("Encoder")
|
| 77 |
+
print("x.shape", x.shape)
|
| 78 |
+
st_x = self.spatio_temporal_encoder(x)
|
| 79 |
+
if self.debug:
|
| 80 |
+
print("st_x.shape", st_x.shape)
|
| 81 |
+
t_x = self.temporal_encoder(st_x)
|
| 82 |
+
if self.debug:
|
| 83 |
+
print("t_x.shape", t_x.shape)
|
| 84 |
+
return t_x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class decoder_block(nn.Module):
|
| 88 |
+
def __init__(self, debug=False):
|
| 89 |
+
super(decoder_block, self).__init__()
|
| 90 |
+
self.debug = debug
|
| 91 |
+
self.decoder_block = nn.Sequential(
|
| 92 |
+
DeConvBlock3D(nf[4], nf[3], [7, 3, 3], [2, 2, 2], [2, 1, 1]),
|
| 93 |
+
DeConvBlock3D(nf[3], nf[2], [7, 3, 3], [2, 2, 2], [2, 1, 1])
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
if self.debug:
|
| 98 |
+
print("Decoder")
|
| 99 |
+
print("x.shape", x.shape)
|
| 100 |
+
x = self.decoder_block(x)
|
| 101 |
+
if self.debug:
|
| 102 |
+
print("x.shape", x.shape)
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class iBVPNet(nn.Module):
|
| 108 |
+
def __init__(self, frames, in_channels=3, debug=False):
|
| 109 |
+
super(iBVPNet, self).__init__()
|
| 110 |
+
self.debug = debug
|
| 111 |
+
|
| 112 |
+
self.in_channels = in_channels
|
| 113 |
+
if self.in_channels == 1 or self.in_channels == 3:
|
| 114 |
+
self.norm = nn.InstanceNorm3d(self.in_channels)
|
| 115 |
+
elif self.in_channels == 4:
|
| 116 |
+
self.rgb_norm = nn.InstanceNorm3d(3)
|
| 117 |
+
self.thermal_norm = nn.InstanceNorm3d(1)
|
| 118 |
+
else:
|
| 119 |
+
print("Unsupported input channels")
|
| 120 |
+
|
| 121 |
+
self.ibvpnet = nn.Sequential(
|
| 122 |
+
encoder_block(in_channels, debug),
|
| 123 |
+
decoder_block(debug),
|
| 124 |
+
# spatial adaptive pooling
|
| 125 |
+
nn.AdaptiveMaxPool3d((frames, 1, 1)),
|
| 126 |
+
nn.Conv3d(nf[2], 1, [1, 1, 1], stride=1, padding=0)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32]
|
| 131 |
+
|
| 132 |
+
[batch, channel, length, width, height] = x.shape
|
| 133 |
+
|
| 134 |
+
x = torch.diff(x, dim=2)
|
| 135 |
+
|
| 136 |
+
if self.debug:
|
| 137 |
+
print("Input.shape", x.shape)
|
| 138 |
+
|
| 139 |
+
if self.in_channels == 1:
|
| 140 |
+
x = self.norm(x[:, -1:, :, :, :])
|
| 141 |
+
elif self.in_channels == 3:
|
| 142 |
+
x = self.norm(x[:, :3, :, :, :])
|
| 143 |
+
elif self.in_channels == 4:
|
| 144 |
+
rgb_x = self.rgb_norm(x[:, :3, :, :, :])
|
| 145 |
+
thermal_x = self.thermal_norm(x[:, -1:, :, :, :])
|
| 146 |
+
x = torch.concat([rgb_x, thermal_x], dim = 1)
|
| 147 |
+
else:
|
| 148 |
+
try:
|
| 149 |
+
print("Specified input channels:", self.in_channels)
|
| 150 |
+
print("Data channels", channel)
|
| 151 |
+
assert self.in_channels <= channel
|
| 152 |
+
except:
|
| 153 |
+
print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels")
|
| 154 |
+
print("Default or specified channels:", self.in_channels)
|
| 155 |
+
print("Data channels [B, C, N, W, H]", x.shape)
|
| 156 |
+
print("Exiting")
|
| 157 |
+
exit()
|
| 158 |
+
|
| 159 |
+
if self.debug:
|
| 160 |
+
print("Diff Normalized shape", x.shape)
|
| 161 |
+
|
| 162 |
+
feats = self.ibvpnet(x)
|
| 163 |
+
if self.debug:
|
| 164 |
+
print("feats.shape", feats.shape)
|
| 165 |
+
rPPG = feats.view(-1, length-1)
|
| 166 |
+
return rPPG
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
import torch
|
| 171 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 172 |
+
|
| 173 |
+
# default `log_dir` is "runs" - we'll be more specific here
|
| 174 |
+
writer = SummaryWriter('runs/iBVPNet')
|
| 175 |
+
|
| 176 |
+
duration = 8
|
| 177 |
+
fs = 25
|
| 178 |
+
batch_size = 4
|
| 179 |
+
frames = duration*fs
|
| 180 |
+
in_channels = 1
|
| 181 |
+
height = 64
|
| 182 |
+
width = 64
|
| 183 |
+
test_data = torch.rand(batch_size, in_channels, frames, height, width)
|
| 184 |
+
|
| 185 |
+
net = iBVPNet(in_channels=in_channels, frames=frames, debug=True)
|
| 186 |
+
# print("-"*100)
|
| 187 |
+
# print(net)
|
| 188 |
+
# print("-"*100)
|
| 189 |
+
pred = net(test_data)
|
| 190 |
+
|
| 191 |
+
print(pred.shape)
|
| 192 |
+
|
| 193 |
+
writer.add_graph(net, test_data)
|
| 194 |
+
writer.close()
|
neural_methods/trainer/BaseTrainer.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.autograd import Variable
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from matplotlib.ticker import ScalarFormatter, MaxNLocator
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseTrainer:
|
| 10 |
+
@staticmethod
|
| 11 |
+
def add_trainer_args(parser):
|
| 12 |
+
"""Adds arguments to Paser for training process"""
|
| 13 |
+
parser.add_argument('--lr', default=None, type=float)
|
| 14 |
+
parser.add_argument('--model_file_name', default=None, type=float)
|
| 15 |
+
return parser
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def train(self, data_loader):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def valid(self, data_loader):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
def test(self):
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
def save_test_outputs(self, predictions, labels, config):
|
| 30 |
+
|
| 31 |
+
output_dir = config.TEST.OUTPUT_SAVE_DIR
|
| 32 |
+
if not os.path.exists(output_dir):
|
| 33 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
# Filename ID to be used in any output files that get saved
|
| 36 |
+
if config.TOOLBOX_MODE == 'train_and_test':
|
| 37 |
+
filename_id = self.model_file_name
|
| 38 |
+
elif config.TOOLBOX_MODE == 'only_test':
|
| 39 |
+
model_file_root = config.INFERENCE.MODEL_PATH.split("/")[-1].split(".pth")[0]
|
| 40 |
+
filename_id = model_file_root + "_" + config.TEST.DATA.DATASET
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError('Metrics.py evaluation only supports train_and_test and only_test!')
|
| 43 |
+
output_path = os.path.join(output_dir, filename_id + '_outputs.pickle')
|
| 44 |
+
|
| 45 |
+
data = dict()
|
| 46 |
+
data['predictions'] = predictions
|
| 47 |
+
data['labels'] = labels
|
| 48 |
+
data['label_type'] = config.TEST.DATA.PREPROCESS.LABEL_TYPE
|
| 49 |
+
data['fs'] = config.TEST.DATA.FS
|
| 50 |
+
|
| 51 |
+
with open(output_path, 'wb') as handle: # save out frame dict pickle file
|
| 52 |
+
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 53 |
+
|
| 54 |
+
print('Saving outputs to:', output_path)
|
| 55 |
+
|
| 56 |
+
def plot_losses_and_lrs(self, train_loss, valid_loss, lrs, config):
|
| 57 |
+
|
| 58 |
+
output_dir = os.path.join(config.LOG.PATH, config.TRAIN.DATA.EXP_DATA_NAME, 'plots')
|
| 59 |
+
if not os.path.exists(output_dir):
|
| 60 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# Filename ID to be used in plots that get saved
|
| 63 |
+
if config.TOOLBOX_MODE == 'train_and_test':
|
| 64 |
+
filename_id = self.model_file_name
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError('Metrics.py evaluation only supports train_and_test and only_test!')
|
| 67 |
+
|
| 68 |
+
# Create a single plot for training and validation losses
|
| 69 |
+
plt.figure(figsize=(10, 6))
|
| 70 |
+
epochs = range(0, len(train_loss)) # Integer values for x-axis
|
| 71 |
+
plt.plot(epochs, train_loss, label='Training Loss')
|
| 72 |
+
if len(valid_loss) > 0:
|
| 73 |
+
plt.plot(epochs, valid_loss, label='Validation Loss')
|
| 74 |
+
else:
|
| 75 |
+
print("The list of validation losses is empty. The validation loss will not be plotted!")
|
| 76 |
+
plt.xlabel('Epoch')
|
| 77 |
+
plt.ylabel('Loss')
|
| 78 |
+
plt.title(f'{filename_id} Losses')
|
| 79 |
+
plt.legend()
|
| 80 |
+
plt.xticks(epochs)
|
| 81 |
+
|
| 82 |
+
# Set y-axis ticks with more granularity
|
| 83 |
+
ax = plt.gca()
|
| 84 |
+
ax.yaxis.set_major_locator(MaxNLocator(integer=False, prune='both'))
|
| 85 |
+
|
| 86 |
+
loss_plot_filename = os.path.join(output_dir, filename_id + '_losses.pdf')
|
| 87 |
+
plt.savefig(loss_plot_filename, dpi=300)
|
| 88 |
+
plt.close()
|
| 89 |
+
|
| 90 |
+
# Create a separate plot for learning rates
|
| 91 |
+
plt.figure(figsize=(6, 4))
|
| 92 |
+
scheduler_steps = range(0, len(lrs))
|
| 93 |
+
plt.plot(scheduler_steps, lrs, label='Learning Rate')
|
| 94 |
+
plt.xlabel('Scheduler Step')
|
| 95 |
+
plt.ylabel('Learning Rate')
|
| 96 |
+
plt.title(f'{filename_id} LR Schedule')
|
| 97 |
+
plt.legend()
|
| 98 |
+
|
| 99 |
+
# Set y-axis values in scientific notation
|
| 100 |
+
ax = plt.gca()
|
| 101 |
+
ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True, useOffset=False))
|
| 102 |
+
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0)) # Force scientific notation
|
| 103 |
+
|
| 104 |
+
lr_plot_filename = os.path.join(output_dir, filename_id + '_learning_rates.pdf')
|
| 105 |
+
plt.savefig(lr_plot_filename, bbox_inches='tight', dpi=300)
|
| 106 |
+
plt.close()
|
| 107 |
+
|
| 108 |
+
print('Saving plots of losses and learning rates to:', output_dir)
|
neural_methods/trainer/BigSmallTrainer.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trainer for BigSmall Multitask Models"""
|
| 2 |
+
|
| 3 |
+
# Training / Eval Imports
|
| 4 |
+
import torch
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 7 |
+
from neural_methods import loss
|
| 8 |
+
from neural_methods.model.BigSmall import BigSmall
|
| 9 |
+
from evaluation.bigsmall_multitask_metrics import (calculate_bvp_metrics,
|
| 10 |
+
calculate_resp_metrics,
|
| 11 |
+
calculate_bp4d_au_metrics)
|
| 12 |
+
|
| 13 |
+
# Other Imports
|
| 14 |
+
from collections import OrderedDict
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
class BigSmallTrainer(BaseTrainer):
|
| 20 |
+
|
| 21 |
+
def define_model(self, config):
|
| 22 |
+
|
| 23 |
+
# BigSmall Model
|
| 24 |
+
model = BigSmall(n_segment=3)
|
| 25 |
+
|
| 26 |
+
if self.using_TSM:
|
| 27 |
+
self.frame_depth = config.MODEL.BIGSMALL.FRAME_DEPTH
|
| 28 |
+
self.base_len = self.num_of_gpu * self.frame_depth
|
| 29 |
+
|
| 30 |
+
return model
|
| 31 |
+
|
| 32 |
+
def format_data_shape(self, data, labels):
|
| 33 |
+
# reshape big data
|
| 34 |
+
data_big = data[0]
|
| 35 |
+
N, D, C, H, W = data_big.shape
|
| 36 |
+
data_big = data_big.view(N * D, C, H, W)
|
| 37 |
+
|
| 38 |
+
# reshape small data
|
| 39 |
+
data_small = data[1]
|
| 40 |
+
N, D, C, H, W = data_small.shape
|
| 41 |
+
data_small = data_small.view(N * D, C, H, W)
|
| 42 |
+
|
| 43 |
+
# reshape labels
|
| 44 |
+
if len(labels.shape) != 3: # this training format requires labels that are of shape N_label, D_label, C_label
|
| 45 |
+
labels = torch.unsqueeze(labels, dim=-1)
|
| 46 |
+
N_label, D_label, C_label = labels.shape
|
| 47 |
+
labels = labels.view(N_label * D_label, C_label)
|
| 48 |
+
|
| 49 |
+
# If using temporal shift module
|
| 50 |
+
if self.using_TSM:
|
| 51 |
+
data_big = data_big[:(N * D) // self.base_len * self.base_len]
|
| 52 |
+
data_small = data_small[:(N * D) // self.base_len * self.base_len]
|
| 53 |
+
labels = labels[:(N * D) // self.base_len * self.base_len]
|
| 54 |
+
|
| 55 |
+
data[0] = data_big
|
| 56 |
+
data[1] = data_small
|
| 57 |
+
labels = torch.unsqueeze(labels, dim=-1)
|
| 58 |
+
|
| 59 |
+
return data, labels
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def send_data_to_device(self, data, labels):
|
| 63 |
+
big_data = data[0].to(self.device)
|
| 64 |
+
small_data = data[1].to(self.device)
|
| 65 |
+
labels = labels.to(self.device)
|
| 66 |
+
data = (big_data, small_data)
|
| 67 |
+
return data, labels
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_label_idxs(self, label_list, used_labels):
|
| 71 |
+
label_idxs = []
|
| 72 |
+
for l in used_labels:
|
| 73 |
+
idx = label_list.index(l)
|
| 74 |
+
label_idxs.append(idx)
|
| 75 |
+
return label_idxs
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def remove_data_parallel(self, old_state_dict):
|
| 79 |
+
new_state_dict = OrderedDict()
|
| 80 |
+
|
| 81 |
+
for k, v in old_state_dict.items():
|
| 82 |
+
name = k[7:] # remove `module.`
|
| 83 |
+
new_state_dict[name] = v
|
| 84 |
+
|
| 85 |
+
return new_state_dict
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def save_model(self, index):
|
| 89 |
+
if not os.path.exists(self.model_dir):
|
| 90 |
+
os.makedirs(self.model_dir)
|
| 91 |
+
model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 92 |
+
torch.save(self.model.state_dict(), model_path)
|
| 93 |
+
print('Saved Model Path: ', model_path)
|
| 94 |
+
print('')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def __init__(self, config, data_loader):
|
| 98 |
+
|
| 99 |
+
print('')
|
| 100 |
+
print('Init BigSmall Multitask Trainer\n\n')
|
| 101 |
+
|
| 102 |
+
self.config = config # save config file
|
| 103 |
+
|
| 104 |
+
# Set up GPU/CPU compute device
|
| 105 |
+
if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0:
|
| 106 |
+
self.device = torch.device(config.DEVICE) # set device to primary GPU
|
| 107 |
+
self.num_of_gpu = config.NUM_OF_GPU_TRAIN # set number of used GPUs
|
| 108 |
+
else:
|
| 109 |
+
self.device = "cpu" # if no GPUs set device is CPU
|
| 110 |
+
self.num_of_gpu = 0 # no GPUs used
|
| 111 |
+
|
| 112 |
+
# Defining model
|
| 113 |
+
self.using_TSM = True
|
| 114 |
+
self.model = self.define_model(config) # define the model
|
| 115 |
+
|
| 116 |
+
if torch.cuda.device_count() > 1 and config.NUM_OF_GPU_TRAIN > 1: # distribute model across GPUs
|
| 117 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) # data parallel model
|
| 118 |
+
|
| 119 |
+
self.model = self.model.to(self.device) # send model to primary GPU
|
| 120 |
+
|
| 121 |
+
# Training parameters
|
| 122 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 123 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 124 |
+
self.LR = config.TRAIN.LR
|
| 125 |
+
|
| 126 |
+
# Set Loss and Optimizer
|
| 127 |
+
AU_weights = torch.as_tensor([9.64, 11.74, 16.77, 1.05, 0.53, 0.56,
|
| 128 |
+
0.75, 0.69, 8.51, 6.94, 5.03, 25.00]).to(self.device)
|
| 129 |
+
|
| 130 |
+
self.criterionAU = torch.nn.BCEWithLogitsLoss(pos_weight=AU_weights).to(self.device)
|
| 131 |
+
self.criterionBVP = torch.nn.MSELoss().to(self.device)
|
| 132 |
+
self.criterionRESP = torch.nn.MSELoss().to(self.device)
|
| 133 |
+
self.optimizer = optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0)
|
| 134 |
+
|
| 135 |
+
# self.scaler = torch.cuda.amp.GradScaler() # Loss scalar
|
| 136 |
+
|
| 137 |
+
# Model info (saved more dir, chunk len, best epoch, etc.)
|
| 138 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 139 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 140 |
+
self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
|
| 141 |
+
|
| 142 |
+
# Epoch To Use For Test
|
| 143 |
+
self.used_epoch = 0
|
| 144 |
+
|
| 145 |
+
# Indicies corresponding to used labels
|
| 146 |
+
label_list = ['bp_wave', 'HR_bpm', 'systolic_bp', 'diastolic_bp', 'mean_bp',
|
| 147 |
+
'resp_wave', 'resp_bpm', 'eda',
|
| 148 |
+
'AU01', 'AU02', 'AU04', 'AU05', 'AU06', 'AU06int', 'AU07', 'AU09', 'AU10', 'AU10int',
|
| 149 |
+
'AU11', 'AU12', 'AU12int', 'AU13', 'AU14', 'AU14int', 'AU15', 'AU16', 'AU17', 'AU17int',
|
| 150 |
+
'AU18', 'AU19', 'AU20', 'AU22', 'AU23', 'AU24', 'AU27', 'AU28', 'AU29', 'AU30', 'AU31',
|
| 151 |
+
'AU32', 'AU33', 'AU34', 'AU35', 'AU36', 'AU37', 'AU38', 'AU39',
|
| 152 |
+
'pos_bvp','pos_env_norm_bvp']
|
| 153 |
+
|
| 154 |
+
used_labels = ['bp_wave', 'AU01', 'AU02', 'AU04', 'AU06', 'AU07', 'AU10', 'AU12',
|
| 155 |
+
'AU14', 'AU15', 'AU17', 'AU23', 'AU24',
|
| 156 |
+
'pos_env_norm_bvp', 'resp_wave']
|
| 157 |
+
|
| 158 |
+
# Get indicies for labels from npy array
|
| 159 |
+
au_label_list = [label for label in used_labels if 'AU' in label]
|
| 160 |
+
bvp_label_list_train = [label for label in used_labels if 'bvp' in label]
|
| 161 |
+
bvp_label_list_test = [label for label in used_labels if 'bp_wave' in label]
|
| 162 |
+
resp_label_list = [label for label in used_labels if 'resp' in label]
|
| 163 |
+
|
| 164 |
+
self.label_idx_train_au = self.get_label_idxs(label_list, au_label_list)
|
| 165 |
+
self.label_idx_valid_au = self.get_label_idxs(label_list, au_label_list)
|
| 166 |
+
self.label_idx_test_au = self.get_label_idxs(label_list, au_label_list)
|
| 167 |
+
|
| 168 |
+
self.label_idx_train_bvp = self.get_label_idxs(label_list, bvp_label_list_train)
|
| 169 |
+
self.label_idx_valid_bvp = self.get_label_idxs(label_list, bvp_label_list_train)
|
| 170 |
+
self.label_idx_test_bvp = self.get_label_idxs(label_list, bvp_label_list_test)
|
| 171 |
+
|
| 172 |
+
self.label_idx_train_resp = self.get_label_idxs(label_list, resp_label_list)
|
| 173 |
+
self.label_idx_valid_resp = self.get_label_idxs(label_list, resp_label_list)
|
| 174 |
+
self.label_idx_test_resp = self.get_label_idxs(label_list, resp_label_list)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def train(self, data_loader):
|
| 178 |
+
"""Model Training"""
|
| 179 |
+
|
| 180 |
+
if data_loader["train"] is None:
|
| 181 |
+
raise ValueError("No data for train")
|
| 182 |
+
|
| 183 |
+
print('Starting Training Routine')
|
| 184 |
+
print('')
|
| 185 |
+
|
| 186 |
+
# Init min validation loss as infinity
|
| 187 |
+
min_valid_loss = np.inf # minimum validation loss
|
| 188 |
+
|
| 189 |
+
# ARRAYS TO SAVE (LOSS ARRAYS)
|
| 190 |
+
train_loss_dict = dict()
|
| 191 |
+
train_au_loss_dict = dict()
|
| 192 |
+
train_bvp_loss_dict = dict()
|
| 193 |
+
train_resp_loss_dict = dict()
|
| 194 |
+
|
| 195 |
+
val_loss_dict = dict()
|
| 196 |
+
val_au_loss_dict = dict()
|
| 197 |
+
val_bvp_loss_dict = dict()
|
| 198 |
+
val_resp_loss_dict = dict()
|
| 199 |
+
|
| 200 |
+
# TODO: Expand tracking and subsequent plotting of these losses for BigSmall
|
| 201 |
+
mean_training_losses = []
|
| 202 |
+
mean_valid_losses = []
|
| 203 |
+
lrs = []
|
| 204 |
+
|
| 205 |
+
# ITERATE THROUGH EPOCHS
|
| 206 |
+
for epoch in range(self.max_epoch_num):
|
| 207 |
+
print(f"====Training Epoch: {epoch}====")
|
| 208 |
+
|
| 209 |
+
# INIT PARAMS FOR TRAINING
|
| 210 |
+
running_loss = 0.0 # tracks avg loss over mini batches of 100
|
| 211 |
+
train_loss = []
|
| 212 |
+
train_au_loss = []
|
| 213 |
+
train_bvp_loss = []
|
| 214 |
+
train_resp_loss = []
|
| 215 |
+
self.model.train() # put model in train mode
|
| 216 |
+
|
| 217 |
+
# MODEL TRAINING
|
| 218 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 219 |
+
for idx, batch in enumerate(tbar):
|
| 220 |
+
tbar.set_description("Train epoch %s" % epoch)
|
| 221 |
+
|
| 222 |
+
# GATHER AND FORMAT BATCH DATA
|
| 223 |
+
data, labels = batch[0], batch[1]
|
| 224 |
+
data, labels = self.format_data_shape(data, labels)
|
| 225 |
+
data, labels = self.send_data_to_device(data, labels)
|
| 226 |
+
|
| 227 |
+
# FOWARD AND BACK PROPOGATE THROUGH MODEL
|
| 228 |
+
self.optimizer.zero_grad()
|
| 229 |
+
au_out, bvp_out, resp_out = self.model(data)
|
| 230 |
+
au_loss = self.criterionAU(au_out, labels[:, self.label_idx_train_au, 0]) # au loss
|
| 231 |
+
bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_train_bvp, 0]) # bvp loss
|
| 232 |
+
resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_train_resp, 0]) # resp loss
|
| 233 |
+
loss = au_loss + bvp_loss + resp_loss # sum losses
|
| 234 |
+
loss.backward()
|
| 235 |
+
|
| 236 |
+
# Append the current learning rate to the list
|
| 237 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 238 |
+
|
| 239 |
+
self.optimizer.step()
|
| 240 |
+
# self.scaler.scale(loss).backward() # Loss scaling
|
| 241 |
+
# self.scaler.step(self.optimizer)
|
| 242 |
+
# self.scaler.update()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# UPDATE RUNNING LOSS AND PRINTED TERMINAL OUTPUT AND SAVED LOSSES
|
| 248 |
+
train_loss.append(loss.item())
|
| 249 |
+
train_au_loss.append(au_loss.item())
|
| 250 |
+
train_bvp_loss.append(bvp_loss.item())
|
| 251 |
+
train_resp_loss.append(resp_loss.item())
|
| 252 |
+
|
| 253 |
+
running_loss += loss.item()
|
| 254 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 255 |
+
print(f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
|
| 256 |
+
running_loss = 0.0
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
tbar.set_postfix({"loss:": loss.item(), "lr:": self.optimizer.param_groups[0]["lr"]})
|
| 260 |
+
|
| 261 |
+
# APPEND EPOCH LOSS LIST TO TRAINING LOSS DICTIONARY
|
| 262 |
+
train_loss_dict[epoch] = train_loss
|
| 263 |
+
train_au_loss_dict[epoch] = train_au_loss
|
| 264 |
+
train_bvp_loss_dict[epoch] = train_bvp_loss
|
| 265 |
+
train_resp_loss_dict[epoch] = train_resp_loss
|
| 266 |
+
|
| 267 |
+
print('')
|
| 268 |
+
|
| 269 |
+
# Append the mean training loss for the epoch
|
| 270 |
+
mean_training_losses.append(np.mean(train_loss))
|
| 271 |
+
|
| 272 |
+
# SAVE MODEL FOR THIS EPOCH
|
| 273 |
+
self.save_model(epoch)
|
| 274 |
+
|
| 275 |
+
# VALIDATION (IF ENABLED)
|
| 276 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 277 |
+
|
| 278 |
+
# Get validation losses
|
| 279 |
+
valid_loss, valid_au_loss, valid_bvp_loss, valid_resp_loss = self.valid(data_loader)
|
| 280 |
+
mean_valid_losses.append(valid_loss)
|
| 281 |
+
val_loss_dict[epoch] = valid_loss
|
| 282 |
+
val_au_loss_dict[epoch] = valid_au_loss
|
| 283 |
+
val_bvp_loss_dict[epoch] = valid_bvp_loss
|
| 284 |
+
val_resp_loss_dict[epoch] = valid_resp_loss
|
| 285 |
+
print('validation loss: ', valid_loss)
|
| 286 |
+
|
| 287 |
+
# Update used model
|
| 288 |
+
if self.model_to_use == 'best_epoch' and (valid_loss < min_valid_loss):
|
| 289 |
+
min_valid_loss = valid_loss
|
| 290 |
+
self.used_epoch = epoch
|
| 291 |
+
print("Update best model! Best epoch: {}".format(self.used_epoch))
|
| 292 |
+
elif self.model_to_use == 'last_epoch':
|
| 293 |
+
self.used_epoch = epoch
|
| 294 |
+
|
| 295 |
+
# VALIDATION (NOT ENABLED)
|
| 296 |
+
else:
|
| 297 |
+
self.used_epoch = epoch
|
| 298 |
+
|
| 299 |
+
print('')
|
| 300 |
+
|
| 301 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 302 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 303 |
+
|
| 304 |
+
# PRINT MODEL TO BE USED FOR TESTING
|
| 305 |
+
print("Used model trained epoch:{}, val_loss:{}".format(self.used_epoch, min_valid_loss))
|
| 306 |
+
print('')
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def valid(self, data_loader):
|
| 311 |
+
""" Model evaluation on the validation dataset."""
|
| 312 |
+
|
| 313 |
+
if data_loader["valid"] is None:
|
| 314 |
+
raise ValueError("No data for valid")
|
| 315 |
+
|
| 316 |
+
print("===Validating===")
|
| 317 |
+
|
| 318 |
+
# INIT PARAMS FOR VALIDATION
|
| 319 |
+
valid_loss = []
|
| 320 |
+
valid_au_loss = []
|
| 321 |
+
valid_bvp_loss = []
|
| 322 |
+
valid_resp_loss = []
|
| 323 |
+
self.model.eval()
|
| 324 |
+
|
| 325 |
+
# MODEL VALIDATION
|
| 326 |
+
with torch.no_grad():
|
| 327 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 328 |
+
for valid_idx, valid_batch in enumerate(vbar):
|
| 329 |
+
vbar.set_description("Validation")
|
| 330 |
+
|
| 331 |
+
# GATHER AND FORMAT BATCH DATA
|
| 332 |
+
data, labels = valid_batch[0], valid_batch[1]
|
| 333 |
+
data, labels = self.format_data_shape(data, labels)
|
| 334 |
+
data, labels = self.send_data_to_device(data, labels)
|
| 335 |
+
|
| 336 |
+
au_out, bvp_out, resp_out = self.model(data)
|
| 337 |
+
au_loss = self.criterionAU(au_out, labels[:, self.label_idx_valid_au, 0]) # au loss
|
| 338 |
+
bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_valid_bvp, 0]) # bvp loss
|
| 339 |
+
resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_valid_resp, 0]) # resp loss
|
| 340 |
+
loss = au_loss + bvp_loss + resp_loss # sum losses
|
| 341 |
+
|
| 342 |
+
# APPEND VAL LOSS
|
| 343 |
+
valid_loss.append(loss.item())
|
| 344 |
+
valid_au_loss.append(au_loss.item())
|
| 345 |
+
valid_bvp_loss.append(bvp_loss.item())
|
| 346 |
+
valid_resp_loss.append(resp_loss.item())
|
| 347 |
+
vbar.set_postfix(loss=loss.item())
|
| 348 |
+
|
| 349 |
+
valid_loss = np.asarray(valid_loss)
|
| 350 |
+
valid_au_loss = np.asarray(valid_au_loss)
|
| 351 |
+
valid_bvp_loss = np.asarray(valid_bvp_loss)
|
| 352 |
+
valid_resp_loss = np.asarray(valid_resp_loss)
|
| 353 |
+
return np.mean(valid_loss), np.mean(valid_au_loss), np.mean(valid_bvp_loss), np.mean(valid_resp_loss)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def test(self, data_loader):
|
| 358 |
+
""" Model evaluation on the testing dataset."""
|
| 359 |
+
|
| 360 |
+
print("===Testing===")
|
| 361 |
+
print('')
|
| 362 |
+
|
| 363 |
+
# SETUP
|
| 364 |
+
if data_loader["test"] is None:
|
| 365 |
+
raise ValueError("No data for test")
|
| 366 |
+
|
| 367 |
+
# Change chunk length to be test chunk length
|
| 368 |
+
self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 369 |
+
|
| 370 |
+
# ARRAYS TO SAVE (PREDICTIONS AND METRICS ARRAYS)
|
| 371 |
+
preds_dict_au = dict()
|
| 372 |
+
labels_dict_au = dict()
|
| 373 |
+
preds_dict_bvp = dict()
|
| 374 |
+
labels_dict_bvp = dict()
|
| 375 |
+
preds_dict_resp = dict()
|
| 376 |
+
labels_dict_resp = dict()
|
| 377 |
+
|
| 378 |
+
# IF ONLY_TEST MODE LOAD PRETRAINED MODEL
|
| 379 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 380 |
+
model_path = self.config.INFERENCE.MODEL_PATH
|
| 381 |
+
print("Testing uses pretrained model!")
|
| 382 |
+
print('Model path:', model_path)
|
| 383 |
+
if not os.path.exists(model_path):
|
| 384 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 385 |
+
|
| 386 |
+
# IF USING MODEL FROM TRAINING
|
| 387 |
+
else:
|
| 388 |
+
model_path = os.path.join(self.model_dir,
|
| 389 |
+
self.model_file_name + '_Epoch' + str(self.used_epoch) + '.pth')
|
| 390 |
+
print("Testing uses non-pretrained model!")
|
| 391 |
+
print('Model path:', model_path)
|
| 392 |
+
if not os.path.exists(model_path):
|
| 393 |
+
raise ValueError("Something went wrong... cant find trained model...")
|
| 394 |
+
print('')
|
| 395 |
+
|
| 396 |
+
# LOAD ABOVED SPECIFIED MODEL FOR TESTING
|
| 397 |
+
self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
| 398 |
+
self.model = self.model.to(self.device)
|
| 399 |
+
self.model.eval()
|
| 400 |
+
|
| 401 |
+
# MODEL TESTING
|
| 402 |
+
print("Running model evaluation on the testing dataset!")
|
| 403 |
+
with torch.no_grad():
|
| 404 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 405 |
+
|
| 406 |
+
# PROCESSING - ANALYSIS, METRICS, SAVING OUT DATA
|
| 407 |
+
batch_size = test_batch[1].shape[0] # get batch size
|
| 408 |
+
|
| 409 |
+
# GATHER AND FORMAT BATCH DATA
|
| 410 |
+
data, labels = test_batch[0], test_batch[1]
|
| 411 |
+
data, labels = self.format_data_shape(data, labels)
|
| 412 |
+
data, labels = self.send_data_to_device(data, labels)
|
| 413 |
+
|
| 414 |
+
# Weird dataloader bug is causing the final training batch to be of size 0...
|
| 415 |
+
if labels.shape[0] == 0:
|
| 416 |
+
continue
|
| 417 |
+
|
| 418 |
+
# GET MODEL PREDICTIONS
|
| 419 |
+
au_out, bvp_out, resp_out = self.model(data)
|
| 420 |
+
au_out = torch.sigmoid(au_out)
|
| 421 |
+
|
| 422 |
+
# GATHER AND SLICE LABELS USED FOR TEST DATASET
|
| 423 |
+
TEST_AU = False
|
| 424 |
+
if len(self.label_idx_test_au) > 0: # if test dataset has AU
|
| 425 |
+
TEST_AU = True
|
| 426 |
+
labels_au = labels[:, self.label_idx_test_au]
|
| 427 |
+
else: # if not set whole AU labels array to -1
|
| 428 |
+
labels_au = np.ones((batch_size, len(self.label_idx_train_au)))
|
| 429 |
+
labels_au = -1 * labels_au
|
| 430 |
+
# labels_au = torch.from_numpy(labels_au)
|
| 431 |
+
|
| 432 |
+
TEST_BVP = False
|
| 433 |
+
if len(self.label_idx_test_bvp) > 0: # if test dataset has BVP
|
| 434 |
+
TEST_BVP = True
|
| 435 |
+
labels_bvp = labels[:, self.label_idx_test_bvp]
|
| 436 |
+
else: # if not set whole BVP labels array to -1
|
| 437 |
+
labels_bvp = np.ones((batch_size, len(self.label_idx_train_bvp)))
|
| 438 |
+
labels_bvp = -1 * labels_bvp
|
| 439 |
+
# labels_bvp = torch.from_numpy(labels_bvp)
|
| 440 |
+
|
| 441 |
+
TEST_RESP = False
|
| 442 |
+
if len(self.label_idx_test_resp) > 0: # if test dataset has BVP
|
| 443 |
+
TEST_RESP = True
|
| 444 |
+
labels_resp = labels[:, self.label_idx_test_resp]
|
| 445 |
+
else: # if not set whole BVP labels array to -1
|
| 446 |
+
labels_resp = np.ones((batch_size, len(self.label_idx_train_resp)))
|
| 447 |
+
labels_resp = -1 * labels_resp
|
| 448 |
+
# labels_resp = torch.from_numpy(labels_resp)
|
| 449 |
+
|
| 450 |
+
# ITERATE THROUGH BATCH, SORT, AND ADD TO CORRECT DICTIONARY
|
| 451 |
+
for idx in range(batch_size):
|
| 452 |
+
|
| 453 |
+
# if the labels are cut off due to TSM dataformating
|
| 454 |
+
if idx * self.chunk_len >= labels.shape[0] and self.using_TSM:
|
| 455 |
+
continue
|
| 456 |
+
|
| 457 |
+
subj_index = test_batch[2][idx]
|
| 458 |
+
sort_index = int(test_batch[3][idx])
|
| 459 |
+
|
| 460 |
+
# add subject to prediction / label arrays
|
| 461 |
+
if subj_index not in preds_dict_bvp.keys():
|
| 462 |
+
preds_dict_au[subj_index] = dict()
|
| 463 |
+
labels_dict_au[subj_index] = dict()
|
| 464 |
+
preds_dict_bvp[subj_index] = dict()
|
| 465 |
+
labels_dict_bvp[subj_index] = dict()
|
| 466 |
+
preds_dict_resp[subj_index] = dict()
|
| 467 |
+
labels_dict_resp[subj_index] = dict()
|
| 468 |
+
|
| 469 |
+
# append predictions and labels to subject dict
|
| 470 |
+
preds_dict_au[subj_index][sort_index] = au_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 471 |
+
labels_dict_au[subj_index][sort_index] = labels_au[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 472 |
+
preds_dict_bvp[subj_index][sort_index] = bvp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 473 |
+
labels_dict_bvp[subj_index][sort_index] = labels_bvp[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 474 |
+
preds_dict_resp[subj_index][sort_index] = resp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 475 |
+
labels_dict_resp[subj_index][sort_index] = labels_resp[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 476 |
+
|
| 477 |
+
# Calculate Eval Metrics
|
| 478 |
+
bvp_metric_dict = calculate_bvp_metrics(preds_dict_bvp, labels_dict_bvp, self.config)
|
| 479 |
+
resp_metric_dict = calculate_resp_metrics(preds_dict_resp, labels_dict_resp, self.config)
|
| 480 |
+
au_metric_dict = calculate_bp4d_au_metrics(preds_dict_au, labels_dict_au, self.config)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
|
neural_methods/trainer/BigSmallTrainer.py.backup
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trainer for BigSmall Multitask Models"""
|
| 2 |
+
|
| 3 |
+
# Training / Eval Imports
|
| 4 |
+
import torch
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 7 |
+
from neural_methods import loss
|
| 8 |
+
from neural_methods.model.BigSmall import BigSmall
|
| 9 |
+
from evaluation.bigsmall_multitask_metrics import (calculate_bvp_metrics,
|
| 10 |
+
calculate_resp_metrics,
|
| 11 |
+
calculate_bp4d_au_metrics)
|
| 12 |
+
|
| 13 |
+
# Other Imports
|
| 14 |
+
from collections import OrderedDict
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
class BigSmallTrainer(BaseTrainer):
|
| 20 |
+
|
| 21 |
+
def define_model(self, config):
|
| 22 |
+
|
| 23 |
+
# BigSmall Model
|
| 24 |
+
model = BigSmall(n_segment=3)
|
| 25 |
+
|
| 26 |
+
if self.using_TSM:
|
| 27 |
+
self.frame_depth = config.MODEL.BIGSMALL.FRAME_DEPTH
|
| 28 |
+
self.base_len = self.num_of_gpu * self.frame_depth
|
| 29 |
+
|
| 30 |
+
return model
|
| 31 |
+
|
| 32 |
+
def format_data_shape(self, data, labels):
|
| 33 |
+
# reshape big data
|
| 34 |
+
data_big = data[0]
|
| 35 |
+
N, D, C, H, W = data_big.shape
|
| 36 |
+
data_big = data_big.view(N * D, C, H, W)
|
| 37 |
+
|
| 38 |
+
# reshape small data
|
| 39 |
+
data_small = data[1]
|
| 40 |
+
N, D, C, H, W = data_small.shape
|
| 41 |
+
data_small = data_small.view(N * D, C, H, W)
|
| 42 |
+
|
| 43 |
+
# reshape labels
|
| 44 |
+
if len(labels.shape) != 3: # this training format requires labels that are of shape N_label, D_label, C_label
|
| 45 |
+
labels = torch.unsqueeze(labels, dim=-1)
|
| 46 |
+
N_label, D_label, C_label = labels.shape
|
| 47 |
+
labels = labels.view(N_label * D_label, C_label)
|
| 48 |
+
|
| 49 |
+
# If using temporal shift module
|
| 50 |
+
if self.using_TSM:
|
| 51 |
+
data_big = data_big[:(N * D) // self.base_len * self.base_len]
|
| 52 |
+
data_small = data_small[:(N * D) // self.base_len * self.base_len]
|
| 53 |
+
labels = labels[:(N * D) // self.base_len * self.base_len]
|
| 54 |
+
|
| 55 |
+
data[0] = data_big
|
| 56 |
+
data[1] = data_small
|
| 57 |
+
labels = torch.unsqueeze(labels, dim=-1)
|
| 58 |
+
|
| 59 |
+
return data, labels
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def send_data_to_device(self, data, labels):
|
| 63 |
+
big_data = data[0].to(self.device)
|
| 64 |
+
small_data = data[1].to(self.device)
|
| 65 |
+
labels = labels.to(self.device)
|
| 66 |
+
data = (big_data, small_data)
|
| 67 |
+
return data, labels
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_label_idxs(self, label_list, used_labels):
|
| 71 |
+
label_idxs = []
|
| 72 |
+
for l in used_labels:
|
| 73 |
+
idx = label_list.index(l)
|
| 74 |
+
label_idxs.append(idx)
|
| 75 |
+
return label_idxs
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def remove_data_parallel(self, old_state_dict):
|
| 79 |
+
new_state_dict = OrderedDict()
|
| 80 |
+
|
| 81 |
+
for k, v in old_state_dict.items():
|
| 82 |
+
name = k[7:] # remove `module.`
|
| 83 |
+
new_state_dict[name] = v
|
| 84 |
+
|
| 85 |
+
return new_state_dict
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def save_model(self, index):
|
| 89 |
+
if not os.path.exists(self.model_dir):
|
| 90 |
+
os.makedirs(self.model_dir)
|
| 91 |
+
model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 92 |
+
torch.save(self.model.state_dict(), model_path)
|
| 93 |
+
print('Saved Model Path: ', model_path)
|
| 94 |
+
print('')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def __init__(self, config, data_loader):
|
| 98 |
+
|
| 99 |
+
print('')
|
| 100 |
+
print('Init BigSmall Multitask Trainer\n\n')
|
| 101 |
+
|
| 102 |
+
self.config = config # save config file
|
| 103 |
+
|
| 104 |
+
# Set up GPU/CPU compute device
|
| 105 |
+
if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0:
|
| 106 |
+
self.device = torch.device(config.DEVICE) # set device to primary GPU
|
| 107 |
+
self.num_of_gpu = config.NUM_OF_GPU_TRAIN # set number of used GPUs
|
| 108 |
+
else:
|
| 109 |
+
self.device = "cpu" # if no GPUs set device is CPU
|
| 110 |
+
self.num_of_gpu = 0 # no GPUs used
|
| 111 |
+
|
| 112 |
+
# Defining model
|
| 113 |
+
self.using_TSM = True
|
| 114 |
+
self.model = self.define_model(config) # define the model
|
| 115 |
+
|
| 116 |
+
if torch.cuda.device_count() > 1 and config.NUM_OF_GPU_TRAIN > 1: # distribute model across GPUs
|
| 117 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) # data parallel model
|
| 118 |
+
|
| 119 |
+
self.model = self.model.to(self.device) # send model to primary GPU
|
| 120 |
+
|
| 121 |
+
# Training parameters
|
| 122 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 123 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 124 |
+
self.LR = config.TRAIN.LR
|
| 125 |
+
|
| 126 |
+
# Set Loss and Optimizer
|
| 127 |
+
AU_weights = torch.as_tensor([9.64, 11.74, 16.77, 1.05, 0.53, 0.56,
|
| 128 |
+
0.75, 0.69, 8.51, 6.94, 5.03, 25.00]).to(self.device)
|
| 129 |
+
|
| 130 |
+
self.criterionAU = torch.nn.BCEWithLogitsLoss(pos_weight=AU_weights).to(self.device)
|
| 131 |
+
self.criterionBVP = torch.nn.MSELoss().to(self.device)
|
| 132 |
+
self.criterionRESP = torch.nn.MSELoss().to(self.device)
|
| 133 |
+
self.optimizer = optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0)
|
| 134 |
+
|
| 135 |
+
# self.scaler = torch.cuda.amp.GradScaler() # Loss scalar
|
| 136 |
+
|
| 137 |
+
# Model info (saved more dir, chunk len, best epoch, etc.)
|
| 138 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 139 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 140 |
+
self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
|
| 141 |
+
|
| 142 |
+
# Epoch To Use For Test
|
| 143 |
+
self.used_epoch = 0
|
| 144 |
+
|
| 145 |
+
# Indicies corresponding to used labels
|
| 146 |
+
label_list = ['bp_wave', 'HR_bpm', 'systolic_bp', 'diastolic_bp', 'mean_bp',
|
| 147 |
+
'resp_wave', 'resp_bpm', 'eda',
|
| 148 |
+
'AU01', 'AU02', 'AU04', 'AU05', 'AU06', 'AU06int', 'AU07', 'AU09', 'AU10', 'AU10int',
|
| 149 |
+
'AU11', 'AU12', 'AU12int', 'AU13', 'AU14', 'AU14int', 'AU15', 'AU16', 'AU17', 'AU17int',
|
| 150 |
+
'AU18', 'AU19', 'AU20', 'AU22', 'AU23', 'AU24', 'AU27', 'AU28', 'AU29', 'AU30', 'AU31',
|
| 151 |
+
'AU32', 'AU33', 'AU34', 'AU35', 'AU36', 'AU37', 'AU38', 'AU39',
|
| 152 |
+
'pos_bvp','pos_env_norm_bvp']
|
| 153 |
+
|
| 154 |
+
used_labels = ['bp_wave', 'AU01', 'AU02', 'AU04', 'AU06', 'AU07', 'AU10', 'AU12',
|
| 155 |
+
'AU14', 'AU15', 'AU17', 'AU23', 'AU24',
|
| 156 |
+
'pos_env_norm_bvp', 'resp_wave']
|
| 157 |
+
|
| 158 |
+
# Get indicies for labels from npy array
|
| 159 |
+
au_label_list = [label for label in used_labels if 'AU' in label]
|
| 160 |
+
bvp_label_list_train = [label for label in used_labels if 'bvp' in label]
|
| 161 |
+
bvp_label_list_test = [label for label in used_labels if 'bp_wave' in label]
|
| 162 |
+
resp_label_list = [label for label in used_labels if 'resp' in label]
|
| 163 |
+
|
| 164 |
+
self.label_idx_train_au = self.get_label_idxs(label_list, au_label_list)
|
| 165 |
+
self.label_idx_valid_au = self.get_label_idxs(label_list, au_label_list)
|
| 166 |
+
self.label_idx_test_au = self.get_label_idxs(label_list, au_label_list)
|
| 167 |
+
|
| 168 |
+
self.label_idx_train_bvp = self.get_label_idxs(label_list, bvp_label_list_train)
|
| 169 |
+
self.label_idx_valid_bvp = self.get_label_idxs(label_list, bvp_label_list_train)
|
| 170 |
+
self.label_idx_test_bvp = self.get_label_idxs(label_list, bvp_label_list_test)
|
| 171 |
+
|
| 172 |
+
self.label_idx_train_resp = self.get_label_idxs(label_list, resp_label_list)
|
| 173 |
+
self.label_idx_valid_resp = self.get_label_idxs(label_list, resp_label_list)
|
| 174 |
+
self.label_idx_test_resp = self.get_label_idxs(label_list, resp_label_list)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def train(self, data_loader):
|
| 178 |
+
"""Model Training"""
|
| 179 |
+
|
| 180 |
+
if data_loader["train"] is None:
|
| 181 |
+
raise ValueError("No data for train")
|
| 182 |
+
|
| 183 |
+
print('Starting Training Routine')
|
| 184 |
+
print('')
|
| 185 |
+
|
| 186 |
+
# Init min validation loss as infinity
|
| 187 |
+
min_valid_loss = np.inf # minimum validation loss
|
| 188 |
+
|
| 189 |
+
# ARRAYS TO SAVE (LOSS ARRAYS)
|
| 190 |
+
train_loss_dict = dict()
|
| 191 |
+
train_au_loss_dict = dict()
|
| 192 |
+
train_bvp_loss_dict = dict()
|
| 193 |
+
train_resp_loss_dict = dict()
|
| 194 |
+
|
| 195 |
+
val_loss_dict = dict()
|
| 196 |
+
val_au_loss_dict = dict()
|
| 197 |
+
val_bvp_loss_dict = dict()
|
| 198 |
+
val_resp_loss_dict = dict()
|
| 199 |
+
|
| 200 |
+
# TODO: Expand tracking and subsequent plotting of these losses for BigSmall
|
| 201 |
+
mean_training_losses = []
|
| 202 |
+
mean_valid_losses = []
|
| 203 |
+
lrs = []
|
| 204 |
+
|
| 205 |
+
# ITERATE THROUGH EPOCHS
|
| 206 |
+
for epoch in range(self.max_epoch_num):
|
| 207 |
+
print(f"====Training Epoch: {epoch}====")
|
| 208 |
+
|
| 209 |
+
# INIT PARAMS FOR TRAINING
|
| 210 |
+
running_loss = 0.0 # tracks avg loss over mini batches of 100
|
| 211 |
+
train_loss = []
|
| 212 |
+
train_au_loss = []
|
| 213 |
+
train_bvp_loss = []
|
| 214 |
+
train_resp_loss = []
|
| 215 |
+
self.model.train() # put model in train mode
|
| 216 |
+
|
| 217 |
+
# MODEL TRAINING
|
| 218 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 219 |
+
for idx, batch in enumerate(tbar):
|
| 220 |
+
tbar.set_description("Train epoch %s" % epoch)
|
| 221 |
+
|
| 222 |
+
# GATHER AND FORMAT BATCH DATA
|
| 223 |
+
data, labels = batch[0], batch[1]
|
| 224 |
+
data, labels = self.format_data_shape(data, labels)
|
| 225 |
+
data, labels = self.send_data_to_device(data, labels)
|
| 226 |
+
|
| 227 |
+
# FOWARD AND BACK PROPOGATE THROUGH MODEL
|
| 228 |
+
self.optimizer.zero_grad()
|
| 229 |
+
au_out, bvp_out, resp_out = self.model(data)
|
| 230 |
+
au_loss = self.criterionAU(au_out, labels[:, self.label_idx_train_au, 0]) # au loss
|
| 231 |
+
bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_train_bvp, 0]) # bvp loss
|
| 232 |
+
resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_train_resp, 0]) # resp loss
|
| 233 |
+
loss = au_loss + bvp_loss + resp_loss # sum losses
|
| 234 |
+
loss.backward()
|
| 235 |
+
|
| 236 |
+
# Append the current learning rate to the list
|
| 237 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 238 |
+
|
| 239 |
+
self.optimizer.step()
|
| 240 |
+
# self.scaler.scale(loss).backward() # Loss scaling
|
| 241 |
+
# self.scaler.step(self.optimizer)
|
| 242 |
+
# self.scaler.update()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# UPDATE RUNNING LOSS AND PRINTED TERMINAL OUTPUT AND SAVED LOSSES
|
| 248 |
+
train_loss.append(loss.item())
|
| 249 |
+
train_au_loss.append(au_loss.item())
|
| 250 |
+
train_bvp_loss.append(bvp_loss.item())
|
| 251 |
+
train_resp_loss.append(resp_loss.item())
|
| 252 |
+
|
| 253 |
+
running_loss += loss.item()
|
| 254 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 255 |
+
print(f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
|
| 256 |
+
running_loss = 0.0
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
tbar.set_postfix({"loss:": loss.item(), "lr:": self.optimizer.param_groups[0]["lr"]})
|
| 260 |
+
|
| 261 |
+
# APPEND EPOCH LOSS LIST TO TRAINING LOSS DICTIONARY
|
| 262 |
+
train_loss_dict[epoch] = train_loss
|
| 263 |
+
train_au_loss_dict[epoch] = train_au_loss
|
| 264 |
+
train_bvp_loss_dict[epoch] = train_bvp_loss
|
| 265 |
+
train_resp_loss_dict[epoch] = train_resp_loss
|
| 266 |
+
|
| 267 |
+
print('')
|
| 268 |
+
|
| 269 |
+
# Append the mean training loss for the epoch
|
| 270 |
+
mean_training_losses.append(np.mean(train_loss))
|
| 271 |
+
|
| 272 |
+
# SAVE MODEL FOR THIS EPOCH
|
| 273 |
+
self.save_model(epoch)
|
| 274 |
+
|
| 275 |
+
# VALIDATION (IF ENABLED)
|
| 276 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 277 |
+
|
| 278 |
+
# Get validation losses
|
| 279 |
+
valid_loss, valid_au_loss, valid_bvp_loss, valid_resp_loss = self.valid(data_loader)
|
| 280 |
+
mean_valid_losses.append(valid_loss)
|
| 281 |
+
val_loss_dict[epoch] = valid_loss
|
| 282 |
+
val_au_loss_dict[epoch] = valid_au_loss
|
| 283 |
+
val_bvp_loss_dict[epoch] = valid_bvp_loss
|
| 284 |
+
val_resp_loss_dict[epoch] = valid_resp_loss
|
| 285 |
+
print('validation loss: ', valid_loss)
|
| 286 |
+
|
| 287 |
+
# Update used model
|
| 288 |
+
if self.model_to_use == 'best_epoch' and (valid_loss < min_valid_loss):
|
| 289 |
+
min_valid_loss = valid_loss
|
| 290 |
+
self.used_epoch = epoch
|
| 291 |
+
print("Update best model! Best epoch: {}".format(self.used_epoch))
|
| 292 |
+
elif self.model_to_use == 'last_epoch':
|
| 293 |
+
self.used_epoch = epoch
|
| 294 |
+
|
| 295 |
+
# VALIDATION (NOT ENABLED)
|
| 296 |
+
else:
|
| 297 |
+
self.used_epoch = epoch
|
| 298 |
+
|
| 299 |
+
print('')
|
| 300 |
+
|
| 301 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 302 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 303 |
+
|
| 304 |
+
# PRINT MODEL TO BE USED FOR TESTING
|
| 305 |
+
print("Used model trained epoch:{}, val_loss:{}".format(self.used_epoch, min_valid_loss))
|
| 306 |
+
print('')
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def valid(self, data_loader):
|
| 311 |
+
""" Model evaluation on the validation dataset."""
|
| 312 |
+
|
| 313 |
+
if data_loader["valid"] is None:
|
| 314 |
+
raise ValueError("No data for valid")
|
| 315 |
+
|
| 316 |
+
print("===Validating===")
|
| 317 |
+
|
| 318 |
+
# INIT PARAMS FOR VALIDATION
|
| 319 |
+
valid_loss = []
|
| 320 |
+
valid_au_loss = []
|
| 321 |
+
valid_bvp_loss = []
|
| 322 |
+
valid_resp_loss = []
|
| 323 |
+
self.model.eval()
|
| 324 |
+
|
| 325 |
+
# MODEL VALIDATION
|
| 326 |
+
with torch.no_grad():
|
| 327 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 328 |
+
for valid_idx, valid_batch in enumerate(vbar):
|
| 329 |
+
vbar.set_description("Validation")
|
| 330 |
+
|
| 331 |
+
# GATHER AND FORMAT BATCH DATA
|
| 332 |
+
data, labels = valid_batch[0], valid_batch[1]
|
| 333 |
+
data, labels = self.format_data_shape(data, labels)
|
| 334 |
+
data, labels = self.send_data_to_device(data, labels)
|
| 335 |
+
|
| 336 |
+
au_out, bvp_out, resp_out = self.model(data)
|
| 337 |
+
au_loss = self.criterionAU(au_out, labels[:, self.label_idx_valid_au, 0]) # au loss
|
| 338 |
+
bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_valid_bvp, 0]) # bvp loss
|
| 339 |
+
resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_valid_resp, 0]) # resp loss
|
| 340 |
+
loss = au_loss + bvp_loss + resp_loss # sum losses
|
| 341 |
+
|
| 342 |
+
# APPEND VAL LOSS
|
| 343 |
+
valid_loss.append(loss.item())
|
| 344 |
+
valid_au_loss.append(au_loss.item())
|
| 345 |
+
valid_bvp_loss.append(bvp_loss.item())
|
| 346 |
+
valid_resp_loss.append(resp_loss.item())
|
| 347 |
+
vbar.set_postfix(loss=loss.item())
|
| 348 |
+
|
| 349 |
+
valid_loss = np.asarray(valid_loss)
|
| 350 |
+
valid_au_loss = np.asarray(valid_au_loss)
|
| 351 |
+
valid_bvp_loss = np.asarray(valid_bvp_loss)
|
| 352 |
+
valid_resp_loss = np.asarray(valid_resp_loss)
|
| 353 |
+
return np.mean(valid_loss), np.mean(valid_au_loss), np.mean(valid_bvp_loss), np.mean(valid_resp_loss)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def test(self, data_loader):
|
| 358 |
+
""" Model evaluation on the testing dataset."""
|
| 359 |
+
|
| 360 |
+
print("===Testing===")
|
| 361 |
+
print('')
|
| 362 |
+
|
| 363 |
+
# SETUP
|
| 364 |
+
if data_loader["test"] is None:
|
| 365 |
+
raise ValueError("No data for test")
|
| 366 |
+
|
| 367 |
+
# Change chunk length to be test chunk length
|
| 368 |
+
self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 369 |
+
|
| 370 |
+
# ARRAYS TO SAVE (PREDICTIONS AND METRICS ARRAYS)
|
| 371 |
+
preds_dict_au = dict()
|
| 372 |
+
labels_dict_au = dict()
|
| 373 |
+
preds_dict_bvp = dict()
|
| 374 |
+
labels_dict_bvp = dict()
|
| 375 |
+
preds_dict_resp = dict()
|
| 376 |
+
labels_dict_resp = dict()
|
| 377 |
+
|
| 378 |
+
# IF ONLY_TEST MODE LOAD PRETRAINED MODEL
|
| 379 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 380 |
+
model_path = self.config.INFERENCE.MODEL_PATH
|
| 381 |
+
print("Testing uses pretrained model!")
|
| 382 |
+
print('Model path:', model_path)
|
| 383 |
+
if not os.path.exists(model_path):
|
| 384 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 385 |
+
|
| 386 |
+
# IF USING MODEL FROM TRAINING
|
| 387 |
+
else:
|
| 388 |
+
model_path = os.path.join(self.model_dir,
|
| 389 |
+
self.model_file_name + '_Epoch' + str(self.used_epoch) + '.pth')
|
| 390 |
+
print("Testing uses non-pretrained model!")
|
| 391 |
+
print('Model path:', model_path)
|
| 392 |
+
if not os.path.exists(model_path):
|
| 393 |
+
raise ValueError("Something went wrong... cant find trained model...")
|
| 394 |
+
print('')
|
| 395 |
+
|
| 396 |
+
# LOAD ABOVED SPECIFIED MODEL FOR TESTING
|
| 397 |
+
self.model.load_state_dict(torch.load(model_path))
|
| 398 |
+
self.model = self.model.to(self.device)
|
| 399 |
+
self.model.eval()
|
| 400 |
+
|
| 401 |
+
# MODEL TESTING
|
| 402 |
+
print("Running model evaluation on the testing dataset!")
|
| 403 |
+
with torch.no_grad():
|
| 404 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 405 |
+
|
| 406 |
+
# PROCESSING - ANALYSIS, METRICS, SAVING OUT DATA
|
| 407 |
+
batch_size = test_batch[1].shape[0] # get batch size
|
| 408 |
+
|
| 409 |
+
# GATHER AND FORMAT BATCH DATA
|
| 410 |
+
data, labels = test_batch[0], test_batch[1]
|
| 411 |
+
data, labels = self.format_data_shape(data, labels)
|
| 412 |
+
data, labels = self.send_data_to_device(data, labels)
|
| 413 |
+
|
| 414 |
+
# Weird dataloader bug is causing the final training batch to be of size 0...
|
| 415 |
+
if labels.shape[0] == 0:
|
| 416 |
+
continue
|
| 417 |
+
|
| 418 |
+
# GET MODEL PREDICTIONS
|
| 419 |
+
au_out, bvp_out, resp_out = self.model(data)
|
| 420 |
+
au_out = torch.sigmoid(au_out)
|
| 421 |
+
|
| 422 |
+
# GATHER AND SLICE LABELS USED FOR TEST DATASET
|
| 423 |
+
TEST_AU = False
|
| 424 |
+
if len(self.label_idx_test_au) > 0: # if test dataset has AU
|
| 425 |
+
TEST_AU = True
|
| 426 |
+
labels_au = labels[:, self.label_idx_test_au]
|
| 427 |
+
else: # if not set whole AU labels array to -1
|
| 428 |
+
labels_au = np.ones((batch_size, len(self.label_idx_train_au)))
|
| 429 |
+
labels_au = -1 * labels_au
|
| 430 |
+
# labels_au = torch.from_numpy(labels_au)
|
| 431 |
+
|
| 432 |
+
TEST_BVP = False
|
| 433 |
+
if len(self.label_idx_test_bvp) > 0: # if test dataset has BVP
|
| 434 |
+
TEST_BVP = True
|
| 435 |
+
labels_bvp = labels[:, self.label_idx_test_bvp]
|
| 436 |
+
else: # if not set whole BVP labels array to -1
|
| 437 |
+
labels_bvp = np.ones((batch_size, len(self.label_idx_train_bvp)))
|
| 438 |
+
labels_bvp = -1 * labels_bvp
|
| 439 |
+
# labels_bvp = torch.from_numpy(labels_bvp)
|
| 440 |
+
|
| 441 |
+
TEST_RESP = False
|
| 442 |
+
if len(self.label_idx_test_resp) > 0: # if test dataset has BVP
|
| 443 |
+
TEST_RESP = True
|
| 444 |
+
labels_resp = labels[:, self.label_idx_test_resp]
|
| 445 |
+
else: # if not set whole BVP labels array to -1
|
| 446 |
+
labels_resp = np.ones((batch_size, len(self.label_idx_train_resp)))
|
| 447 |
+
labels_resp = -1 * labels_resp
|
| 448 |
+
# labels_resp = torch.from_numpy(labels_resp)
|
| 449 |
+
|
| 450 |
+
# ITERATE THROUGH BATCH, SORT, AND ADD TO CORRECT DICTIONARY
|
| 451 |
+
for idx in range(batch_size):
|
| 452 |
+
|
| 453 |
+
# if the labels are cut off due to TSM dataformating
|
| 454 |
+
if idx * self.chunk_len >= labels.shape[0] and self.using_TSM:
|
| 455 |
+
continue
|
| 456 |
+
|
| 457 |
+
subj_index = test_batch[2][idx]
|
| 458 |
+
sort_index = int(test_batch[3][idx])
|
| 459 |
+
|
| 460 |
+
# add subject to prediction / label arrays
|
| 461 |
+
if subj_index not in preds_dict_bvp.keys():
|
| 462 |
+
preds_dict_au[subj_index] = dict()
|
| 463 |
+
labels_dict_au[subj_index] = dict()
|
| 464 |
+
preds_dict_bvp[subj_index] = dict()
|
| 465 |
+
labels_dict_bvp[subj_index] = dict()
|
| 466 |
+
preds_dict_resp[subj_index] = dict()
|
| 467 |
+
labels_dict_resp[subj_index] = dict()
|
| 468 |
+
|
| 469 |
+
# append predictions and labels to subject dict
|
| 470 |
+
preds_dict_au[subj_index][sort_index] = au_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 471 |
+
labels_dict_au[subj_index][sort_index] = labels_au[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 472 |
+
preds_dict_bvp[subj_index][sort_index] = bvp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 473 |
+
labels_dict_bvp[subj_index][sort_index] = labels_bvp[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 474 |
+
preds_dict_resp[subj_index][sort_index] = resp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 475 |
+
labels_dict_resp[subj_index][sort_index] = labels_resp[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 476 |
+
|
| 477 |
+
# Calculate Eval Metrics
|
| 478 |
+
bvp_metric_dict = calculate_bvp_metrics(preds_dict_bvp, labels_dict_bvp, self.config)
|
| 479 |
+
resp_metric_dict = calculate_resp_metrics(preds_dict_resp, labels_dict_resp, self.config)
|
| 480 |
+
au_metric_dict = calculate_bp4d_au_metrics(preds_dict_au, labels_dict_au, self.config)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
|
neural_methods/trainer/DeepPhysTrainer.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trainer for DeepPhys."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
from evaluation.metrics import calculate_metrics
|
| 11 |
+
from neural_methods.loss.NegPearsonLoss import Neg_Pearson
|
| 12 |
+
from neural_methods.model.DeepPhys import DeepPhys
|
| 13 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DeepPhysTrainer(BaseTrainer):
|
| 18 |
+
|
| 19 |
+
def __init__(self, config, data_loader):
|
| 20 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.device = torch.device(config.DEVICE)
|
| 23 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 24 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 25 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 26 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 27 |
+
self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
|
| 28 |
+
self.config = config
|
| 29 |
+
self.min_valid_loss = None
|
| 30 |
+
self.best_epoch = 0
|
| 31 |
+
|
| 32 |
+
if config.TOOLBOX_MODE == "train_and_test":
|
| 33 |
+
self.model = DeepPhys(img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(self.device)
|
| 34 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 35 |
+
|
| 36 |
+
self.num_train_batches = len(data_loader["train"])
|
| 37 |
+
self.criterion = torch.nn.MSELoss()
|
| 38 |
+
self.optimizer = optim.AdamW(
|
| 39 |
+
self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0)
|
| 40 |
+
# See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
|
| 41 |
+
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 42 |
+
self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
|
| 43 |
+
elif config.TOOLBOX_MODE == "only_test":
|
| 44 |
+
self.model = DeepPhys(img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(self.device)
|
| 45 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError("DeepPhys trainer initialized in incorrect toolbox mode!")
|
| 48 |
+
|
| 49 |
+
def train(self, data_loader):
|
| 50 |
+
"""Training routine for model"""
|
| 51 |
+
if data_loader["train"] is None:
|
| 52 |
+
raise ValueError("No data for train")
|
| 53 |
+
|
| 54 |
+
mean_training_losses = []
|
| 55 |
+
mean_valid_losses = []
|
| 56 |
+
lrs = []
|
| 57 |
+
for epoch in range(self.max_epoch_num):
|
| 58 |
+
print('')
|
| 59 |
+
print(f"====Training Epoch: {epoch}====")
|
| 60 |
+
running_loss = 0.0
|
| 61 |
+
train_loss = []
|
| 62 |
+
self.model.train()
|
| 63 |
+
# Model Training
|
| 64 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 65 |
+
for idx, batch in enumerate(tbar):
|
| 66 |
+
tbar.set_description("Train epoch %s" % epoch)
|
| 67 |
+
data, labels = batch[0].to(
|
| 68 |
+
self.device), batch[1].to(self.device)
|
| 69 |
+
N, D, C, H, W = data.shape
|
| 70 |
+
data = data.view(N * D, C, H, W)
|
| 71 |
+
labels = labels.view(-1, 1)
|
| 72 |
+
self.optimizer.zero_grad()
|
| 73 |
+
pred_ppg = self.model(data)
|
| 74 |
+
loss = self.criterion(pred_ppg, labels)
|
| 75 |
+
loss.backward()
|
| 76 |
+
|
| 77 |
+
# Append the current learning rate to the list
|
| 78 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 79 |
+
|
| 80 |
+
self.optimizer.step()
|
| 81 |
+
self.scheduler.step()
|
| 82 |
+
running_loss += loss.item()
|
| 83 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 84 |
+
print(
|
| 85 |
+
f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
|
| 86 |
+
running_loss = 0.0
|
| 87 |
+
train_loss.append(loss.item())
|
| 88 |
+
tbar.set_postfix({"loss": loss.item(), "lr": self.optimizer.param_groups[0]["lr"]})
|
| 89 |
+
|
| 90 |
+
# Append the mean training loss for the epoch
|
| 91 |
+
mean_training_losses.append(np.mean(train_loss))
|
| 92 |
+
|
| 93 |
+
self.save_model(epoch)
|
| 94 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 95 |
+
valid_loss = self.valid(data_loader)
|
| 96 |
+
mean_valid_losses.append(valid_loss)
|
| 97 |
+
print('validation loss: ', valid_loss)
|
| 98 |
+
if self.min_valid_loss is None:
|
| 99 |
+
self.min_valid_loss = valid_loss
|
| 100 |
+
self.best_epoch = epoch
|
| 101 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 102 |
+
elif (valid_loss < self.min_valid_loss):
|
| 103 |
+
self.min_valid_loss = valid_loss
|
| 104 |
+
self.best_epoch = epoch
|
| 105 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 106 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 107 |
+
print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss))
|
| 108 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 109 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 110 |
+
|
| 111 |
+
def valid(self, data_loader):
|
| 112 |
+
""" Model evaluation on the validation dataset."""
|
| 113 |
+
if data_loader["valid"] is None:
|
| 114 |
+
raise ValueError("No data for valid")
|
| 115 |
+
|
| 116 |
+
print('')
|
| 117 |
+
print("===Validating===")
|
| 118 |
+
valid_loss = []
|
| 119 |
+
self.model.eval()
|
| 120 |
+
valid_step = 0
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 123 |
+
for valid_idx, valid_batch in enumerate(vbar):
|
| 124 |
+
vbar.set_description("Validation")
|
| 125 |
+
data_valid, labels_valid = valid_batch[0].to(
|
| 126 |
+
self.device), valid_batch[1].to(self.device)
|
| 127 |
+
N, D, C, H, W = data_valid.shape
|
| 128 |
+
data_valid = data_valid.view(N * D, C, H, W)
|
| 129 |
+
labels_valid = labels_valid.view(-1, 1)
|
| 130 |
+
pred_ppg_valid = self.model(data_valid)
|
| 131 |
+
loss = self.criterion(pred_ppg_valid, labels_valid)
|
| 132 |
+
valid_loss.append(loss.item())
|
| 133 |
+
valid_step += 1
|
| 134 |
+
vbar.set_postfix(loss=loss.item())
|
| 135 |
+
valid_loss = np.asarray(valid_loss)
|
| 136 |
+
return np.mean(valid_loss)
|
| 137 |
+
|
| 138 |
+
def test(self, data_loader):
|
| 139 |
+
""" Model evaluation on the testing dataset."""
|
| 140 |
+
if data_loader["test"] is None:
|
| 141 |
+
raise ValueError("No data for test")
|
| 142 |
+
config = self.config
|
| 143 |
+
|
| 144 |
+
print('')
|
| 145 |
+
print("===Testing===")
|
| 146 |
+
|
| 147 |
+
# Change chunk length to be test chunk length
|
| 148 |
+
self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 149 |
+
|
| 150 |
+
predictions = dict()
|
| 151 |
+
labels = dict()
|
| 152 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 153 |
+
if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
|
| 154 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 155 |
+
self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")))
|
| 156 |
+
print("Testing uses pretrained model!")
|
| 157 |
+
else:
|
| 158 |
+
if self.config.TEST.USE_LAST_EPOCH:
|
| 159 |
+
last_epoch_model_path = os.path.join(
|
| 160 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
|
| 161 |
+
print("Testing uses last epoch as non-pretrained model!")
|
| 162 |
+
print(last_epoch_model_path)
|
| 163 |
+
self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")))
|
| 164 |
+
else:
|
| 165 |
+
best_model_path = os.path.join(
|
| 166 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
|
| 167 |
+
print("Testing uses best epoch selected using model selection as non-pretrained model!")
|
| 168 |
+
print(best_model_path)
|
| 169 |
+
self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")))
|
| 170 |
+
|
| 171 |
+
self.model = self.model.to(self.config.DEVICE)
|
| 172 |
+
self.model.eval()
|
| 173 |
+
print("Running model evaluation on the testing dataset!")
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 176 |
+
batch_size = test_batch[0].shape[0]
|
| 177 |
+
data_test, labels_test = test_batch[0].to(
|
| 178 |
+
self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
|
| 179 |
+
N, D, C, H, W = data_test.shape
|
| 180 |
+
data_test = data_test.view(N * D, C, H, W)
|
| 181 |
+
labels_test = labels_test.view(-1, 1)
|
| 182 |
+
pred_ppg_test = self.model(data_test)
|
| 183 |
+
|
| 184 |
+
if self.config.TEST.OUTPUT_SAVE_DIR:
|
| 185 |
+
labels_test = labels_test.cpu()
|
| 186 |
+
pred_ppg_test = pred_ppg_test.cpu()
|
| 187 |
+
|
| 188 |
+
for idx in range(batch_size):
|
| 189 |
+
subj_index = test_batch[2][idx]
|
| 190 |
+
sort_index = int(test_batch[3][idx])
|
| 191 |
+
if subj_index not in predictions.keys():
|
| 192 |
+
predictions[subj_index] = dict()
|
| 193 |
+
labels[subj_index] = dict()
|
| 194 |
+
predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 195 |
+
labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 196 |
+
|
| 197 |
+
print('')
|
| 198 |
+
calculate_metrics(predictions, labels, self.config)
|
| 199 |
+
if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
|
| 200 |
+
self.save_test_outputs(predictions, labels, self.config)
|
| 201 |
+
|
| 202 |
+
def save_model(self, index):
|
| 203 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 204 |
+
if not os.path.exists(self.model_dir):
|
| 205 |
+
os.makedirs(self.model_dir)
|
| 206 |
+
model_path = os.path.join(
|
| 207 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 208 |
+
torch.save(self.model.state_dict(), model_path)
|
| 209 |
+
|
neural_methods/trainer/DeepPhysTrainer.py.backup
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trainer for DeepPhys."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
from evaluation.metrics import calculate_metrics
|
| 11 |
+
from neural_methods.loss.NegPearsonLoss import Neg_Pearson
|
| 12 |
+
from neural_methods.model.DeepPhys import DeepPhys
|
| 13 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DeepPhysTrainer(BaseTrainer):
|
| 18 |
+
|
| 19 |
+
def __init__(self, config, data_loader):
|
| 20 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.device = torch.device(config.DEVICE)
|
| 23 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 24 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 25 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 26 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 27 |
+
self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
|
| 28 |
+
self.config = config
|
| 29 |
+
self.min_valid_loss = None
|
| 30 |
+
self.best_epoch = 0
|
| 31 |
+
|
| 32 |
+
if config.TOOLBOX_MODE == "train_and_test":
|
| 33 |
+
self.model = DeepPhys(img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(self.device)
|
| 34 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 35 |
+
|
| 36 |
+
self.num_train_batches = len(data_loader["train"])
|
| 37 |
+
self.criterion = torch.nn.MSELoss()
|
| 38 |
+
self.optimizer = optim.AdamW(
|
| 39 |
+
self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0)
|
| 40 |
+
# See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
|
| 41 |
+
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 42 |
+
self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
|
| 43 |
+
elif config.TOOLBOX_MODE == "only_test":
|
| 44 |
+
self.model = DeepPhys(img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(self.device)
|
| 45 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError("DeepPhys trainer initialized in incorrect toolbox mode!")
|
| 48 |
+
|
| 49 |
+
def train(self, data_loader):
|
| 50 |
+
"""Training routine for model"""
|
| 51 |
+
if data_loader["train"] is None:
|
| 52 |
+
raise ValueError("No data for train")
|
| 53 |
+
|
| 54 |
+
mean_training_losses = []
|
| 55 |
+
mean_valid_losses = []
|
| 56 |
+
lrs = []
|
| 57 |
+
for epoch in range(self.max_epoch_num):
|
| 58 |
+
print('')
|
| 59 |
+
print(f"====Training Epoch: {epoch}====")
|
| 60 |
+
running_loss = 0.0
|
| 61 |
+
train_loss = []
|
| 62 |
+
self.model.train()
|
| 63 |
+
# Model Training
|
| 64 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 65 |
+
for idx, batch in enumerate(tbar):
|
| 66 |
+
tbar.set_description("Train epoch %s" % epoch)
|
| 67 |
+
data, labels = batch[0].to(
|
| 68 |
+
self.device), batch[1].to(self.device)
|
| 69 |
+
N, D, C, H, W = data.shape
|
| 70 |
+
data = data.view(N * D, C, H, W)
|
| 71 |
+
labels = labels.view(-1, 1)
|
| 72 |
+
self.optimizer.zero_grad()
|
| 73 |
+
pred_ppg = self.model(data)
|
| 74 |
+
loss = self.criterion(pred_ppg, labels)
|
| 75 |
+
loss.backward()
|
| 76 |
+
|
| 77 |
+
# Append the current learning rate to the list
|
| 78 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 79 |
+
|
| 80 |
+
self.optimizer.step()
|
| 81 |
+
self.scheduler.step()
|
| 82 |
+
running_loss += loss.item()
|
| 83 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 84 |
+
print(
|
| 85 |
+
f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
|
| 86 |
+
running_loss = 0.0
|
| 87 |
+
train_loss.append(loss.item())
|
| 88 |
+
tbar.set_postfix({"loss": loss.item(), "lr": self.optimizer.param_groups[0]["lr"]})
|
| 89 |
+
|
| 90 |
+
# Append the mean training loss for the epoch
|
| 91 |
+
mean_training_losses.append(np.mean(train_loss))
|
| 92 |
+
|
| 93 |
+
self.save_model(epoch)
|
| 94 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 95 |
+
valid_loss = self.valid(data_loader)
|
| 96 |
+
mean_valid_losses.append(valid_loss)
|
| 97 |
+
print('validation loss: ', valid_loss)
|
| 98 |
+
if self.min_valid_loss is None:
|
| 99 |
+
self.min_valid_loss = valid_loss
|
| 100 |
+
self.best_epoch = epoch
|
| 101 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 102 |
+
elif (valid_loss < self.min_valid_loss):
|
| 103 |
+
self.min_valid_loss = valid_loss
|
| 104 |
+
self.best_epoch = epoch
|
| 105 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 106 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 107 |
+
print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss))
|
| 108 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 109 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 110 |
+
|
| 111 |
+
def valid(self, data_loader):
|
| 112 |
+
""" Model evaluation on the validation dataset."""
|
| 113 |
+
if data_loader["valid"] is None:
|
| 114 |
+
raise ValueError("No data for valid")
|
| 115 |
+
|
| 116 |
+
print('')
|
| 117 |
+
print("===Validating===")
|
| 118 |
+
valid_loss = []
|
| 119 |
+
self.model.eval()
|
| 120 |
+
valid_step = 0
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 123 |
+
for valid_idx, valid_batch in enumerate(vbar):
|
| 124 |
+
vbar.set_description("Validation")
|
| 125 |
+
data_valid, labels_valid = valid_batch[0].to(
|
| 126 |
+
self.device), valid_batch[1].to(self.device)
|
| 127 |
+
N, D, C, H, W = data_valid.shape
|
| 128 |
+
data_valid = data_valid.view(N * D, C, H, W)
|
| 129 |
+
labels_valid = labels_valid.view(-1, 1)
|
| 130 |
+
pred_ppg_valid = self.model(data_valid)
|
| 131 |
+
loss = self.criterion(pred_ppg_valid, labels_valid)
|
| 132 |
+
valid_loss.append(loss.item())
|
| 133 |
+
valid_step += 1
|
| 134 |
+
vbar.set_postfix(loss=loss.item())
|
| 135 |
+
valid_loss = np.asarray(valid_loss)
|
| 136 |
+
return np.mean(valid_loss)
|
| 137 |
+
|
| 138 |
+
def test(self, data_loader):
|
| 139 |
+
""" Model evaluation on the testing dataset."""
|
| 140 |
+
if data_loader["test"] is None:
|
| 141 |
+
raise ValueError("No data for test")
|
| 142 |
+
config = self.config
|
| 143 |
+
|
| 144 |
+
print('')
|
| 145 |
+
print("===Testing===")
|
| 146 |
+
|
| 147 |
+
# Change chunk length to be test chunk length
|
| 148 |
+
self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 149 |
+
|
| 150 |
+
predictions = dict()
|
| 151 |
+
labels = dict()
|
| 152 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 153 |
+
if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
|
| 154 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 155 |
+
self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=self.device))
|
| 156 |
+
print("Testing uses pretrained model!")
|
| 157 |
+
else:
|
| 158 |
+
if self.config.TEST.USE_LAST_EPOCH:
|
| 159 |
+
last_epoch_model_path = os.path.join(
|
| 160 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
|
| 161 |
+
print("Testing uses last epoch as non-pretrained model!")
|
| 162 |
+
print(last_epoch_model_path)
|
| 163 |
+
self.model.load_state_dict(torch.load(last_epoch_model_path))
|
| 164 |
+
else:
|
| 165 |
+
best_model_path = os.path.join(
|
| 166 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
|
| 167 |
+
print("Testing uses best epoch selected using model selection as non-pretrained model!")
|
| 168 |
+
print(best_model_path)
|
| 169 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 170 |
+
|
| 171 |
+
self.model = self.model.to(self.config.DEVICE)
|
| 172 |
+
self.model.eval()
|
| 173 |
+
print("Running model evaluation on the testing dataset!")
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 176 |
+
batch_size = test_batch[0].shape[0]
|
| 177 |
+
data_test, labels_test = test_batch[0].to(
|
| 178 |
+
self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
|
| 179 |
+
N, D, C, H, W = data_test.shape
|
| 180 |
+
data_test = data_test.view(N * D, C, H, W)
|
| 181 |
+
labels_test = labels_test.view(-1, 1)
|
| 182 |
+
pred_ppg_test = self.model(data_test)
|
| 183 |
+
|
| 184 |
+
if self.config.TEST.OUTPUT_SAVE_DIR:
|
| 185 |
+
labels_test = labels_test.cpu()
|
| 186 |
+
pred_ppg_test = pred_ppg_test.cpu()
|
| 187 |
+
|
| 188 |
+
for idx in range(batch_size):
|
| 189 |
+
subj_index = test_batch[2][idx]
|
| 190 |
+
sort_index = int(test_batch[3][idx])
|
| 191 |
+
if subj_index not in predictions.keys():
|
| 192 |
+
predictions[subj_index] = dict()
|
| 193 |
+
labels[subj_index] = dict()
|
| 194 |
+
predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 195 |
+
labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 196 |
+
|
| 197 |
+
print('')
|
| 198 |
+
calculate_metrics(predictions, labels, self.config)
|
| 199 |
+
if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
|
| 200 |
+
self.save_test_outputs(predictions, labels, self.config)
|
| 201 |
+
|
| 202 |
+
def save_model(self, index):
|
| 203 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 204 |
+
if not os.path.exists(self.model_dir):
|
| 205 |
+
os.makedirs(self.model_dir)
|
| 206 |
+
model_path = os.path.join(
|
| 207 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 208 |
+
torch.save(self.model.state_dict(), model_path)
|
| 209 |
+
|
neural_methods/trainer/EfficientPhysTrainer.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trainer for EfficientPhys."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
from evaluation.metrics import calculate_metrics
|
| 11 |
+
from neural_methods.loss.NegPearsonLoss import Neg_Pearson
|
| 12 |
+
from neural_methods.model.EfficientPhys import EfficientPhys
|
| 13 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EfficientPhysTrainer(BaseTrainer):
|
| 18 |
+
|
| 19 |
+
def __init__(self, config, data_loader):
|
| 20 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.device = torch.device(config.DEVICE)
|
| 23 |
+
self.frame_depth = config.MODEL.EFFICIENTPHYS.FRAME_DEPTH
|
| 24 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 25 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 26 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 27 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 28 |
+
self.num_of_gpu = config.NUM_OF_GPU_TRAIN
|
| 29 |
+
self.base_len = self.num_of_gpu * self.frame_depth
|
| 30 |
+
self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
|
| 31 |
+
self.config = config
|
| 32 |
+
self.min_valid_loss = None
|
| 33 |
+
self.best_epoch = 0
|
| 34 |
+
|
| 35 |
+
if config.TOOLBOX_MODE == "train_and_test":
|
| 36 |
+
self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(
|
| 37 |
+
self.device)
|
| 38 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 39 |
+
|
| 40 |
+
self.num_train_batches = len(data_loader["train"])
|
| 41 |
+
self.criterion = torch.nn.MSELoss()
|
| 42 |
+
self.optimizer = optim.AdamW(
|
| 43 |
+
self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0)
|
| 44 |
+
# See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
|
| 45 |
+
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 46 |
+
self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
|
| 47 |
+
elif config.TOOLBOX_MODE == "only_test":
|
| 48 |
+
self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(
|
| 49 |
+
self.device)
|
| 50 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError("EfficientPhys trainer initialized in incorrect toolbox mode!")
|
| 53 |
+
|
| 54 |
+
def train(self, data_loader):
|
| 55 |
+
"""Training routine for model"""
|
| 56 |
+
if data_loader["train"] is None:
|
| 57 |
+
raise ValueError("No data for train")
|
| 58 |
+
|
| 59 |
+
mean_training_losses = []
|
| 60 |
+
mean_valid_losses = []
|
| 61 |
+
lrs = []
|
| 62 |
+
for epoch in range(self.max_epoch_num):
|
| 63 |
+
print('')
|
| 64 |
+
print(f"====Training Epoch: {epoch}====")
|
| 65 |
+
running_loss = 0.0
|
| 66 |
+
train_loss = []
|
| 67 |
+
self.model.train()
|
| 68 |
+
# Model Training
|
| 69 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 70 |
+
for idx, batch in enumerate(tbar):
|
| 71 |
+
tbar.set_description("Train epoch %s" % epoch)
|
| 72 |
+
data, labels = batch[0].to(
|
| 73 |
+
self.device), batch[1].to(self.device)
|
| 74 |
+
N, D, C, H, W = data.shape
|
| 75 |
+
data = data.view(N * D, C, H, W)
|
| 76 |
+
labels = labels.view(-1, 1)
|
| 77 |
+
data = data[:(N * D) // self.base_len * self.base_len]
|
| 78 |
+
# Add one more frame for EfficientPhys since it does torch.diff for the input
|
| 79 |
+
last_frame = torch.unsqueeze(data[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
|
| 80 |
+
data = torch.cat((data, last_frame), 0)
|
| 81 |
+
labels = labels[:(N * D) // self.base_len * self.base_len]
|
| 82 |
+
self.optimizer.zero_grad()
|
| 83 |
+
pred_ppg = self.model(data)
|
| 84 |
+
loss = self.criterion(pred_ppg, labels)
|
| 85 |
+
loss.backward()
|
| 86 |
+
|
| 87 |
+
# Append the current learning rate to the list
|
| 88 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 89 |
+
|
| 90 |
+
self.optimizer.step()
|
| 91 |
+
self.scheduler.step()
|
| 92 |
+
running_loss += loss.item()
|
| 93 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 94 |
+
print(
|
| 95 |
+
f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
|
| 96 |
+
running_loss = 0.0
|
| 97 |
+
train_loss.append(loss.item())
|
| 98 |
+
tbar.set_postfix(loss=loss.item())
|
| 99 |
+
|
| 100 |
+
# Append the mean training loss for the epoch
|
| 101 |
+
mean_training_losses.append(np.mean(train_loss))
|
| 102 |
+
|
| 103 |
+
self.save_model(epoch)
|
| 104 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 105 |
+
valid_loss = self.valid(data_loader)
|
| 106 |
+
mean_valid_losses.append(valid_loss)
|
| 107 |
+
print('validation loss: ', valid_loss)
|
| 108 |
+
if self.min_valid_loss is None:
|
| 109 |
+
self.min_valid_loss = valid_loss
|
| 110 |
+
self.best_epoch = epoch
|
| 111 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 112 |
+
elif (valid_loss < self.min_valid_loss):
|
| 113 |
+
self.min_valid_loss = valid_loss
|
| 114 |
+
self.best_epoch = epoch
|
| 115 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 116 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 117 |
+
print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss))
|
| 118 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 119 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 120 |
+
|
| 121 |
+
def valid(self, data_loader):
|
| 122 |
+
""" Model evaluation on the validation dataset."""
|
| 123 |
+
if data_loader["valid"] is None:
|
| 124 |
+
raise ValueError("No data for valid")
|
| 125 |
+
|
| 126 |
+
print('')
|
| 127 |
+
print("===Validating===")
|
| 128 |
+
valid_loss = []
|
| 129 |
+
self.model.eval()
|
| 130 |
+
valid_step = 0
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 133 |
+
for valid_idx, valid_batch in enumerate(vbar):
|
| 134 |
+
vbar.set_description("Validation")
|
| 135 |
+
data_valid, labels_valid = valid_batch[0].to(
|
| 136 |
+
self.device), valid_batch[1].to(self.device)
|
| 137 |
+
N, D, C, H, W = data_valid.shape
|
| 138 |
+
data_valid = data_valid.view(N * D, C, H, W)
|
| 139 |
+
labels_valid = labels_valid.view(-1, 1)
|
| 140 |
+
data_valid = data_valid[:(N * D) // self.base_len * self.base_len]
|
| 141 |
+
# Add one more frame for EfficientPhys since it does torch.diff for the input
|
| 142 |
+
last_frame = torch.unsqueeze(data_valid[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
|
| 143 |
+
data_valid = torch.cat((data_valid, last_frame), 0)
|
| 144 |
+
labels_valid = labels_valid[:(N * D) // self.base_len * self.base_len]
|
| 145 |
+
pred_ppg_valid = self.model(data_valid)
|
| 146 |
+
loss = self.criterion(pred_ppg_valid, labels_valid)
|
| 147 |
+
valid_loss.append(loss.item())
|
| 148 |
+
valid_step += 1
|
| 149 |
+
vbar.set_postfix(loss=loss.item())
|
| 150 |
+
valid_loss = np.asarray(valid_loss)
|
| 151 |
+
return np.mean(valid_loss)
|
| 152 |
+
|
| 153 |
+
def test(self, data_loader):
|
| 154 |
+
""" Model evaluation on the testing dataset."""
|
| 155 |
+
if data_loader["test"] is None:
|
| 156 |
+
raise ValueError("No data for test")
|
| 157 |
+
|
| 158 |
+
print('')
|
| 159 |
+
print("===Testing===")
|
| 160 |
+
|
| 161 |
+
# Change chunk length to be test chunk length
|
| 162 |
+
self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 163 |
+
|
| 164 |
+
predictions = dict()
|
| 165 |
+
labels = dict()
|
| 166 |
+
|
| 167 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 168 |
+
if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
|
| 169 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 170 |
+
self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")))
|
| 171 |
+
print("Testing uses pretrained model!")
|
| 172 |
+
else:
|
| 173 |
+
if self.config.TEST.USE_LAST_EPOCH:
|
| 174 |
+
last_epoch_model_path = os.path.join(
|
| 175 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
|
| 176 |
+
print("Testing uses last epoch as non-pretrained model!")
|
| 177 |
+
print(last_epoch_model_path)
|
| 178 |
+
self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")))
|
| 179 |
+
else:
|
| 180 |
+
best_model_path = os.path.join(
|
| 181 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
|
| 182 |
+
print("Testing uses best epoch selected using model selection as non-pretrained model!")
|
| 183 |
+
print(best_model_path)
|
| 184 |
+
self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")))
|
| 185 |
+
|
| 186 |
+
self.model = self.model.to(self.config.DEVICE)
|
| 187 |
+
self.model.eval()
|
| 188 |
+
print("Running model evaluation on the testing dataset!")
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 191 |
+
batch_size = test_batch[0].shape[0]
|
| 192 |
+
data_test, labels_test = test_batch[0].to(
|
| 193 |
+
self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
|
| 194 |
+
N, D, C, H, W = data_test.shape
|
| 195 |
+
data_test = data_test.view(N * D, C, H, W)
|
| 196 |
+
labels_test = labels_test.view(-1, 1)
|
| 197 |
+
data_test = data_test[:(N * D) // self.base_len * self.base_len]
|
| 198 |
+
# Add one more frame for EfficientPhys since it does torch.diff for the input
|
| 199 |
+
last_frame = torch.unsqueeze(data_test[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
|
| 200 |
+
data_test = torch.cat((data_test, last_frame), 0)
|
| 201 |
+
labels_test = labels_test[:(N * D) // self.base_len * self.base_len]
|
| 202 |
+
pred_ppg_test = self.model(data_test)
|
| 203 |
+
|
| 204 |
+
if self.config.TEST.OUTPUT_SAVE_DIR:
|
| 205 |
+
labels_test = labels_test.cpu()
|
| 206 |
+
pred_ppg_test = pred_ppg_test.cpu()
|
| 207 |
+
|
| 208 |
+
for idx in range(batch_size):
|
| 209 |
+
subj_index = test_batch[2][idx]
|
| 210 |
+
sort_index = int(test_batch[3][idx])
|
| 211 |
+
if subj_index not in predictions.keys():
|
| 212 |
+
predictions[subj_index] = dict()
|
| 213 |
+
labels[subj_index] = dict()
|
| 214 |
+
predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 215 |
+
labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 216 |
+
|
| 217 |
+
print('')
|
| 218 |
+
calculate_metrics(predictions, labels, self.config)
|
| 219 |
+
if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
|
| 220 |
+
self.save_test_outputs(predictions, labels, self.config)
|
| 221 |
+
|
| 222 |
+
def save_model(self, index):
|
| 223 |
+
if not os.path.exists(self.model_dir):
|
| 224 |
+
os.makedirs(self.model_dir)
|
| 225 |
+
model_path = os.path.join(
|
| 226 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 227 |
+
torch.save(self.model.state_dict(), model_path)
|
| 228 |
+
print('Saved Model Path: ', model_path)
|
neural_methods/trainer/EfficientPhysTrainer.py.backup
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trainer for EfficientPhys."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
from evaluation.metrics import calculate_metrics
|
| 11 |
+
from neural_methods.loss.NegPearsonLoss import Neg_Pearson
|
| 12 |
+
from neural_methods.model.EfficientPhys import EfficientPhys
|
| 13 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EfficientPhysTrainer(BaseTrainer):
|
| 18 |
+
|
| 19 |
+
def __init__(self, config, data_loader):
|
| 20 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.device = torch.device(config.DEVICE)
|
| 23 |
+
self.frame_depth = config.MODEL.EFFICIENTPHYS.FRAME_DEPTH
|
| 24 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 25 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 26 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 27 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 28 |
+
self.num_of_gpu = config.NUM_OF_GPU_TRAIN
|
| 29 |
+
self.base_len = self.num_of_gpu * self.frame_depth
|
| 30 |
+
self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
|
| 31 |
+
self.config = config
|
| 32 |
+
self.min_valid_loss = None
|
| 33 |
+
self.best_epoch = 0
|
| 34 |
+
|
| 35 |
+
if config.TOOLBOX_MODE == "train_and_test":
|
| 36 |
+
self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(
|
| 37 |
+
self.device)
|
| 38 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 39 |
+
|
| 40 |
+
self.num_train_batches = len(data_loader["train"])
|
| 41 |
+
self.criterion = torch.nn.MSELoss()
|
| 42 |
+
self.optimizer = optim.AdamW(
|
| 43 |
+
self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0)
|
| 44 |
+
# See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
|
| 45 |
+
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 46 |
+
self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
|
| 47 |
+
elif config.TOOLBOX_MODE == "only_test":
|
| 48 |
+
self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(
|
| 49 |
+
self.device)
|
| 50 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError("EfficientPhys trainer initialized in incorrect toolbox mode!")
|
| 53 |
+
|
| 54 |
+
def train(self, data_loader):
|
| 55 |
+
"""Training routine for model"""
|
| 56 |
+
if data_loader["train"] is None:
|
| 57 |
+
raise ValueError("No data for train")
|
| 58 |
+
|
| 59 |
+
mean_training_losses = []
|
| 60 |
+
mean_valid_losses = []
|
| 61 |
+
lrs = []
|
| 62 |
+
for epoch in range(self.max_epoch_num):
|
| 63 |
+
print('')
|
| 64 |
+
print(f"====Training Epoch: {epoch}====")
|
| 65 |
+
running_loss = 0.0
|
| 66 |
+
train_loss = []
|
| 67 |
+
self.model.train()
|
| 68 |
+
# Model Training
|
| 69 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 70 |
+
for idx, batch in enumerate(tbar):
|
| 71 |
+
tbar.set_description("Train epoch %s" % epoch)
|
| 72 |
+
data, labels = batch[0].to(
|
| 73 |
+
self.device), batch[1].to(self.device)
|
| 74 |
+
N, D, C, H, W = data.shape
|
| 75 |
+
data = data.view(N * D, C, H, W)
|
| 76 |
+
labels = labels.view(-1, 1)
|
| 77 |
+
data = data[:(N * D) // self.base_len * self.base_len]
|
| 78 |
+
# Add one more frame for EfficientPhys since it does torch.diff for the input
|
| 79 |
+
last_frame = torch.unsqueeze(data[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
|
| 80 |
+
data = torch.cat((data, last_frame), 0)
|
| 81 |
+
labels = labels[:(N * D) // self.base_len * self.base_len]
|
| 82 |
+
self.optimizer.zero_grad()
|
| 83 |
+
pred_ppg = self.model(data)
|
| 84 |
+
loss = self.criterion(pred_ppg, labels)
|
| 85 |
+
loss.backward()
|
| 86 |
+
|
| 87 |
+
# Append the current learning rate to the list
|
| 88 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 89 |
+
|
| 90 |
+
self.optimizer.step()
|
| 91 |
+
self.scheduler.step()
|
| 92 |
+
running_loss += loss.item()
|
| 93 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 94 |
+
print(
|
| 95 |
+
f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
|
| 96 |
+
running_loss = 0.0
|
| 97 |
+
train_loss.append(loss.item())
|
| 98 |
+
tbar.set_postfix(loss=loss.item())
|
| 99 |
+
|
| 100 |
+
# Append the mean training loss for the epoch
|
| 101 |
+
mean_training_losses.append(np.mean(train_loss))
|
| 102 |
+
|
| 103 |
+
self.save_model(epoch)
|
| 104 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 105 |
+
valid_loss = self.valid(data_loader)
|
| 106 |
+
mean_valid_losses.append(valid_loss)
|
| 107 |
+
print('validation loss: ', valid_loss)
|
| 108 |
+
if self.min_valid_loss is None:
|
| 109 |
+
self.min_valid_loss = valid_loss
|
| 110 |
+
self.best_epoch = epoch
|
| 111 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 112 |
+
elif (valid_loss < self.min_valid_loss):
|
| 113 |
+
self.min_valid_loss = valid_loss
|
| 114 |
+
self.best_epoch = epoch
|
| 115 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 116 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 117 |
+
print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss))
|
| 118 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 119 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 120 |
+
|
| 121 |
+
def valid(self, data_loader):
|
| 122 |
+
""" Model evaluation on the validation dataset."""
|
| 123 |
+
if data_loader["valid"] is None:
|
| 124 |
+
raise ValueError("No data for valid")
|
| 125 |
+
|
| 126 |
+
print('')
|
| 127 |
+
print("===Validating===")
|
| 128 |
+
valid_loss = []
|
| 129 |
+
self.model.eval()
|
| 130 |
+
valid_step = 0
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 133 |
+
for valid_idx, valid_batch in enumerate(vbar):
|
| 134 |
+
vbar.set_description("Validation")
|
| 135 |
+
data_valid, labels_valid = valid_batch[0].to(
|
| 136 |
+
self.device), valid_batch[1].to(self.device)
|
| 137 |
+
N, D, C, H, W = data_valid.shape
|
| 138 |
+
data_valid = data_valid.view(N * D, C, H, W)
|
| 139 |
+
labels_valid = labels_valid.view(-1, 1)
|
| 140 |
+
data_valid = data_valid[:(N * D) // self.base_len * self.base_len]
|
| 141 |
+
# Add one more frame for EfficientPhys since it does torch.diff for the input
|
| 142 |
+
last_frame = torch.unsqueeze(data_valid[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
|
| 143 |
+
data_valid = torch.cat((data_valid, last_frame), 0)
|
| 144 |
+
labels_valid = labels_valid[:(N * D) // self.base_len * self.base_len]
|
| 145 |
+
pred_ppg_valid = self.model(data_valid)
|
| 146 |
+
loss = self.criterion(pred_ppg_valid, labels_valid)
|
| 147 |
+
valid_loss.append(loss.item())
|
| 148 |
+
valid_step += 1
|
| 149 |
+
vbar.set_postfix(loss=loss.item())
|
| 150 |
+
valid_loss = np.asarray(valid_loss)
|
| 151 |
+
return np.mean(valid_loss)
|
| 152 |
+
|
| 153 |
+
def test(self, data_loader):
|
| 154 |
+
""" Model evaluation on the testing dataset."""
|
| 155 |
+
if data_loader["test"] is None:
|
| 156 |
+
raise ValueError("No data for test")
|
| 157 |
+
|
| 158 |
+
print('')
|
| 159 |
+
print("===Testing===")
|
| 160 |
+
|
| 161 |
+
# Change chunk length to be test chunk length
|
| 162 |
+
self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 163 |
+
|
| 164 |
+
predictions = dict()
|
| 165 |
+
labels = dict()
|
| 166 |
+
|
| 167 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 168 |
+
if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
|
| 169 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 170 |
+
self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH))
|
| 171 |
+
print("Testing uses pretrained model!")
|
| 172 |
+
else:
|
| 173 |
+
if self.config.TEST.USE_LAST_EPOCH:
|
| 174 |
+
last_epoch_model_path = os.path.join(
|
| 175 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
|
| 176 |
+
print("Testing uses last epoch as non-pretrained model!")
|
| 177 |
+
print(last_epoch_model_path)
|
| 178 |
+
self.model.load_state_dict(torch.load(last_epoch_model_path))
|
| 179 |
+
else:
|
| 180 |
+
best_model_path = os.path.join(
|
| 181 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
|
| 182 |
+
print("Testing uses best epoch selected using model selection as non-pretrained model!")
|
| 183 |
+
print(best_model_path)
|
| 184 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 185 |
+
|
| 186 |
+
self.model = self.model.to(self.config.DEVICE)
|
| 187 |
+
self.model.eval()
|
| 188 |
+
print("Running model evaluation on the testing dataset!")
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 191 |
+
batch_size = test_batch[0].shape[0]
|
| 192 |
+
data_test, labels_test = test_batch[0].to(
|
| 193 |
+
self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
|
| 194 |
+
N, D, C, H, W = data_test.shape
|
| 195 |
+
data_test = data_test.view(N * D, C, H, W)
|
| 196 |
+
labels_test = labels_test.view(-1, 1)
|
| 197 |
+
data_test = data_test[:(N * D) // self.base_len * self.base_len]
|
| 198 |
+
# Add one more frame for EfficientPhys since it does torch.diff for the input
|
| 199 |
+
last_frame = torch.unsqueeze(data_test[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1)
|
| 200 |
+
data_test = torch.cat((data_test, last_frame), 0)
|
| 201 |
+
labels_test = labels_test[:(N * D) // self.base_len * self.base_len]
|
| 202 |
+
pred_ppg_test = self.model(data_test)
|
| 203 |
+
|
| 204 |
+
if self.config.TEST.OUTPUT_SAVE_DIR:
|
| 205 |
+
labels_test = labels_test.cpu()
|
| 206 |
+
pred_ppg_test = pred_ppg_test.cpu()
|
| 207 |
+
|
| 208 |
+
for idx in range(batch_size):
|
| 209 |
+
subj_index = test_batch[2][idx]
|
| 210 |
+
sort_index = int(test_batch[3][idx])
|
| 211 |
+
if subj_index not in predictions.keys():
|
| 212 |
+
predictions[subj_index] = dict()
|
| 213 |
+
labels[subj_index] = dict()
|
| 214 |
+
predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 215 |
+
labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len]
|
| 216 |
+
|
| 217 |
+
print('')
|
| 218 |
+
calculate_metrics(predictions, labels, self.config)
|
| 219 |
+
if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
|
| 220 |
+
self.save_test_outputs(predictions, labels, self.config)
|
| 221 |
+
|
| 222 |
+
def save_model(self, index):
|
| 223 |
+
if not os.path.exists(self.model_dir):
|
| 224 |
+
os.makedirs(self.model_dir)
|
| 225 |
+
model_path = os.path.join(
|
| 226 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 227 |
+
torch.save(self.model.state_dict(), model_path)
|
| 228 |
+
print('Saved Model Path: ', model_path)
|
neural_methods/trainer/FactorizePhysTrainer.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
|
| 3 |
+
NeurIPS 2024
|
| 4 |
+
Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
from evaluation.metrics import calculate_metrics
|
| 12 |
+
from neural_methods.loss.NegPearsonLoss import Neg_Pearson
|
| 13 |
+
from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys
|
| 14 |
+
from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig
|
| 15 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FactorizePhysTrainer(BaseTrainer):
|
| 20 |
+
|
| 21 |
+
def __init__(self, config, data_loader):
|
| 22 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 25 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 26 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 27 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 28 |
+
self.num_of_gpu = config.NUM_OF_GPU_TRAIN
|
| 29 |
+
self.dropout_rate = config.MODEL.DROP_RATE
|
| 30 |
+
self.base_len = self.num_of_gpu
|
| 31 |
+
self.config = config
|
| 32 |
+
self.min_valid_loss = None
|
| 33 |
+
self.best_epoch = 0
|
| 34 |
+
|
| 35 |
+
if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0:
|
| 36 |
+
dev_list = [int(d) for d in config.DEVICE.replace("cuda:", "").split(",")]
|
| 37 |
+
self.device = torch.device(dev_list[0]) #currently toolbox only supports 1 GPU
|
| 38 |
+
self.num_of_gpu = 1 #config.NUM_OF_GPU_TRAIN # set number of used GPUs
|
| 39 |
+
else:
|
| 40 |
+
self.device = torch.device("cpu") # if no GPUs set device is CPU
|
| 41 |
+
self.num_of_gpu = 0 # no GPUs used
|
| 42 |
+
|
| 43 |
+
frames = self.config.MODEL.FactorizePhys.FRAME_NUM
|
| 44 |
+
in_channels = self.config.MODEL.FactorizePhys.CHANNELS
|
| 45 |
+
model_type = self.config.MODEL.FactorizePhys.TYPE
|
| 46 |
+
model_type = model_type.lower()
|
| 47 |
+
|
| 48 |
+
md_config = {}
|
| 49 |
+
md_config["FRAME_NUM"] = self.config.MODEL.FactorizePhys.FRAME_NUM
|
| 50 |
+
md_config["MD_TYPE"] = self.config.MODEL.FactorizePhys.MD_TYPE
|
| 51 |
+
md_config["MD_FSAM"] = self.config.MODEL.FactorizePhys.MD_FSAM
|
| 52 |
+
md_config["MD_TRANSFORM"] = self.config.MODEL.FactorizePhys.MD_TRANSFORM
|
| 53 |
+
md_config["MD_S"] = self.config.MODEL.FactorizePhys.MD_S
|
| 54 |
+
md_config["MD_R"] = self.config.MODEL.FactorizePhys.MD_R
|
| 55 |
+
md_config["MD_STEPS"] = self.config.MODEL.FactorizePhys.MD_STEPS
|
| 56 |
+
md_config["MD_INFERENCE"] = self.config.MODEL.FactorizePhys.MD_INFERENCE
|
| 57 |
+
md_config["MD_RESIDUAL"] = self.config.MODEL.FactorizePhys.MD_RESIDUAL
|
| 58 |
+
|
| 59 |
+
self.md_infer = self.config.MODEL.FactorizePhys.MD_INFERENCE
|
| 60 |
+
self.use_fsam = self.config.MODEL.FactorizePhys.MD_FSAM
|
| 61 |
+
|
| 62 |
+
if model_type == "standard":
|
| 63 |
+
self.model = FactorizePhys(frames=frames, md_config=md_config, in_channels=in_channels,
|
| 64 |
+
dropout=self.dropout_rate, device=self.device) # [3, T, 72,72]
|
| 65 |
+
elif model_type == "big":
|
| 66 |
+
self.model = FactorizePhysBig(frames=frames, md_config=md_config, in_channels=in_channels,
|
| 67 |
+
dropout=self.dropout_rate, device=self.device) # [3, T, 144,144]
|
| 68 |
+
else:
|
| 69 |
+
print("Unexpected model type specified. Should be standard or big, but specified:", model_type)
|
| 70 |
+
exit()
|
| 71 |
+
|
| 72 |
+
if torch.cuda.device_count() > 0 and self.num_of_gpu > 0: # distribute model across GPUs
|
| 73 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=[self.device]) # data parallel model
|
| 74 |
+
else:
|
| 75 |
+
self.model = torch.nn.DataParallel(self.model).to(self.device)
|
| 76 |
+
|
| 77 |
+
if self.config.TOOLBOX_MODE == "train_and_test" or self.config.TOOLBOX_MODE == "only_train":
|
| 78 |
+
self.num_train_batches = len(data_loader["train"])
|
| 79 |
+
self.criterion = Neg_Pearson()
|
| 80 |
+
self.optimizer = optim.Adam(
|
| 81 |
+
self.model.parameters(), lr=self.config.TRAIN.LR)
|
| 82 |
+
# See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
|
| 83 |
+
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 84 |
+
self.optimizer, max_lr=self.config.TRAIN.LR, epochs=self.config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
|
| 85 |
+
elif self.config.TOOLBOX_MODE == "only_test":
|
| 86 |
+
pass
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("FactorizePhys trainer initialized in incorrect toolbox mode!")
|
| 89 |
+
|
| 90 |
+
def train(self, data_loader):
|
| 91 |
+
"""Training routine for model"""
|
| 92 |
+
if data_loader["train"] is None:
|
| 93 |
+
raise ValueError("No data for train")
|
| 94 |
+
|
| 95 |
+
mean_training_losses = []
|
| 96 |
+
mean_valid_losses = []
|
| 97 |
+
mean_appx_error = []
|
| 98 |
+
lrs = []
|
| 99 |
+
for epoch in range(self.max_epoch_num):
|
| 100 |
+
print('')
|
| 101 |
+
print(f"====Training Epoch: {epoch}====")
|
| 102 |
+
running_loss = 0.0
|
| 103 |
+
train_loss = []
|
| 104 |
+
appx_error_list = []
|
| 105 |
+
self.model.train()
|
| 106 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 107 |
+
for idx, batch in enumerate(tbar):
|
| 108 |
+
tbar.set_description("Train epoch %s" % epoch)
|
| 109 |
+
|
| 110 |
+
data = batch[0].to(self.device)
|
| 111 |
+
labels = batch[1].to(self.device)
|
| 112 |
+
|
| 113 |
+
if len(labels.shape) > 2:
|
| 114 |
+
labels = labels[..., 0] # Compatibility wigth multi-signal labelled data
|
| 115 |
+
labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize
|
| 116 |
+
last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
|
| 117 |
+
data = torch.cat((data, last_frame), 2)
|
| 118 |
+
|
| 119 |
+
# last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
|
| 120 |
+
# labels = torch.cat((labels, last_sample), 0)
|
| 121 |
+
# labels = torch.diff(labels, dim=0)
|
| 122 |
+
# labels = labels/ torch.std(labels) # normalize
|
| 123 |
+
# labels[torch.isnan(labels)] = 0
|
| 124 |
+
|
| 125 |
+
self.optimizer.zero_grad()
|
| 126 |
+
if self.model.training and self.use_fsam:
|
| 127 |
+
pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data)
|
| 128 |
+
else:
|
| 129 |
+
pred_ppg, vox_embed = self.model(data)
|
| 130 |
+
|
| 131 |
+
pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize
|
| 132 |
+
|
| 133 |
+
loss = self.criterion(pred_ppg, labels)
|
| 134 |
+
|
| 135 |
+
loss.backward()
|
| 136 |
+
running_loss += loss.item()
|
| 137 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 138 |
+
print(
|
| 139 |
+
f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
|
| 140 |
+
running_loss = 0.0
|
| 141 |
+
train_loss.append(loss.item())
|
| 142 |
+
if self.use_fsam:
|
| 143 |
+
appx_error_list.append(appx_error.item())
|
| 144 |
+
|
| 145 |
+
# Append the current learning rate to the list
|
| 146 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 147 |
+
|
| 148 |
+
self.optimizer.step()
|
| 149 |
+
self.scheduler.step()
|
| 150 |
+
|
| 151 |
+
if self.use_fsam:
|
| 152 |
+
tbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item())
|
| 153 |
+
else:
|
| 154 |
+
tbar.set_postfix(loss=loss.item())
|
| 155 |
+
|
| 156 |
+
# Append the mean training loss for the epoch
|
| 157 |
+
mean_training_losses.append(np.mean(train_loss))
|
| 158 |
+
if self.use_fsam:
|
| 159 |
+
mean_appx_error.append(np.mean(appx_error_list))
|
| 160 |
+
print("Mean train loss: {}, Mean appx error: {}".format(
|
| 161 |
+
np.mean(train_loss), np.mean(appx_error_list)))
|
| 162 |
+
else:
|
| 163 |
+
print("Mean train loss: {}".format(np.mean(train_loss)))
|
| 164 |
+
|
| 165 |
+
self.save_model(epoch)
|
| 166 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 167 |
+
valid_loss = self.valid(data_loader)
|
| 168 |
+
mean_valid_losses.append(valid_loss)
|
| 169 |
+
print('validation loss: ', valid_loss)
|
| 170 |
+
if self.min_valid_loss is None:
|
| 171 |
+
self.min_valid_loss = valid_loss
|
| 172 |
+
self.best_epoch = epoch
|
| 173 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 174 |
+
elif (valid_loss < self.min_valid_loss):
|
| 175 |
+
self.min_valid_loss = valid_loss
|
| 176 |
+
self.best_epoch = epoch
|
| 177 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 178 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 179 |
+
print("best trained epoch: {}, min_val_loss: {}".format(
|
| 180 |
+
self.best_epoch, self.min_valid_loss))
|
| 181 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 182 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 183 |
+
|
| 184 |
+
def valid(self, data_loader):
|
| 185 |
+
""" Runs the model on valid sets."""
|
| 186 |
+
if data_loader["valid"] is None:
|
| 187 |
+
raise ValueError("No data for valid")
|
| 188 |
+
|
| 189 |
+
print('')
|
| 190 |
+
print(" ====Validing===")
|
| 191 |
+
valid_loss = []
|
| 192 |
+
self.model.eval()
|
| 193 |
+
valid_step = 0
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 196 |
+
for valid_idx, valid_batch in enumerate(vbar):
|
| 197 |
+
vbar.set_description("Validation")
|
| 198 |
+
|
| 199 |
+
data, labels = valid_batch[0].to(self.device), valid_batch[1].to(self.device)
|
| 200 |
+
if len(labels.shape) > 2:
|
| 201 |
+
labels = labels[..., 0] # Compatibility wigth multi-signal labelled data
|
| 202 |
+
labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize
|
| 203 |
+
|
| 204 |
+
last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
|
| 205 |
+
data = torch.cat((data, last_frame), 2)
|
| 206 |
+
|
| 207 |
+
# last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
|
| 208 |
+
# labels = torch.cat((labels, last_sample), 0)
|
| 209 |
+
# labels = torch.diff(labels, dim=0)
|
| 210 |
+
# labels = labels/ torch.std(labels) # normalize
|
| 211 |
+
# labels[torch.isnan(labels)] = 0
|
| 212 |
+
|
| 213 |
+
if self.md_infer and self.use_fsam:
|
| 214 |
+
pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data)
|
| 215 |
+
else:
|
| 216 |
+
pred_ppg, vox_embed = self.model(data)
|
| 217 |
+
pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize
|
| 218 |
+
loss = self.criterion(pred_ppg, labels)
|
| 219 |
+
|
| 220 |
+
valid_loss.append(loss.item())
|
| 221 |
+
valid_step += 1
|
| 222 |
+
# vbar.set_postfix(loss=loss.item())
|
| 223 |
+
if self.md_infer and self.use_fsam:
|
| 224 |
+
vbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item())
|
| 225 |
+
else:
|
| 226 |
+
vbar.set_postfix(loss=loss.item())
|
| 227 |
+
valid_loss = np.asarray(valid_loss)
|
| 228 |
+
return np.mean(valid_loss)
|
| 229 |
+
|
| 230 |
+
def test(self, data_loader):
|
| 231 |
+
""" Runs the model on test sets."""
|
| 232 |
+
if data_loader["test"] is None:
|
| 233 |
+
raise ValueError("No data for test")
|
| 234 |
+
|
| 235 |
+
print('')
|
| 236 |
+
print("===Testing===")
|
| 237 |
+
predictions = dict()
|
| 238 |
+
labels = dict()
|
| 239 |
+
|
| 240 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 241 |
+
if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
|
| 242 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 243 |
+
self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")), strict=False)
|
| 244 |
+
print("Testing uses pretrained model!")
|
| 245 |
+
print(self.config.INFERENCE.MODEL_PATH)
|
| 246 |
+
else:
|
| 247 |
+
if self.config.TEST.USE_LAST_EPOCH:
|
| 248 |
+
last_epoch_model_path = os.path.join(
|
| 249 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
|
| 250 |
+
print("Testing uses last epoch as non-pretrained model!")
|
| 251 |
+
print(last_epoch_model_path)
|
| 252 |
+
self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")), strict=False)
|
| 253 |
+
else:
|
| 254 |
+
best_model_path = os.path.join(
|
| 255 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
|
| 256 |
+
print("Testing uses best epoch selected using model selection as non-pretrained model!")
|
| 257 |
+
print(best_model_path)
|
| 258 |
+
self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")), strict=False)
|
| 259 |
+
|
| 260 |
+
self.model = self.model.to(self.device)
|
| 261 |
+
self.model.eval()
|
| 262 |
+
print("Running model evaluation on the testing dataset!")
|
| 263 |
+
with torch.no_grad():
|
| 264 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 265 |
+
batch_size = test_batch[0].shape[0]
|
| 266 |
+
data, labels_test = test_batch[0].to(self.device), test_batch[1].to(self.device)
|
| 267 |
+
|
| 268 |
+
if len(labels_test.shape) > 2:
|
| 269 |
+
labels_test = labels_test[..., 0] # Compatibility wigth multi-signal labelled data
|
| 270 |
+
labels_test = (labels_test - torch.mean(labels_test)) / torch.std(labels_test) # normalize
|
| 271 |
+
|
| 272 |
+
last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
|
| 273 |
+
data = torch.cat((data, last_frame), 2)
|
| 274 |
+
|
| 275 |
+
# last_sample = torch.unsqueeze(labels_test[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
|
| 276 |
+
# labels_test = torch.cat((labels_test, last_sample), 0)
|
| 277 |
+
# labels_test = torch.diff(labels_test, dim=0)
|
| 278 |
+
# labels_test = labels_test/ torch.std(labels_test) # normalize
|
| 279 |
+
# labels_test[torch.isnan(labels_test)] = 0
|
| 280 |
+
|
| 281 |
+
if self.md_infer and self.use_fsam:
|
| 282 |
+
pred_ppg_test, vox_embed, factorized_embed, appx_error = self.model(data)
|
| 283 |
+
else:
|
| 284 |
+
pred_ppg_test, vox_embed = self.model(data)
|
| 285 |
+
pred_ppg_test = (pred_ppg_test - torch.mean(pred_ppg_test)) / torch.std(pred_ppg_test) # normalize
|
| 286 |
+
|
| 287 |
+
if self.config.TEST.OUTPUT_SAVE_DIR:
|
| 288 |
+
labels_test = labels_test.cpu()
|
| 289 |
+
pred_ppg_test = pred_ppg_test.cpu()
|
| 290 |
+
|
| 291 |
+
for idx in range(batch_size):
|
| 292 |
+
subj_index = test_batch[2][idx]
|
| 293 |
+
sort_index = int(test_batch[3][idx])
|
| 294 |
+
if subj_index not in predictions.keys():
|
| 295 |
+
predictions[subj_index] = dict()
|
| 296 |
+
labels[subj_index] = dict()
|
| 297 |
+
predictions[subj_index][sort_index] = pred_ppg_test[idx]
|
| 298 |
+
labels[subj_index][sort_index] = labels_test[idx]
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
print('')
|
| 302 |
+
calculate_metrics(predictions, labels, self.config)
|
| 303 |
+
if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
|
| 304 |
+
self.save_test_outputs(predictions, labels, self.config)
|
| 305 |
+
|
| 306 |
+
def save_model(self, index):
|
| 307 |
+
if not os.path.exists(self.model_dir):
|
| 308 |
+
os.makedirs(self.model_dir)
|
| 309 |
+
model_path = os.path.join(
|
| 310 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 311 |
+
torch.save(self.model.state_dict(), model_path)
|
| 312 |
+
print('Saved Model Path: ', model_path)
|
neural_methods/trainer/FactorizePhysTrainer.py.backup
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing
|
| 3 |
+
NeurIPS 2024
|
| 4 |
+
Jitesh Joshi, Sos S. Agaian, and Youngjun Cho
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
from evaluation.metrics import calculate_metrics
|
| 12 |
+
from neural_methods.loss.NegPearsonLoss import Neg_Pearson
|
| 13 |
+
from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys
|
| 14 |
+
from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig
|
| 15 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FactorizePhysTrainer(BaseTrainer):
|
| 20 |
+
|
| 21 |
+
def __init__(self, config, data_loader):
|
| 22 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 25 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 26 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 27 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 28 |
+
self.num_of_gpu = config.NUM_OF_GPU_TRAIN
|
| 29 |
+
self.dropout_rate = config.MODEL.DROP_RATE
|
| 30 |
+
self.base_len = self.num_of_gpu
|
| 31 |
+
self.config = config
|
| 32 |
+
self.min_valid_loss = None
|
| 33 |
+
self.best_epoch = 0
|
| 34 |
+
|
| 35 |
+
if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0:
|
| 36 |
+
dev_list = [int(d) for d in config.DEVICE.replace("cuda:", "").split(",")]
|
| 37 |
+
self.device = torch.device(dev_list[0]) #currently toolbox only supports 1 GPU
|
| 38 |
+
self.num_of_gpu = 1 #config.NUM_OF_GPU_TRAIN # set number of used GPUs
|
| 39 |
+
else:
|
| 40 |
+
self.device = torch.device("cpu") # if no GPUs set device is CPU
|
| 41 |
+
self.num_of_gpu = 0 # no GPUs used
|
| 42 |
+
|
| 43 |
+
frames = self.config.MODEL.FactorizePhys.FRAME_NUM
|
| 44 |
+
in_channels = self.config.MODEL.FactorizePhys.CHANNELS
|
| 45 |
+
model_type = self.config.MODEL.FactorizePhys.TYPE
|
| 46 |
+
model_type = model_type.lower()
|
| 47 |
+
|
| 48 |
+
md_config = {}
|
| 49 |
+
md_config["FRAME_NUM"] = self.config.MODEL.FactorizePhys.FRAME_NUM
|
| 50 |
+
md_config["MD_TYPE"] = self.config.MODEL.FactorizePhys.MD_TYPE
|
| 51 |
+
md_config["MD_FSAM"] = self.config.MODEL.FactorizePhys.MD_FSAM
|
| 52 |
+
md_config["MD_TRANSFORM"] = self.config.MODEL.FactorizePhys.MD_TRANSFORM
|
| 53 |
+
md_config["MD_S"] = self.config.MODEL.FactorizePhys.MD_S
|
| 54 |
+
md_config["MD_R"] = self.config.MODEL.FactorizePhys.MD_R
|
| 55 |
+
md_config["MD_STEPS"] = self.config.MODEL.FactorizePhys.MD_STEPS
|
| 56 |
+
md_config["MD_INFERENCE"] = self.config.MODEL.FactorizePhys.MD_INFERENCE
|
| 57 |
+
md_config["MD_RESIDUAL"] = self.config.MODEL.FactorizePhys.MD_RESIDUAL
|
| 58 |
+
|
| 59 |
+
self.md_infer = self.config.MODEL.FactorizePhys.MD_INFERENCE
|
| 60 |
+
self.use_fsam = self.config.MODEL.FactorizePhys.MD_FSAM
|
| 61 |
+
|
| 62 |
+
if model_type == "standard":
|
| 63 |
+
self.model = FactorizePhys(frames=frames, md_config=md_config, in_channels=in_channels,
|
| 64 |
+
dropout=self.dropout_rate, device=self.device) # [3, T, 72,72]
|
| 65 |
+
elif model_type == "big":
|
| 66 |
+
self.model = FactorizePhysBig(frames=frames, md_config=md_config, in_channels=in_channels,
|
| 67 |
+
dropout=self.dropout_rate, device=self.device) # [3, T, 144,144]
|
| 68 |
+
else:
|
| 69 |
+
print("Unexpected model type specified. Should be standard or big, but specified:", model_type)
|
| 70 |
+
exit()
|
| 71 |
+
|
| 72 |
+
if torch.cuda.device_count() > 0 and self.num_of_gpu > 0: # distribute model across GPUs
|
| 73 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=[self.device]) # data parallel model
|
| 74 |
+
else:
|
| 75 |
+
self.model = torch.nn.DataParallel(self.model).to(self.device)
|
| 76 |
+
|
| 77 |
+
if self.config.TOOLBOX_MODE == "train_and_test" or self.config.TOOLBOX_MODE == "only_train":
|
| 78 |
+
self.num_train_batches = len(data_loader["train"])
|
| 79 |
+
self.criterion = Neg_Pearson()
|
| 80 |
+
self.optimizer = optim.Adam(
|
| 81 |
+
self.model.parameters(), lr=self.config.TRAIN.LR)
|
| 82 |
+
# See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
|
| 83 |
+
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 84 |
+
self.optimizer, max_lr=self.config.TRAIN.LR, epochs=self.config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches)
|
| 85 |
+
elif self.config.TOOLBOX_MODE == "only_test":
|
| 86 |
+
pass
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("FactorizePhys trainer initialized in incorrect toolbox mode!")
|
| 89 |
+
|
| 90 |
+
def train(self, data_loader):
|
| 91 |
+
"""Training routine for model"""
|
| 92 |
+
if data_loader["train"] is None:
|
| 93 |
+
raise ValueError("No data for train")
|
| 94 |
+
|
| 95 |
+
mean_training_losses = []
|
| 96 |
+
mean_valid_losses = []
|
| 97 |
+
mean_appx_error = []
|
| 98 |
+
lrs = []
|
| 99 |
+
for epoch in range(self.max_epoch_num):
|
| 100 |
+
print('')
|
| 101 |
+
print(f"====Training Epoch: {epoch}====")
|
| 102 |
+
running_loss = 0.0
|
| 103 |
+
train_loss = []
|
| 104 |
+
appx_error_list = []
|
| 105 |
+
self.model.train()
|
| 106 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 107 |
+
for idx, batch in enumerate(tbar):
|
| 108 |
+
tbar.set_description("Train epoch %s" % epoch)
|
| 109 |
+
|
| 110 |
+
data = batch[0].to(self.device)
|
| 111 |
+
labels = batch[1].to(self.device)
|
| 112 |
+
|
| 113 |
+
if len(labels.shape) > 2:
|
| 114 |
+
labels = labels[..., 0] # Compatibility wigth multi-signal labelled data
|
| 115 |
+
labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize
|
| 116 |
+
last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
|
| 117 |
+
data = torch.cat((data, last_frame), 2)
|
| 118 |
+
|
| 119 |
+
# last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
|
| 120 |
+
# labels = torch.cat((labels, last_sample), 0)
|
| 121 |
+
# labels = torch.diff(labels, dim=0)
|
| 122 |
+
# labels = labels/ torch.std(labels) # normalize
|
| 123 |
+
# labels[torch.isnan(labels)] = 0
|
| 124 |
+
|
| 125 |
+
self.optimizer.zero_grad()
|
| 126 |
+
if self.model.training and self.use_fsam:
|
| 127 |
+
pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data)
|
| 128 |
+
else:
|
| 129 |
+
pred_ppg, vox_embed = self.model(data)
|
| 130 |
+
|
| 131 |
+
pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize
|
| 132 |
+
|
| 133 |
+
loss = self.criterion(pred_ppg, labels)
|
| 134 |
+
|
| 135 |
+
loss.backward()
|
| 136 |
+
running_loss += loss.item()
|
| 137 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 138 |
+
print(
|
| 139 |
+
f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}')
|
| 140 |
+
running_loss = 0.0
|
| 141 |
+
train_loss.append(loss.item())
|
| 142 |
+
if self.use_fsam:
|
| 143 |
+
appx_error_list.append(appx_error.item())
|
| 144 |
+
|
| 145 |
+
# Append the current learning rate to the list
|
| 146 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 147 |
+
|
| 148 |
+
self.optimizer.step()
|
| 149 |
+
self.scheduler.step()
|
| 150 |
+
|
| 151 |
+
if self.use_fsam:
|
| 152 |
+
tbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item())
|
| 153 |
+
else:
|
| 154 |
+
tbar.set_postfix(loss=loss.item())
|
| 155 |
+
|
| 156 |
+
# Append the mean training loss for the epoch
|
| 157 |
+
mean_training_losses.append(np.mean(train_loss))
|
| 158 |
+
if self.use_fsam:
|
| 159 |
+
mean_appx_error.append(np.mean(appx_error_list))
|
| 160 |
+
print("Mean train loss: {}, Mean appx error: {}".format(
|
| 161 |
+
np.mean(train_loss), np.mean(appx_error_list)))
|
| 162 |
+
else:
|
| 163 |
+
print("Mean train loss: {}".format(np.mean(train_loss)))
|
| 164 |
+
|
| 165 |
+
self.save_model(epoch)
|
| 166 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 167 |
+
valid_loss = self.valid(data_loader)
|
| 168 |
+
mean_valid_losses.append(valid_loss)
|
| 169 |
+
print('validation loss: ', valid_loss)
|
| 170 |
+
if self.min_valid_loss is None:
|
| 171 |
+
self.min_valid_loss = valid_loss
|
| 172 |
+
self.best_epoch = epoch
|
| 173 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 174 |
+
elif (valid_loss < self.min_valid_loss):
|
| 175 |
+
self.min_valid_loss = valid_loss
|
| 176 |
+
self.best_epoch = epoch
|
| 177 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 178 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 179 |
+
print("best trained epoch: {}, min_val_loss: {}".format(
|
| 180 |
+
self.best_epoch, self.min_valid_loss))
|
| 181 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 182 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 183 |
+
|
| 184 |
+
def valid(self, data_loader):
|
| 185 |
+
""" Runs the model on valid sets."""
|
| 186 |
+
if data_loader["valid"] is None:
|
| 187 |
+
raise ValueError("No data for valid")
|
| 188 |
+
|
| 189 |
+
print('')
|
| 190 |
+
print(" ====Validing===")
|
| 191 |
+
valid_loss = []
|
| 192 |
+
self.model.eval()
|
| 193 |
+
valid_step = 0
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 196 |
+
for valid_idx, valid_batch in enumerate(vbar):
|
| 197 |
+
vbar.set_description("Validation")
|
| 198 |
+
|
| 199 |
+
data, labels = valid_batch[0].to(self.device), valid_batch[1].to(self.device)
|
| 200 |
+
if len(labels.shape) > 2:
|
| 201 |
+
labels = labels[..., 0] # Compatibility wigth multi-signal labelled data
|
| 202 |
+
labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize
|
| 203 |
+
|
| 204 |
+
last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
|
| 205 |
+
data = torch.cat((data, last_frame), 2)
|
| 206 |
+
|
| 207 |
+
# last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
|
| 208 |
+
# labels = torch.cat((labels, last_sample), 0)
|
| 209 |
+
# labels = torch.diff(labels, dim=0)
|
| 210 |
+
# labels = labels/ torch.std(labels) # normalize
|
| 211 |
+
# labels[torch.isnan(labels)] = 0
|
| 212 |
+
|
| 213 |
+
if self.md_infer and self.use_fsam:
|
| 214 |
+
pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data)
|
| 215 |
+
else:
|
| 216 |
+
pred_ppg, vox_embed = self.model(data)
|
| 217 |
+
pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize
|
| 218 |
+
loss = self.criterion(pred_ppg, labels)
|
| 219 |
+
|
| 220 |
+
valid_loss.append(loss.item())
|
| 221 |
+
valid_step += 1
|
| 222 |
+
# vbar.set_postfix(loss=loss.item())
|
| 223 |
+
if self.md_infer and self.use_fsam:
|
| 224 |
+
vbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item())
|
| 225 |
+
else:
|
| 226 |
+
vbar.set_postfix(loss=loss.item())
|
| 227 |
+
valid_loss = np.asarray(valid_loss)
|
| 228 |
+
return np.mean(valid_loss)
|
| 229 |
+
|
| 230 |
+
def test(self, data_loader):
|
| 231 |
+
""" Runs the model on test sets."""
|
| 232 |
+
if data_loader["test"] is None:
|
| 233 |
+
raise ValueError("No data for test")
|
| 234 |
+
|
| 235 |
+
print('')
|
| 236 |
+
print("===Testing===")
|
| 237 |
+
predictions = dict()
|
| 238 |
+
labels = dict()
|
| 239 |
+
|
| 240 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 241 |
+
if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
|
| 242 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 243 |
+
self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=self.device), strict=False)
|
| 244 |
+
print("Testing uses pretrained model!")
|
| 245 |
+
print(self.config.INFERENCE.MODEL_PATH)
|
| 246 |
+
else:
|
| 247 |
+
if self.config.TEST.USE_LAST_EPOCH:
|
| 248 |
+
last_epoch_model_path = os.path.join(
|
| 249 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
|
| 250 |
+
print("Testing uses last epoch as non-pretrained model!")
|
| 251 |
+
print(last_epoch_model_path)
|
| 252 |
+
self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=self.device), strict=False)
|
| 253 |
+
else:
|
| 254 |
+
best_model_path = os.path.join(
|
| 255 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
|
| 256 |
+
print("Testing uses best epoch selected using model selection as non-pretrained model!")
|
| 257 |
+
print(best_model_path)
|
| 258 |
+
self.model.load_state_dict(torch.load(best_model_path, map_location=self.device), strict=False)
|
| 259 |
+
|
| 260 |
+
self.model = self.model.to(self.device)
|
| 261 |
+
self.model.eval()
|
| 262 |
+
print("Running model evaluation on the testing dataset!")
|
| 263 |
+
with torch.no_grad():
|
| 264 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 265 |
+
batch_size = test_batch[0].shape[0]
|
| 266 |
+
data, labels_test = test_batch[0].to(self.device), test_batch[1].to(self.device)
|
| 267 |
+
|
| 268 |
+
if len(labels_test.shape) > 2:
|
| 269 |
+
labels_test = labels_test[..., 0] # Compatibility wigth multi-signal labelled data
|
| 270 |
+
labels_test = (labels_test - torch.mean(labels_test)) / torch.std(labels_test) # normalize
|
| 271 |
+
|
| 272 |
+
last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1)
|
| 273 |
+
data = torch.cat((data, last_frame), 2)
|
| 274 |
+
|
| 275 |
+
# last_sample = torch.unsqueeze(labels_test[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1)
|
| 276 |
+
# labels_test = torch.cat((labels_test, last_sample), 0)
|
| 277 |
+
# labels_test = torch.diff(labels_test, dim=0)
|
| 278 |
+
# labels_test = labels_test/ torch.std(labels_test) # normalize
|
| 279 |
+
# labels_test[torch.isnan(labels_test)] = 0
|
| 280 |
+
|
| 281 |
+
if self.md_infer and self.use_fsam:
|
| 282 |
+
pred_ppg_test, vox_embed, factorized_embed, appx_error = self.model(data)
|
| 283 |
+
else:
|
| 284 |
+
pred_ppg_test, vox_embed = self.model(data)
|
| 285 |
+
pred_ppg_test = (pred_ppg_test - torch.mean(pred_ppg_test)) / torch.std(pred_ppg_test) # normalize
|
| 286 |
+
|
| 287 |
+
if self.config.TEST.OUTPUT_SAVE_DIR:
|
| 288 |
+
labels_test = labels_test.cpu()
|
| 289 |
+
pred_ppg_test = pred_ppg_test.cpu()
|
| 290 |
+
|
| 291 |
+
for idx in range(batch_size):
|
| 292 |
+
subj_index = test_batch[2][idx]
|
| 293 |
+
sort_index = int(test_batch[3][idx])
|
| 294 |
+
if subj_index not in predictions.keys():
|
| 295 |
+
predictions[subj_index] = dict()
|
| 296 |
+
labels[subj_index] = dict()
|
| 297 |
+
predictions[subj_index][sort_index] = pred_ppg_test[idx]
|
| 298 |
+
labels[subj_index][sort_index] = labels_test[idx]
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
print('')
|
| 302 |
+
calculate_metrics(predictions, labels, self.config)
|
| 303 |
+
if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
|
| 304 |
+
self.save_test_outputs(predictions, labels, self.config)
|
| 305 |
+
|
| 306 |
+
def save_model(self, index):
|
| 307 |
+
if not os.path.exists(self.model_dir):
|
| 308 |
+
os.makedirs(self.model_dir)
|
| 309 |
+
model_path = os.path.join(
|
| 310 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 311 |
+
torch.save(self.model.state_dict(), model_path)
|
| 312 |
+
print('Saved Model Path: ', model_path)
|
neural_methods/trainer/PhysFormerTrainer.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trainer for Physformer.
|
| 2 |
+
|
| 3 |
+
Based on open-source code from the original PhysFormer authors below:
|
| 4 |
+
https://github.com/ZitongYu/PhysFormer/blob/main/train_Physformer_160_VIPL.py
|
| 5 |
+
|
| 6 |
+
We also thank the PhysBench authors for their open-source code based on the code
|
| 7 |
+
of the original authors. Their code below provided a better reference for tuning loss
|
| 8 |
+
parameters of interest and utilizing RSME as a validation loss:
|
| 9 |
+
https://github.com/KegangWangCCNU/PhysBench/blob/main/benchmark_addition/PhysFormer_pure.ipynb
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import numpy as np
|
| 15 |
+
import math
|
| 16 |
+
import torch
|
| 17 |
+
import torch.optim as optim
|
| 18 |
+
from evaluation.metrics import calculate_metrics
|
| 19 |
+
from neural_methods.loss.PhysNetNegPearsonLoss import Neg_Pearson
|
| 20 |
+
from neural_methods.loss.PhysFormerLossComputer import TorchLossComputer
|
| 21 |
+
from neural_methods.model.PhysFormer import ViT_ST_ST_Compact3_TDC_gra_sharp
|
| 22 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
from scipy.signal import welch
|
| 25 |
+
|
| 26 |
+
class PhysFormerTrainer(BaseTrainer):
|
| 27 |
+
|
| 28 |
+
def __init__(self, config, data_loader):
|
| 29 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.device = torch.device(config.DEVICE)
|
| 32 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 33 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 34 |
+
self.dropout_rate = config.MODEL.DROP_RATE
|
| 35 |
+
self.patch_size = config.MODEL.PHYSFORMER.PATCH_SIZE
|
| 36 |
+
self.dim = config.MODEL.PHYSFORMER.DIM
|
| 37 |
+
self.ff_dim = config.MODEL.PHYSFORMER.FF_DIM
|
| 38 |
+
self.num_heads = config.MODEL.PHYSFORMER.NUM_HEADS
|
| 39 |
+
self.num_layers = config.MODEL.PHYSFORMER.NUM_LAYERS
|
| 40 |
+
self.theta = config.MODEL.PHYSFORMER.THETA
|
| 41 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 42 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 43 |
+
self.num_of_gpu = config.NUM_OF_GPU_TRAIN
|
| 44 |
+
self.frame_rate = config.TRAIN.DATA.FS
|
| 45 |
+
self.config = config
|
| 46 |
+
self.min_valid_loss = None
|
| 47 |
+
self.best_epoch = 0
|
| 48 |
+
|
| 49 |
+
if config.TOOLBOX_MODE == "train_and_test":
|
| 50 |
+
self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
|
| 51 |
+
self.model = ViT_ST_ST_Compact3_TDC_gra_sharp(
|
| 52 |
+
image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W),
|
| 53 |
+
patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers,
|
| 54 |
+
dropout_rate=self.dropout_rate, theta=self.theta).to(self.device)
|
| 55 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 56 |
+
|
| 57 |
+
self.num_train_batches = len(data_loader["train"])
|
| 58 |
+
self.criterion_reg = torch.nn.MSELoss()
|
| 59 |
+
self.criterion_L1loss = torch.nn.L1Loss()
|
| 60 |
+
self.criterion_class = torch.nn.CrossEntropyLoss()
|
| 61 |
+
self.criterion_Pearson = Neg_Pearson()
|
| 62 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0.00005)
|
| 63 |
+
# TODO: In both the PhysFormer repo's training example and other implementations of a PhysFormer trainer,
|
| 64 |
+
# a step_size that doesn't end up changing the LR always seems to be used. This seems to defeat the point
|
| 65 |
+
# of using StepLR in the first place. Consider investigating and using another approach (e.g., OneCycleLR).
|
| 66 |
+
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=50, gamma=0.5)
|
| 67 |
+
elif config.TOOLBOX_MODE == "only_test":
|
| 68 |
+
self.chunk_len = config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 69 |
+
self.model = ViT_ST_ST_Compact3_TDC_gra_sharp(
|
| 70 |
+
image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W),
|
| 71 |
+
patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers,
|
| 72 |
+
dropout_rate=self.dropout_rate, theta=self.theta).to(self.device)
|
| 73 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError("Physformer trainer initialized in incorrect toolbox mode!")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def train(self, data_loader):
|
| 79 |
+
"""Training routine for model"""
|
| 80 |
+
if data_loader["train"] is None:
|
| 81 |
+
raise ValueError("No data for train")
|
| 82 |
+
|
| 83 |
+
# a --> Pearson loss; b --> frequency loss
|
| 84 |
+
a_start = 1.0
|
| 85 |
+
b_start = 1.0
|
| 86 |
+
exp_a = 0.5 # Unused
|
| 87 |
+
exp_b = 1.0
|
| 88 |
+
|
| 89 |
+
# TODO: Expand tracking and subsequent plotting of these losses for PhysFormer
|
| 90 |
+
mean_training_losses = []
|
| 91 |
+
mean_valid_losses = []
|
| 92 |
+
lrs = []
|
| 93 |
+
|
| 94 |
+
for epoch in range(self.max_epoch_num):
|
| 95 |
+
print('')
|
| 96 |
+
print(f"====Training Epoch: {epoch}====")
|
| 97 |
+
loss_rPPG_avg = []
|
| 98 |
+
loss_peak_avg = []
|
| 99 |
+
loss_kl_avg_test = []
|
| 100 |
+
loss_hr_mae = []
|
| 101 |
+
|
| 102 |
+
self.model.train()
|
| 103 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 104 |
+
for idx, batch in enumerate(tbar):
|
| 105 |
+
hr = torch.tensor([self.get_hr(i) for i in batch[1]]).float().to(self.device)
|
| 106 |
+
data, label = batch[0].float().to(self.device), batch[1].float().to(self.device)
|
| 107 |
+
|
| 108 |
+
self.optimizer.zero_grad()
|
| 109 |
+
|
| 110 |
+
gra_sharp = 2.0
|
| 111 |
+
rPPG, _, _, _ = self.model(data, gra_sharp)
|
| 112 |
+
rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG, axis=-1).view(-1, 1) # normalize
|
| 113 |
+
loss_rPPG = self.criterion_Pearson(rPPG, label)
|
| 114 |
+
|
| 115 |
+
fre_loss = 0.0
|
| 116 |
+
kl_loss = 0.0
|
| 117 |
+
train_mae = 0.0
|
| 118 |
+
for bb in range(data.shape[0]):
|
| 119 |
+
loss_distribution_kl, \
|
| 120 |
+
fre_loss_temp, \
|
| 121 |
+
train_mae_temp = TorchLossComputer.cross_entropy_power_spectrum_DLDL_softmax2(
|
| 122 |
+
rPPG[bb],
|
| 123 |
+
hr[bb],
|
| 124 |
+
self.frame_rate,
|
| 125 |
+
std=1.0
|
| 126 |
+
)
|
| 127 |
+
fre_loss = fre_loss+fre_loss_temp
|
| 128 |
+
kl_loss = kl_loss+loss_distribution_kl
|
| 129 |
+
train_mae = train_mae+train_mae_temp
|
| 130 |
+
fre_loss /= data.shape[0]
|
| 131 |
+
kl_loss /= data.shape[0]
|
| 132 |
+
train_mae /= data.shape[0]
|
| 133 |
+
|
| 134 |
+
if epoch>10:
|
| 135 |
+
a = 0.05
|
| 136 |
+
b = 5.0
|
| 137 |
+
else:
|
| 138 |
+
a = a_start
|
| 139 |
+
# exp ascend
|
| 140 |
+
b = b_start*math.pow(exp_b, epoch/10.0)
|
| 141 |
+
|
| 142 |
+
loss = a*loss_rPPG + b*(fre_loss+kl_loss)
|
| 143 |
+
loss.backward()
|
| 144 |
+
self.optimizer.step()
|
| 145 |
+
|
| 146 |
+
n = data.size(0)
|
| 147 |
+
loss_rPPG_avg.append(float(loss_rPPG.data))
|
| 148 |
+
loss_peak_avg.append(float(fre_loss.data))
|
| 149 |
+
loss_kl_avg_test.append(float(kl_loss.data))
|
| 150 |
+
loss_hr_mae.append(float(train_mae))
|
| 151 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 152 |
+
print(f'\nepoch:{epoch}, batch:{idx + 1}, total:{len(data_loader["train"]) // self.batch_size}, '
|
| 153 |
+
f'lr:0.0001, sharp:{gra_sharp:.3f}, a:{a:.3f}, NegPearson:{np.mean(loss_rPPG_avg[-2000:]):.4f}, '
|
| 154 |
+
f'\nb:{b:.3f}, kl:{np.mean(loss_kl_avg_test[-2000:]):.3f}, fre_CEloss:{np.mean(loss_peak_avg[-2000:]):.3f}, '
|
| 155 |
+
f'hr_mae:{np.mean(loss_hr_mae[-2000:]):.3f}')
|
| 156 |
+
|
| 157 |
+
# Append the current learning rate to the list
|
| 158 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 159 |
+
# Append the mean training loss for the epoch
|
| 160 |
+
mean_training_losses.append(np.mean(loss_rPPG_avg))
|
| 161 |
+
self.save_model(epoch)
|
| 162 |
+
self.scheduler.step()
|
| 163 |
+
self.model.eval()
|
| 164 |
+
|
| 165 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 166 |
+
valid_loss = self.valid(data_loader)
|
| 167 |
+
mean_valid_losses.append(valid_loss)
|
| 168 |
+
print(f'Validation RMSE:{valid_loss:.3f}, batch:{idx+1}')
|
| 169 |
+
if self.min_valid_loss is None:
|
| 170 |
+
self.min_valid_loss = valid_loss
|
| 171 |
+
self.best_epoch = epoch
|
| 172 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 173 |
+
elif (valid_loss < self.min_valid_loss):
|
| 174 |
+
self.min_valid_loss = valid_loss
|
| 175 |
+
self.best_epoch = epoch
|
| 176 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 177 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 178 |
+
print("best trained epoch: {}, min_val_loss: {}".format(
|
| 179 |
+
self.best_epoch, self.min_valid_loss))
|
| 180 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 181 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 182 |
+
|
| 183 |
+
def valid(self, data_loader):
|
| 184 |
+
""" Runs the model on valid sets."""
|
| 185 |
+
if data_loader["valid"] is None:
|
| 186 |
+
raise ValueError("No data for valid")
|
| 187 |
+
|
| 188 |
+
print('')
|
| 189 |
+
print(" ====Validating===")
|
| 190 |
+
self.optimizer.zero_grad()
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
hrs = []
|
| 193 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 194 |
+
for val_idx, val_batch in enumerate(vbar):
|
| 195 |
+
data, label = val_batch[0].float().to(self.device), val_batch[1].float().to(self.device)
|
| 196 |
+
gra_sharp = 2.0
|
| 197 |
+
rPPG, _, _, _ = self.model(data, gra_sharp)
|
| 198 |
+
rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG).view(-1, 1)
|
| 199 |
+
for _1, _2 in zip(rPPG, label):
|
| 200 |
+
hrs.append((self.get_hr(_1.cpu().detach().numpy()), self.get_hr(_2.cpu().detach().numpy())))
|
| 201 |
+
RMSE = np.mean([(i-j)**2 for i, j in hrs])**0.5
|
| 202 |
+
return RMSE
|
| 203 |
+
|
| 204 |
+
def test(self, data_loader):
|
| 205 |
+
""" Runs the model on test sets."""
|
| 206 |
+
if data_loader["test"] is None:
|
| 207 |
+
raise ValueError("No data for test")
|
| 208 |
+
|
| 209 |
+
print('')
|
| 210 |
+
print("===Testing===")
|
| 211 |
+
|
| 212 |
+
# Change chunk length to be test chunk length
|
| 213 |
+
self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 214 |
+
|
| 215 |
+
predictions = dict()
|
| 216 |
+
labels = dict()
|
| 217 |
+
|
| 218 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 219 |
+
if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
|
| 220 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 221 |
+
self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")))
|
| 222 |
+
print("Testing uses pretrained model!")
|
| 223 |
+
print(self.config.INFERENCE.MODEL_PATH)
|
| 224 |
+
else:
|
| 225 |
+
if self.config.TEST.USE_LAST_EPOCH:
|
| 226 |
+
last_epoch_model_path = os.path.join(
|
| 227 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
|
| 228 |
+
print("Testing uses last epoch as non-pretrained model!")
|
| 229 |
+
print(last_epoch_model_path)
|
| 230 |
+
self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")))
|
| 231 |
+
else:
|
| 232 |
+
best_model_path = os.path.join(
|
| 233 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
|
| 234 |
+
print("Testing uses best epoch selected using model selection as non-pretrained model!")
|
| 235 |
+
print(best_model_path)
|
| 236 |
+
self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")))
|
| 237 |
+
|
| 238 |
+
self.model = self.model.to(self.config.DEVICE)
|
| 239 |
+
self.model.eval()
|
| 240 |
+
print("Running model evaluation on the testing dataset!")
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 243 |
+
batch_size = test_batch[0].shape[0]
|
| 244 |
+
data, label = test_batch[0].to(
|
| 245 |
+
self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
|
| 246 |
+
gra_sharp = 2.0
|
| 247 |
+
pred_ppg_test, _, _, _ = self.model(data, gra_sharp)
|
| 248 |
+
for idx in range(batch_size):
|
| 249 |
+
subj_index = test_batch[2][idx]
|
| 250 |
+
sort_index = int(test_batch[3][idx])
|
| 251 |
+
if subj_index not in predictions.keys():
|
| 252 |
+
predictions[subj_index] = dict()
|
| 253 |
+
labels[subj_index] = dict()
|
| 254 |
+
predictions[subj_index][sort_index] = pred_ppg_test[idx]
|
| 255 |
+
labels[subj_index][sort_index] = label[idx]
|
| 256 |
+
|
| 257 |
+
print('')
|
| 258 |
+
calculate_metrics(predictions, labels, self.config)
|
| 259 |
+
if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
|
| 260 |
+
self.save_test_outputs(predictions, labels, self.config)
|
| 261 |
+
|
| 262 |
+
def save_model(self, index):
|
| 263 |
+
if not os.path.exists(self.model_dir):
|
| 264 |
+
os.makedirs(self.model_dir)
|
| 265 |
+
model_path = os.path.join(
|
| 266 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 267 |
+
torch.save(self.model.state_dict(), model_path)
|
| 268 |
+
print('Saved Model Path: ', model_path)
|
| 269 |
+
|
| 270 |
+
# HR calculation based on ground truth label
|
| 271 |
+
def get_hr(self, y, sr=30, min=30, max=180):
|
| 272 |
+
p, q = welch(y, sr, nfft=1e5/sr, nperseg=np.min((len(y)-1, 256)))
|
| 273 |
+
return p[(p>min/60)&(p<max/60)][np.argmax(q[(p>min/60)&(p<max/60)])]*60
|
neural_methods/trainer/PhysFormerTrainer.py.backup
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trainer for Physformer.
|
| 2 |
+
|
| 3 |
+
Based on open-source code from the original PhysFormer authors below:
|
| 4 |
+
https://github.com/ZitongYu/PhysFormer/blob/main/train_Physformer_160_VIPL.py
|
| 5 |
+
|
| 6 |
+
We also thank the PhysBench authors for their open-source code based on the code
|
| 7 |
+
of the original authors. Their code below provided a better reference for tuning loss
|
| 8 |
+
parameters of interest and utilizing RSME as a validation loss:
|
| 9 |
+
https://github.com/KegangWangCCNU/PhysBench/blob/main/benchmark_addition/PhysFormer_pure.ipynb
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import numpy as np
|
| 15 |
+
import math
|
| 16 |
+
import torch
|
| 17 |
+
import torch.optim as optim
|
| 18 |
+
from evaluation.metrics import calculate_metrics
|
| 19 |
+
from neural_methods.loss.PhysNetNegPearsonLoss import Neg_Pearson
|
| 20 |
+
from neural_methods.loss.PhysFormerLossComputer import TorchLossComputer
|
| 21 |
+
from neural_methods.model.PhysFormer import ViT_ST_ST_Compact3_TDC_gra_sharp
|
| 22 |
+
from neural_methods.trainer.BaseTrainer import BaseTrainer
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
from scipy.signal import welch
|
| 25 |
+
|
| 26 |
+
class PhysFormerTrainer(BaseTrainer):
|
| 27 |
+
|
| 28 |
+
def __init__(self, config, data_loader):
|
| 29 |
+
"""Inits parameters from args and the writer for TensorboardX."""
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.device = torch.device(config.DEVICE)
|
| 32 |
+
self.max_epoch_num = config.TRAIN.EPOCHS
|
| 33 |
+
self.model_dir = config.MODEL.MODEL_DIR
|
| 34 |
+
self.dropout_rate = config.MODEL.DROP_RATE
|
| 35 |
+
self.patch_size = config.MODEL.PHYSFORMER.PATCH_SIZE
|
| 36 |
+
self.dim = config.MODEL.PHYSFORMER.DIM
|
| 37 |
+
self.ff_dim = config.MODEL.PHYSFORMER.FF_DIM
|
| 38 |
+
self.num_heads = config.MODEL.PHYSFORMER.NUM_HEADS
|
| 39 |
+
self.num_layers = config.MODEL.PHYSFORMER.NUM_LAYERS
|
| 40 |
+
self.theta = config.MODEL.PHYSFORMER.THETA
|
| 41 |
+
self.model_file_name = config.TRAIN.MODEL_FILE_NAME
|
| 42 |
+
self.batch_size = config.TRAIN.BATCH_SIZE
|
| 43 |
+
self.num_of_gpu = config.NUM_OF_GPU_TRAIN
|
| 44 |
+
self.frame_rate = config.TRAIN.DATA.FS
|
| 45 |
+
self.config = config
|
| 46 |
+
self.min_valid_loss = None
|
| 47 |
+
self.best_epoch = 0
|
| 48 |
+
|
| 49 |
+
if config.TOOLBOX_MODE == "train_and_test":
|
| 50 |
+
self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH
|
| 51 |
+
self.model = ViT_ST_ST_Compact3_TDC_gra_sharp(
|
| 52 |
+
image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W),
|
| 53 |
+
patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers,
|
| 54 |
+
dropout_rate=self.dropout_rate, theta=self.theta).to(self.device)
|
| 55 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 56 |
+
|
| 57 |
+
self.num_train_batches = len(data_loader["train"])
|
| 58 |
+
self.criterion_reg = torch.nn.MSELoss()
|
| 59 |
+
self.criterion_L1loss = torch.nn.L1Loss()
|
| 60 |
+
self.criterion_class = torch.nn.CrossEntropyLoss()
|
| 61 |
+
self.criterion_Pearson = Neg_Pearson()
|
| 62 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0.00005)
|
| 63 |
+
# TODO: In both the PhysFormer repo's training example and other implementations of a PhysFormer trainer,
|
| 64 |
+
# a step_size that doesn't end up changing the LR always seems to be used. This seems to defeat the point
|
| 65 |
+
# of using StepLR in the first place. Consider investigating and using another approach (e.g., OneCycleLR).
|
| 66 |
+
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=50, gamma=0.5)
|
| 67 |
+
elif config.TOOLBOX_MODE == "only_test":
|
| 68 |
+
self.chunk_len = config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 69 |
+
self.model = ViT_ST_ST_Compact3_TDC_gra_sharp(
|
| 70 |
+
image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W),
|
| 71 |
+
patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers,
|
| 72 |
+
dropout_rate=self.dropout_rate, theta=self.theta).to(self.device)
|
| 73 |
+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN)))
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError("Physformer trainer initialized in incorrect toolbox mode!")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def train(self, data_loader):
|
| 79 |
+
"""Training routine for model"""
|
| 80 |
+
if data_loader["train"] is None:
|
| 81 |
+
raise ValueError("No data for train")
|
| 82 |
+
|
| 83 |
+
# a --> Pearson loss; b --> frequency loss
|
| 84 |
+
a_start = 1.0
|
| 85 |
+
b_start = 1.0
|
| 86 |
+
exp_a = 0.5 # Unused
|
| 87 |
+
exp_b = 1.0
|
| 88 |
+
|
| 89 |
+
# TODO: Expand tracking and subsequent plotting of these losses for PhysFormer
|
| 90 |
+
mean_training_losses = []
|
| 91 |
+
mean_valid_losses = []
|
| 92 |
+
lrs = []
|
| 93 |
+
|
| 94 |
+
for epoch in range(self.max_epoch_num):
|
| 95 |
+
print('')
|
| 96 |
+
print(f"====Training Epoch: {epoch}====")
|
| 97 |
+
loss_rPPG_avg = []
|
| 98 |
+
loss_peak_avg = []
|
| 99 |
+
loss_kl_avg_test = []
|
| 100 |
+
loss_hr_mae = []
|
| 101 |
+
|
| 102 |
+
self.model.train()
|
| 103 |
+
tbar = tqdm(data_loader["train"], ncols=80)
|
| 104 |
+
for idx, batch in enumerate(tbar):
|
| 105 |
+
hr = torch.tensor([self.get_hr(i) for i in batch[1]]).float().to(self.device)
|
| 106 |
+
data, label = batch[0].float().to(self.device), batch[1].float().to(self.device)
|
| 107 |
+
|
| 108 |
+
self.optimizer.zero_grad()
|
| 109 |
+
|
| 110 |
+
gra_sharp = 2.0
|
| 111 |
+
rPPG, _, _, _ = self.model(data, gra_sharp)
|
| 112 |
+
rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG, axis=-1).view(-1, 1) # normalize
|
| 113 |
+
loss_rPPG = self.criterion_Pearson(rPPG, label)
|
| 114 |
+
|
| 115 |
+
fre_loss = 0.0
|
| 116 |
+
kl_loss = 0.0
|
| 117 |
+
train_mae = 0.0
|
| 118 |
+
for bb in range(data.shape[0]):
|
| 119 |
+
loss_distribution_kl, \
|
| 120 |
+
fre_loss_temp, \
|
| 121 |
+
train_mae_temp = TorchLossComputer.cross_entropy_power_spectrum_DLDL_softmax2(
|
| 122 |
+
rPPG[bb],
|
| 123 |
+
hr[bb],
|
| 124 |
+
self.frame_rate,
|
| 125 |
+
std=1.0
|
| 126 |
+
)
|
| 127 |
+
fre_loss = fre_loss+fre_loss_temp
|
| 128 |
+
kl_loss = kl_loss+loss_distribution_kl
|
| 129 |
+
train_mae = train_mae+train_mae_temp
|
| 130 |
+
fre_loss /= data.shape[0]
|
| 131 |
+
kl_loss /= data.shape[0]
|
| 132 |
+
train_mae /= data.shape[0]
|
| 133 |
+
|
| 134 |
+
if epoch>10:
|
| 135 |
+
a = 0.05
|
| 136 |
+
b = 5.0
|
| 137 |
+
else:
|
| 138 |
+
a = a_start
|
| 139 |
+
# exp ascend
|
| 140 |
+
b = b_start*math.pow(exp_b, epoch/10.0)
|
| 141 |
+
|
| 142 |
+
loss = a*loss_rPPG + b*(fre_loss+kl_loss)
|
| 143 |
+
loss.backward()
|
| 144 |
+
self.optimizer.step()
|
| 145 |
+
|
| 146 |
+
n = data.size(0)
|
| 147 |
+
loss_rPPG_avg.append(float(loss_rPPG.data))
|
| 148 |
+
loss_peak_avg.append(float(fre_loss.data))
|
| 149 |
+
loss_kl_avg_test.append(float(kl_loss.data))
|
| 150 |
+
loss_hr_mae.append(float(train_mae))
|
| 151 |
+
if idx % 100 == 99: # print every 100 mini-batches
|
| 152 |
+
print(f'\nepoch:{epoch}, batch:{idx + 1}, total:{len(data_loader["train"]) // self.batch_size}, '
|
| 153 |
+
f'lr:0.0001, sharp:{gra_sharp:.3f}, a:{a:.3f}, NegPearson:{np.mean(loss_rPPG_avg[-2000:]):.4f}, '
|
| 154 |
+
f'\nb:{b:.3f}, kl:{np.mean(loss_kl_avg_test[-2000:]):.3f}, fre_CEloss:{np.mean(loss_peak_avg[-2000:]):.3f}, '
|
| 155 |
+
f'hr_mae:{np.mean(loss_hr_mae[-2000:]):.3f}')
|
| 156 |
+
|
| 157 |
+
# Append the current learning rate to the list
|
| 158 |
+
lrs.append(self.scheduler.get_last_lr())
|
| 159 |
+
# Append the mean training loss for the epoch
|
| 160 |
+
mean_training_losses.append(np.mean(loss_rPPG_avg))
|
| 161 |
+
self.save_model(epoch)
|
| 162 |
+
self.scheduler.step()
|
| 163 |
+
self.model.eval()
|
| 164 |
+
|
| 165 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 166 |
+
valid_loss = self.valid(data_loader)
|
| 167 |
+
mean_valid_losses.append(valid_loss)
|
| 168 |
+
print(f'Validation RMSE:{valid_loss:.3f}, batch:{idx+1}')
|
| 169 |
+
if self.min_valid_loss is None:
|
| 170 |
+
self.min_valid_loss = valid_loss
|
| 171 |
+
self.best_epoch = epoch
|
| 172 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 173 |
+
elif (valid_loss < self.min_valid_loss):
|
| 174 |
+
self.min_valid_loss = valid_loss
|
| 175 |
+
self.best_epoch = epoch
|
| 176 |
+
print("Update best model! Best epoch: {}".format(self.best_epoch))
|
| 177 |
+
if not self.config.TEST.USE_LAST_EPOCH:
|
| 178 |
+
print("best trained epoch: {}, min_val_loss: {}".format(
|
| 179 |
+
self.best_epoch, self.min_valid_loss))
|
| 180 |
+
if self.config.TRAIN.PLOT_LOSSES_AND_LR:
|
| 181 |
+
self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config)
|
| 182 |
+
|
| 183 |
+
def valid(self, data_loader):
|
| 184 |
+
""" Runs the model on valid sets."""
|
| 185 |
+
if data_loader["valid"] is None:
|
| 186 |
+
raise ValueError("No data for valid")
|
| 187 |
+
|
| 188 |
+
print('')
|
| 189 |
+
print(" ====Validating===")
|
| 190 |
+
self.optimizer.zero_grad()
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
hrs = []
|
| 193 |
+
vbar = tqdm(data_loader["valid"], ncols=80)
|
| 194 |
+
for val_idx, val_batch in enumerate(vbar):
|
| 195 |
+
data, label = val_batch[0].float().to(self.device), val_batch[1].float().to(self.device)
|
| 196 |
+
gra_sharp = 2.0
|
| 197 |
+
rPPG, _, _, _ = self.model(data, gra_sharp)
|
| 198 |
+
rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG).view(-1, 1)
|
| 199 |
+
for _1, _2 in zip(rPPG, label):
|
| 200 |
+
hrs.append((self.get_hr(_1.cpu().detach().numpy()), self.get_hr(_2.cpu().detach().numpy())))
|
| 201 |
+
RMSE = np.mean([(i-j)**2 for i, j in hrs])**0.5
|
| 202 |
+
return RMSE
|
| 203 |
+
|
| 204 |
+
def test(self, data_loader):
|
| 205 |
+
""" Runs the model on test sets."""
|
| 206 |
+
if data_loader["test"] is None:
|
| 207 |
+
raise ValueError("No data for test")
|
| 208 |
+
|
| 209 |
+
print('')
|
| 210 |
+
print("===Testing===")
|
| 211 |
+
|
| 212 |
+
# Change chunk length to be test chunk length
|
| 213 |
+
self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH
|
| 214 |
+
|
| 215 |
+
predictions = dict()
|
| 216 |
+
labels = dict()
|
| 217 |
+
|
| 218 |
+
if self.config.TOOLBOX_MODE == "only_test":
|
| 219 |
+
if not os.path.exists(self.config.INFERENCE.MODEL_PATH):
|
| 220 |
+
raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.")
|
| 221 |
+
self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH))
|
| 222 |
+
print("Testing uses pretrained model!")
|
| 223 |
+
print(self.config.INFERENCE.MODEL_PATH)
|
| 224 |
+
else:
|
| 225 |
+
if self.config.TEST.USE_LAST_EPOCH:
|
| 226 |
+
last_epoch_model_path = os.path.join(
|
| 227 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth')
|
| 228 |
+
print("Testing uses last epoch as non-pretrained model!")
|
| 229 |
+
print(last_epoch_model_path)
|
| 230 |
+
self.model.load_state_dict(torch.load(last_epoch_model_path))
|
| 231 |
+
else:
|
| 232 |
+
best_model_path = os.path.join(
|
| 233 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth')
|
| 234 |
+
print("Testing uses best epoch selected using model selection as non-pretrained model!")
|
| 235 |
+
print(best_model_path)
|
| 236 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 237 |
+
|
| 238 |
+
self.model = self.model.to(self.config.DEVICE)
|
| 239 |
+
self.model.eval()
|
| 240 |
+
print("Running model evaluation on the testing dataset!")
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)):
|
| 243 |
+
batch_size = test_batch[0].shape[0]
|
| 244 |
+
data, label = test_batch[0].to(
|
| 245 |
+
self.config.DEVICE), test_batch[1].to(self.config.DEVICE)
|
| 246 |
+
gra_sharp = 2.0
|
| 247 |
+
pred_ppg_test, _, _, _ = self.model(data, gra_sharp)
|
| 248 |
+
for idx in range(batch_size):
|
| 249 |
+
subj_index = test_batch[2][idx]
|
| 250 |
+
sort_index = int(test_batch[3][idx])
|
| 251 |
+
if subj_index not in predictions.keys():
|
| 252 |
+
predictions[subj_index] = dict()
|
| 253 |
+
labels[subj_index] = dict()
|
| 254 |
+
predictions[subj_index][sort_index] = pred_ppg_test[idx]
|
| 255 |
+
labels[subj_index][sort_index] = label[idx]
|
| 256 |
+
|
| 257 |
+
print('')
|
| 258 |
+
calculate_metrics(predictions, labels, self.config)
|
| 259 |
+
if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs
|
| 260 |
+
self.save_test_outputs(predictions, labels, self.config)
|
| 261 |
+
|
| 262 |
+
def save_model(self, index):
|
| 263 |
+
if not os.path.exists(self.model_dir):
|
| 264 |
+
os.makedirs(self.model_dir)
|
| 265 |
+
model_path = os.path.join(
|
| 266 |
+
self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
|
| 267 |
+
torch.save(self.model.state_dict(), model_path)
|
| 268 |
+
print('Saved Model Path: ', model_path)
|
| 269 |
+
|
| 270 |
+
# HR calculation based on ground truth label
|
| 271 |
+
def get_hr(self, y, sr=30, min=30, max=180):
|
| 272 |
+
p, q = welch(y, sr, nfft=1e5/sr, nperseg=np.min((len(y)-1, 256)))
|
| 273 |
+
return p[(p>min/60)&(p<max/60)][np.argmax(q[(p>min/60)&(p<max/60)])]*60
|