diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a1579f0d03ce4e1ca6fc94b2fe42abcf4d8c738b --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +*.so +*.egg +*.egg-info/ +dist/ +build/ +.pyenv/ +.venv/ +logs/ +analysis/ +*.log +.DS_Store diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4fbf7edc46ffa516a0dfee8f3652a685e001650d --- /dev/null +++ b/app.py @@ -0,0 +1,2559 @@ +import os +os.environ.setdefault("OPENCV_AVFOUNDATION_SKIP_AUTH", "1") +os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") + +import warnings +warnings.filterwarnings("ignore") + +import re +import io +import json +import tempfile +import time +import threading +import uuid +import importlib.util +from pathlib import Path +from collections import deque +from typing import Optional, Tuple, Dict, List + +import numpy as np +import pandas as pd +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from scipy import signal +from scipy.signal import find_peaks, welch, get_window + +# .mat support (SCAMPS) +try: + from scipy.io import loadmat + MAT_SUPPORT = True +except Exception: + MAT_SUPPORT = False + +# Matplotlib headless backend +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec + +import gradio as gr + +class PhysMambaattention_viz: + """Simplified Grad-CAM for PhysMamba.""" + + def __init__(self, model: nn.Module, device: torch.device): + self.model = model + self.device = device + self.activations = None + self.gradients = None + self.hook_handles = [] + + def _get_target_layer(self): + """Auto-detect the best layer for visualization, excluding Mamba/SSM layers.""" + + # Strategy 1: Look for temporal difference convolution + for name, module in reversed(list(self.model.named_modules())): + if ('tdc' in name.lower() or 'temporal_diff' in name.lower()) and not ('mamba' in name.lower() or 'ssm' in name.lower()): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + print(f"[attention_viz] Selected TDC layer: {name} ({type(module).__name__})") + return name, module + + # Strategy 2: Look for 3D convolutions (temporal-spatial) + for name, module in reversed(list(self.model.named_modules())): + if isinstance(module, nn.Conv3d) and not ('mamba' in name.lower() or 'ssm' in name.lower()): + print(f"[attention_viz] Selected Conv3d layer: {name}") + return name, module + + # Strategy 3: Look for 2D convolutions (spatial) + for name, module in reversed(list(self.model.named_modules())): + if isinstance(module, nn.Conv2d) and not ('mamba' in name.lower() or 'ssm' in name.lower()): + print(f"[attention_viz] Selected Conv2d layer: {name}") + return name, module + + # Strategy 4: Look for any convolution (last resort) + for name, module in reversed(list(self.model.named_modules())): + if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + print(f"[attention_viz] Selected Conv layer (fallback): {name} ({type(module).__name__})") + return name, module + + # Strategy 5: Print all available layers and pick the last non-Mamba one + print("\n[attention_viz] Available layers:") + suitable_layers = [] + for name, module in self.model.named_modules(): + layer_type = type(module).__name__ + if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): + is_mamba = 'mamba' in name.lower() or 'ssm' in name.lower() + print(f" - {name}: {layer_type} {'(MAMBA - SKIP)' if is_mamba else ''}") + if not is_mamba: + suitable_layers.append((name, module)) + + if suitable_layers: + name, module = suitable_layers[-1] + print(f"\n[attention_viz] Selected last suitable layer: {name} ({type(module).__name__})") + return name, module + + raise ValueError("No suitable layer found for Grad-CAM (all layers are Mamba/SSM)") + + def _forward_hook(self, module, input, output): + self.activations = output.detach() + + def _backward_hook(self, module, grad_input, grad_output): + self.gradients = grad_output[0].detach() + + def register_hooks(self, target_layer_name: Optional[str] = None): + """Register hooks on target layer""" + self._remove_hooks() + + if target_layer_name and target_layer_name.strip(): + # Manual selection + try: + target_module = dict(self.model.named_modules())[target_layer_name.strip()] + print(f"Using manually specified layer: {target_layer_name}") + except KeyError: + print(f"⚠ Layer '{target_layer_name}' not found, falling back to auto-detection") + target_layer_name, target_module = self._get_target_layer() + else: + # Auto-detection + target_layer_name, target_module = self._get_target_layer() + + fwd_handle = target_module.register_forward_hook(self._forward_hook) + bwd_handle = target_module.register_full_backward_hook(self._backward_hook) + + self.hook_handles = [fwd_handle, bwd_handle] + + def _remove_hooks(self): + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] + self.activations = None + self.gradients = None + + def generate(self, input_tensor: torch.Tensor) -> np.ndarray: + """Generate Grad-CAM heatmap with improved gradient handling.""" + # Set model to eval but keep gradient computation + self.model.eval() + + # Ensure tensor requires grad + input_tensor = input_tensor.requires_grad_(True) + + # Forward pass + output = self.model(input_tensor) + + # Handle different output types + if isinstance(output, dict): + # Try common keys + for key in ('pred', 'output', 'bvp', 'logits', 'out'): + if key in output and isinstance(output[key], torch.Tensor): + output = output[key] + break + + # Get scalar for backward + if output.numel() == 1: + target = output + else: + # Use mean of output + target = output.mean() + + print(f"[attention_viz] Output shape: {output.shape}, Target: {target.item():.4f}") + + # Backward pass + self.model.zero_grad() + target.backward(retain_graph=False) + + # Check if we got gradients + if self.activations is None: + print("⚠ No activations captured!") + return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1])) + + if self.gradients is None: + print("⚠ No gradients captured!") + return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1])) + + print(f"[attention_viz] Activations shape: {self.activations.shape}") + print(f"[attention_viz] Gradients shape: {self.gradients.shape}") + + activations = self.activations + gradients = self.gradients + + # Compute weights (global average pooling of gradients) + if gradients.dim() == 5: # [B, C, T, H, W] + weights = gradients.mean(dim=[2, 3, 4], keepdim=True) + elif gradients.dim() == 4: # [B, C, H, W] + weights = gradients.mean(dim=[2, 3], keepdim=True) + elif gradients.dim() == 3: # [B, C, T] + weights = gradients.mean(dim=2, keepdim=True).unsqueeze(-1).unsqueeze(-1) + else: + print(f"⚠ Unexpected gradient dimensions: {gradients.dim()}") + return np.zeros((input_tensor.shape[-2], input_tensor.shape[-1])) + + # Weighted combination + cam = (weights * activations).sum(dim=1, keepdim=True) + + # If 5D, average over time + if cam.dim() == 5: + cam = cam.mean(dim=2) + + # ReLU + cam = F.relu(cam) + + # Convert to numpy + cam = cam.squeeze().cpu().detach().numpy() + + # Normalize + if cam.max() > cam.min(): + cam = (cam - cam.min()) / (cam.max() - cam.min()) + + print(f"[attention_viz] Final heatmap shape: {cam.shape}, range: [{cam.min():.3f}, {cam.max():.3f}]") + + return cam + + def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray: + """Overlay heatmap on frame.""" + h, w = frame.shape[:2] + heatmap_resized = cv2.resize(heatmap, (w, h)) + heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8) + heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) + overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0) + return overlay + + def cleanup(self): + self._remove_hooks() + + +def apply_diff_normalized(frames: List[np.ndarray]) -> np.ndarray: + """Apply DiffNormalized preprocessing from PhysMamba paper.""" + if len(frames) < 2: + return np.zeros((len(frames), *frames[0].shape), dtype=np.float32) + + diff_frames = [] + + for i in range(len(frames)): + if i == 0: + diff_frames.append(np.zeros_like(frames[0], dtype=np.float32)) + else: + curr = frames[i].astype(np.float32) + prev = frames[i-1].astype(np.float32) + denominator = curr + prev + 1e-8 + diff = (curr - prev) / denominator + diff_frames.append(diff) + + diff_array = np.stack(diff_frames) + std = diff_array.std() + if std > 0: + diff_array = diff_array / std + + return diff_array + + +def preprocess_for_physmamba(frames: List[np.ndarray], + target_frames: int = 128, + target_size: int = 128) -> torch.Tensor: + """Complete preprocessing pipeline for PhysMamba model.""" + if len(frames) < target_frames: + frames = frames + [frames[-1]] * (target_frames - len(frames)) + elif len(frames) > target_frames: + indices = np.linspace(0, len(frames)-1, target_frames).astype(int) + frames = [frames[i] for i in indices] + + frames_rgb = [f[..., ::-1].copy() for f in frames] + frames_resized = [cv2.resize(f, (target_size, target_size)) for f in frames_rgb] + frames_diff = apply_diff_normalized(frames_resized) + frames_transposed = np.transpose(frames_diff, (3, 0, 1, 2)) + frames_batched = np.expand_dims(frames_transposed, axis=0) + tensor = torch.from_numpy(frames_batched.astype(np.float32)) + + return tensor + +HERE = Path(__file__).parent +MODEL_DIR = HERE / "final_model_release" +LOG_DIR = HERE / "logs" +ANALYSIS_DIR = HERE / "analysis" +for d in [MODEL_DIR, LOG_DIR, ANALYSIS_DIR]: + d.mkdir(exist_ok=True) + +DEVICE = ( + torch.device("cuda") if torch.cuda.is_available() + else torch.device("mps") if torch.backends.mps.is_available() + else torch.device("cpu") +) + +FACE_CASCADE = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml") + +DEFAULT_SIZE = 128 # input H=W to model +DEFAULT_T = 128 # clip length +DEFAULT_STRIDE = 8 # inference hop +DISPLAY_FPS = 10 + +VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'] + +_HR_SMOOTH = None +REST_HR_TARGET = 72.0 +REST_HR_RANGE = (55.0, 95.0) +MAX_JUMP_BPM = 8.0 + +# Recognized GT files for subject folders +GT_FILENAMES = {"ground_truth.txt", "gtdump.txt", "gt.txt"} +GT_EXTS = {".txt", ".csv", ".json"} + +def _as_path(maybe) -> Optional[str]: + """Return a filesystem path from Gradio values (str, dict, tempfile objects).""" + if maybe is None: + return None + if isinstance(maybe, str): + return maybe + if isinstance(maybe, dict): + return maybe.get("name") or maybe.get("path") + name = getattr(maybe, "name", None) # tempfile-like object + if isinstance(name, str): + return name + try: + return str(maybe) + except Exception: + return None + +def _import_from_file(py_path: Path): + spec = importlib.util.spec_from_file_location(py_path.stem, str(py_path)) + if not spec or not spec.loader: + raise ImportError(f"Cannot import module from {py_path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + +def _looks_like_video(p: Path) -> bool: + if p.suffix.lower() == ".mat": + return True + return p.suffix.lower() in VIDEO_EXTENSIONS + +class SimpleActivationAttention: + """Lightweight attention visualization without gradients.""" + + def __init__(self, model: nn.Module, device: torch.device): + self.model = model + self.device = device + self.activations = None + self.hook_handle = None + + def _activation_hook(self, module, input, output): + """Capture activations during forward pass.""" + self.activations = output.detach() + + def register_hook(self): + """Register hook on a suitable layer.""" + # Find the last convolutional layer before Mamba + target = None + target_name = None + + for name, module in self.model.named_modules(): + if isinstance(module, (nn.Conv2d, nn.Conv3d)) and 'mamba' not in name.lower() and 'ssm' not in name.lower(): + target = module + target_name = name + + if target is None: + print("⚠ [attention_viz] No suitable conv layer found, attention disabled") + return + + self.hook_handle = target.register_forward_hook(self._activation_hook) + print(f"✓ [attention_viz] Hook registered on {target_name} ({type(target).__name__})") + + def generate(self, clip_tensor: torch.Tensor) -> Optional[np.ndarray]: + """Generate attention map from activations (call after forward pass).""" + try: + if self.activations is None: + return None + + # Process activations to create spatial attention + act = self.activations + + # Handle different tensor shapes + if act.dim() == 5: # [B, C, T, H, W] + # Average over time and channels + attention = act.mean(dim=[1, 2]) # -> [B, H, W] + elif act.dim() == 4: # [B, C, H, W] + attention = act.mean(dim=1) # -> [B, H, W] + else: + print(f"⚠ [attention_viz] Unexpected activation shape: {act.shape}") + return None + + # Convert to numpy + attention = attention.squeeze().cpu().numpy() + + # Normalize to [0, 1] + if attention.max() > attention.min(): + attention = (attention - attention.min()) / (attention.max() - attention.min()) + + return attention + + except Exception as e: + print(f"⚠ [attention_viz] Generation failed: {e}") + return None + + def visualize(self, heatmap: np.ndarray, frame: np.ndarray, alpha: float = 0.4) -> np.ndarray: + """Overlay heatmap on frame.""" + h, w = frame.shape[:2] + heatmap_resized = cv2.resize(heatmap, (w, h)) + heatmap_uint8 = (heatmap_resized * 255).astype(np.uint8) + heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) + overlay = cv2.addWeighted(frame, 1-alpha, heatmap_colored, alpha, 0) + return overlay + + def cleanup(self): + if self.hook_handle is not None: + self.hook_handle.remove() + +class VideoReader: + """ + Unified frame reader: + • Regular videos via cv2.VideoCapture + • .mat 'videos' (e.g., SCAMPS): expects array like (T,H,W,3) or (H,W,T,3) or (T,H,W) + Returns frames as BGR uint8. + """ + def __init__(self, path: str): + self.path = str(path) + self._cap = None + self._mat = None + self._idx = 0 + self._len = 0 + self._shape = None + + if self.path.lower().endswith(".mat") and MAT_SUPPORT: + self._open_mat(self.path) + else: + self._open_cv(self.path) + + def _open_cv(self, path: str): + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise RuntimeError("Cannot open video") + self._cap = cap + self._len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) + + def _open_mat(self, path: str): + try: + md = loadmat(path) + # Common keys in SCAMPS-like dumps + for key in ("video", "frames", "vid", "data"): + if key in md and isinstance(md[key], np.ndarray): + arr = md[key] + break + else: + arr = next((v for v in md.values() if isinstance(v, np.ndarray)), None) + if arr is None: + raise RuntimeError("No ndarray found in .mat") + + a = np.asarray(arr) + # Normalize to (T,H,W,3) + if a.ndim == 4: + if a.shape[-1] == 3: + if a.shape[0] < a.shape[2]: # (T,H,W,3) heuristic + v = a + else: # (H,W,T,3) -> (T,H,W,3) + v = np.transpose(a, (2, 0, 1, 3)) + else: + v = a[..., :1] # take first channel + elif a.ndim == 3: + if a.shape[0] < a.shape[2]: # (T,H,W) + v = a + else: # (H,W,T) -> (T,H,W) + v = np.transpose(a, (2, 0, 1)) + v = v[..., None] + else: + raise RuntimeError(f"Unsupported .mat video shape: {a.shape}") + + v = np.ascontiguousarray(v) + if v.shape[-1] == 1: + v = np.repeat(v, 3, axis=-1) + v = v.astype(np.uint8) + self._mat = v + self._len = v.shape[0] + self._shape = v.shape[1:3] + except Exception as e: + raise RuntimeError(f"Failed to open .mat video: {e}") + + def read(self): + """Return (ret, frame[BGR]) like cv2.VideoCapture.read().""" + if self._mat is not None: + if self._idx >= self._len: + return False, None + frame = self._mat[self._idx] + self._idx += 1 + return True, frame + else: + return self._cap.read() + + def fps(self, fallback: int = 30) -> int: + if self._mat is not None: + return fallback # .mat typically lacks FPS; caller can override + f = self._cap.get(cv2.CAP_PROP_FPS) + return int(f) if f and f > 0 else fallback + + def length(self) -> int: + return self._len + + def release(self): + if self._cap is not None: + self._cap.release() + +def roi_candidates(face: Tuple[int, int, int, int], frame: np.ndarray) -> Dict[str, np.ndarray]: + x, y, w, h = face + # forehead + fh = frame[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)] + # cheeks + ck = frame[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)] + # full face + ff = frame[y:y + h, x:x + w] + return {"forehead": fh, "cheeks": ck, "face": ff} + +def roi_quality_score(patch: Optional[np.ndarray], fs: int = 30) -> float: + if patch is None or patch.size == 0: + return -1e9 + g = patch[..., 1].astype(np.float32) / 255.0 # green channel + g = cv2.resize(g, (64, 64)).mean(axis=1) # crude spatial pooling + g = g - g.mean() + b, a = signal.butter(4, [0.7 / (fs / 2), 3.5 / (fs / 2)], btype="band") + try: + y = signal.filtfilt(b, a, g, method="gust") + except Exception: + y = g + return float((y ** 2).mean()) + +def pick_auto_roi(face: Tuple[int, int, int, int], + frame: np.ndarray, + attn: Optional[np.ndarray] = None) -> Tuple[np.ndarray, str]: + """Simple ROI selection.""" + cands = roi_candidates(face, frame) + scores = {k: roi_quality_score(v) for k, v in cands.items()} + + if attn is not None and attn.size: + H, W = frame.shape[:2] + try: + attn_resized = cv2.resize(attn, (W, H)) + x, y, w, h = face + fh_attn = attn_resized[int(y + 0.10 * h):int(y + 0.30 * h), int(x + 0.25 * w):int(x + 0.75 * w)].mean() if attn_resized.size > 0 else 0.0 + ck_attn = attn_resized[int(y + 0.55 * h):int(y + 0.85 * h), int(x + 0.15 * w):int(x + 0.85 * w)].mean() if attn_resized.size > 0 else 0.0 + ff_attn = attn_resized[y:y+h, x:x+w].mean() if attn_resized.size > 0 else 0.0 + scores['forehead'] += fh_attn * 0.2 + scores['cheeks'] += ck_attn * 0.2 + scores['face'] += ff_attn * 0.2 + except Exception: + pass + + best = max(scores, key=scores.get) + return cands[best], best + +def discover_subjects(root_dir: Path) -> List[Tuple[str, Optional[str]]]: + """ + Walk root_dir; for each subject folder (or single-folder dataset), return (video_path, gt_path or None). + Heuristics: + - Subject folder: any directory containing at least one video-like file (.mat or common video). + - If multiple videos, pick the largest by size. + - GT file: prefer known names in GT_FILENAMES, else any .txt/.csv/.json not named readme. + """ + pairs: List[Tuple[str, Optional[str]]] = [] + if not root_dir.exists(): + return pairs + + def pick_pair(folder: Path) -> Optional[Tuple[str, Optional[str]]]: + vids = [p for p in folder.rglob("*") if p.is_file() and _looks_like_video(p)] + if not vids: + return None + vids.sort(key=lambda p: p.stat().st_size if p.exists() else 0, reverse=True) + video = vids[0] + + gt: Optional[Path] = None + for p in folder.rglob("*"): + if p.is_file() and p.name.lower() in GT_FILENAMES: + gt = p + break + if gt is None: + cands = [ + p for p in folder.rglob("*") + if p.is_file() and p.suffix.lower() in GT_EXTS and "readme" not in p.name.lower() + ] + if cands: + gt = cands[0] + return str(video), (str(gt) if gt else None) + + subs = [d for d in root_dir.iterdir() if d.is_dir()] + if subs: + for sub in subs: + pair = pick_pair(sub) + if pair: + pairs.append(pair) + else: + pair = pick_pair(root_dir) # the root itself might be a single-subject folder + if pair: + pairs.append(pair) + + # Deduplicate + seen = set() + uniq: List[Tuple[str, Optional[str]]] = [] + for v, g in pairs: + key = (v, g or "") + if key not in seen: + seen.add(key) + uniq.append((v, g)) + return uniq + +def find_physmamba_builder(repo_root: Path, model_file: str = "", model_class: str = "PhysMamba"): + import inspect + + if model_file: + model_path = (repo_root / model_file).resolve() + if model_path.exists(): + try: + mod = _import_from_file(model_path) + if hasattr(mod, model_class): + return getattr(mod, model_class) + except Exception: + pass + + search_dirs = [ + repo_root / "neural_methods" / "model", + repo_root / "neural_methods", + repo_root + ] + + name_pattern = re.compile(r"mamba", re.IGNORECASE) + + for base_dir in search_dirs: + if not base_dir.exists(): + continue + + for py_file in base_dir.rglob("*.py"): + if "__pycache__" in str(py_file) or "mamba_ssm" in str(py_file): + continue + + try: + mod = _import_from_file(py_file) + for name, obj in inspect.getmembers(mod): + if callable(obj) and name_pattern.search(name) and "ssm" not in name.lower(): + if inspect.isclass(obj) and issubclass(obj, nn.Module): + return obj + except Exception: + continue + + raise ImportError(f"Could not find PhysMamba model class") + +def load_physmamba_model(ckpt_path: Path, device: torch.device, + model_file: str = "", model_class: str = "PhysMamba"): + + repo_root = Path(".").resolve() + Builder = find_physmamba_builder(repo_root, model_file, model_class) + + import inspect + ctor_trials = [ + {}, + {"d_model": 96}, + {"dim": 96}, + {"d_model": 96, "frames": 128, "img_size": 128, "in_chans": 3}, + {"frames": 128, "img_size": 128, "in_chans": 3}, + {"frame_depth": 3}, + ] + + model = None + for kwargs in ctor_trials: + try: + candidate = Builder(**kwargs) if inspect.isclass(Builder) else Builder(**kwargs) + if isinstance(candidate, nn.Module): + model = candidate + break + except Exception: + continue + + if model is None: + raise RuntimeError("Could not construct PhysMamba model") + + try: + checkpoint = torch.load(str(ckpt_path), map_location="cpu") + state_dict = checkpoint.get("state_dict", checkpoint) + + try: + model.load_state_dict(state_dict, strict=False) + except Exception: + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + model.load_state_dict(state_dict, strict=False) + except Exception: + pass + + model.to(device).eval() + + try: + with torch.no_grad(): + _ = model(torch.zeros(1, 3, 8, 128, 128, device=device)) + except Exception: + pass + + # Disable attention visualization since model forward pass is incompatible + attention_viz = None + + return model, attention_viz + +def bandpass_filter(x: np.ndarray, fs: int = 30, low: float = 0.7, high: float = 3.5, order: int = 4) -> np.ndarray: + """ + Stable band-pass with edge-safety and parameter clipping. + """ + x = np.asarray(x, dtype=float) + n = int(fs * 2) + if x.size < max(n, 8): # need at least ~2s + return x + + nyq = 0.5 * fs + lo = max(low / nyq, 1e-6) + hi = min(high / nyq, 0.999_999) + if not (0.0 < lo < hi < 1.0): + return x + + try: + b, a = signal.butter(order, [lo, hi], btype="band") + # padlen must be < len(x); reduce when short + padlen = min(3 * max(len(a), len(b)), max(0, x.size - 1)) + return signal.filtfilt(b, a, x, padlen=padlen) + except Exception: + return x + +def hr_from_welch(x: np.ndarray, fs: int = 30, lo: float = 0.7, hi: float = 3.5) -> float: + """ + HR (BPM) via Welch PSD peak in [lo, hi] Hz. + """ + x = np.asarray(x, dtype=float) + if x.size < int(fs * 4.0): # need ~4s for a usable PSD + return 0.0 + try: + # nperseg tuned for short windows while avoiding tiny segments + nper = int(min(max(64, fs * 2), min(512, x.size))) + f, pxx = welch(x, fs=fs, window=get_window("hann", nper), nperseg=nper, detrend="constant") + + mask = (f >= lo) & (f <= hi) + if not np.any(mask): + return 0.0 + f_band = f[mask] + p_band = pxx[mask] + + if p_band.size == 0 or np.all(~np.isfinite(p_band)): + return 0.0 + + fpk = float(f_band[np.argmax(p_band)]) + bpm = fpk * 60.0 + # clip to plausible range + return float(np.clip(bpm, 30.0, 220.0)) + except Exception: + return 0.0 + +def compute_rmssd(x: np.ndarray, fs: int = 30) -> float: + """ + HRV RMSSD from peaks; robust to short/flat segments. + """ + x = np.asarray(x, dtype=float) + if x.size < int(fs * 5.0): + return 0.0 + try: + # peak distance ~ 0.5s minimum (avoid double counting) + peaks, _ = find_peaks(x, distance=max(1, int(0.5 * fs))) + if len(peaks) < 3: + return 0.0 + rr = np.diff(peaks) / fs * 1000.0 # ms + if rr.size < 2: + return 0.0 + return float(np.sqrt(np.mean(np.diff(rr) ** 2))) + except Exception: + return 0.0 + +def postprocess_bvp(pred: np.ndarray, fs: int = 30) -> Tuple[np.ndarray, float]: + """ + Filters BVP to HR band + returns smoothed HR (BPM) with gentle pull toward resting band. + Signature unchanged to avoid breaking callers. + """ + global _HR_SMOOTH + + y = np.asarray(pred, dtype=float) + if y.size == 0: + return y, 0.0 + + # 1) band-limit + y_filt = bandpass_filter(y, fs=fs, low=0.7, high=3.5, order=4) + + # 2) HR estimate + hr = hr_from_welch(y_filt, fs=fs, lo=0.7, hi=3.5) + + # 3) gentle attraction to resting band (if way off) + if hr > 0: + lo, hi = REST_HR_RANGE + if hr < lo or hr > hi: + dist = abs(hr - REST_HR_TARGET) + # farther away -> stronger pull + alpha = float(np.clip(0.25 + 0.02 * dist, 0.25, 0.65)) + hr = alpha * hr + (1.0 - alpha) * REST_HR_TARGET + + # 4) temporal smoothing to limit frame-to-frame jumps + if hr > 0: + if _HR_SMOOTH is None: + _HR_SMOOTH = hr + else: + step = float(np.clip(hr - _HR_SMOOTH, -MAX_JUMP_BPM, MAX_JUMP_BPM)) + _HR_SMOOTH = _HR_SMOOTH + 0.6 * step + hr = float(_HR_SMOOTH) + + return y_filt, float(hr) + +def draw_face_and_roi(frame_bgr: np.ndarray, + face_bbox: Optional[Tuple[int, int, int, int]], + roi_bbox: Optional[Tuple[int, int, int, int]], + label: str = "ROI") -> np.ndarray: + """ + Draw face (green) and ROI (cyan) rectangles on a copy of the frame. + """ + vis = frame_bgr.copy() + if face_bbox is not None: + x, y, w, h = face_bbox + cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 230, 0), 2) + cv2.putText(vis, "FACE", (x, max(20, y - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 230, 0), 2) + if roi_bbox is not None: + rx, ry, rw, rh = roi_bbox + cv2.rectangle(vis, (rx, ry), (rx + rw, ry + rh), (255, 220, 0), 2) + cv2.putText(vis, label, (rx, max(20, ry - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 220, 0), 2) + return vis + +def roi_bbox_from_face(face_bbox: Tuple[int, int, int, int], + roi_type: str, + frame_shape: Tuple[int, int, int]) -> Tuple[int, int, int, int]: + """ + Compute the ROI rectangle (x,y,w,h) from a face bbox and your ROI rule. + Matches your crop_roi geometry. + """ + x, y, w, h = face_bbox + H, W = frame_shape[:2] + if roi_type == "forehead": + rx = int(x + 0.25 * w); rw = int(0.5 * w) + ry = int(y + 0.10 * h); rh = int(0.20 * h) + elif roi_type == "cheeks": + rx = int(x + 0.15 * w); rw = int(0.70 * w) + ry = int(y + 0.55 * h); rh = int(0.30 * h) + else: + rx, ry, rw, rh = x, y, w, h + + rx2 = min(W, rx + rw) + ry2 = min(H, ry + rh) + rx = max(0, rx); ry = max(0, ry) + if rx2 <= rx or ry2 <= ry: + return (0, 0, 0, 0) + return (rx, ry, rx2 - rx, ry2 - ry) + +def render_preprocessed_roi(chw: np.ndarray) -> np.ndarray: + """ + Visualize the model input (C,H,W, normalized). Returns HxWx3 uint8 BGR. + Assumes chw = (3, H, W) with zero-mean, unit-var normalization per-frame. + """ + if chw is None or chw.ndim != 3 or chw.shape[0] != 3: + return np.zeros((128, 128, 3), dtype=np.uint8) + + # Undo channel-first & normalization to a viewable image + img = chw.copy() + # Re-normalize to 0..1 by min-max of the tensor to "show" contrast + vmin, vmax = float(img.min()), float(img.max()) + if vmax <= vmin + 1e-6: + img = np.zeros_like(img) + else: + img = (img - vmin) / (vmax - vmin) + + img = (img.transpose(1, 2, 0)[:, :, ::-1] * 255.0).clip(0, 255).astype(np.uint8) # RGB->BGR + return img + +def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]: + if gt_len <= 1: + return None + if gt_fs and gt_fs > 0: + return np.arange(gt_len, dtype=float) / float(gt_fs) + return None # will fall back to length-matching overlay + +def plot_signals_with_gt(time_axis: np.ndarray, + raw_signal: np.ndarray, + post_signal: np.ndarray, + fs: int, + out_path: str, + gt_time: Optional[np.ndarray] = None, + gt_bvp: Optional[np.ndarray] = None, + title: str = "rPPG Signals (Pred vs GT)") -> str: + """ + Save a 3-pane plot: (1) predicted raw, (2) predicted post, (3) overlay Pred vs GT (normalized). + If GT is provided, it is resampled to the prediction time grid and lag-aligned + (±5 s) using cross-correlation. The overlay includes Pearson r, lag, and HR stats. + """ + import numpy as _np + import matplotlib.pyplot as _plt + from matplotlib.gridspec import GridSpec as _GridSpec + + def z(x): + x = _np.asarray(x, dtype=float) + if x.size == 0: + return x + m = float(_np.nanmean(x)) + s = float(_np.nanstd(x)) + 1e-8 + return (x - m) / s + + def _bandpass(x, fs_local, lo=0.7, hi=3.5, order=4): + try: + return bandpass_filter(_np.asarray(x, float), fs=fs_local, low=lo, high=hi, order=order) + except Exception: + return _np.asarray(x, float) + + def _welch_hr(x, fs_local): + try: + return float(hr_from_welch(_np.asarray(x, float), fs=fs_local, lo=0.7, hi=3.5)) + except Exception: + return 0.0 + + def _safe_interp(x_t, y, t_new): + """Monotonic-safe 1D interpolation with clipping to valid domain.""" + x_t = _np.asarray(x_t, dtype=float).ravel() + y = _np.asarray(y, dtype=float).ravel() + t_new = _np.asarray(t_new, dtype=float).ravel() + + if x_t.size < 2 or y.size != x_t.size: + # Fallback: length-based resize to t_new length + if y.size == 0 or t_new.size == 0: + return _np.zeros_like(t_new) + idx = _np.linspace(0, y.size - 1, num=t_new.size) + return _np.interp(_np.arange(t_new.size), idx, y) + + # Enforce strictly increasing time (dedup if needed) + order = _np.argsort(x_t) + x_t = x_t[order] + y = y[order] + mask = _np.concatenate(([True], _np.diff(x_t) > 0)) + x_t = x_t[mask] + y = y[mask] + # Clip t_new to the valid domain to avoid edge extrapolation artifacts + t_clip = _np.clip(t_new, x_t[0], x_t[-1]) + return _np.interp(t_clip, x_t, y) + + def _best_lag(pred, gt, fs_local, max_lag_s=5.0): + """Return lag (sec) that best aligns GT to Pred using cross-correlation on z-scored signals.""" + x = z(pred); y = z(gt) + if x.size < 8 or y.size < 8: + return 0.0 + n = int(min(len(x), len(y))) + x = x[:n]; y = y[:n] + max_lag = int(max(1, min(n - 1, round(max_lag_s * fs_local)))) + # valid lags: negative means GT should be shifted left (advance) relative to Pred + lags = _np.arange(-max_lag, max_lag + 1) + # compute correlation for each lag + best_corr = -_np.inf + best_lag = 0 + for L in lags: + if L < 0: + xx = x[-L:n] + yy = y[0:n+L] + elif L > 0: + xx = x[0:n-L] + yy = y[L:n] + else: + xx = x + yy = y + if xx.size < 8 or yy.size < 8: + continue + c = _np.corrcoef(xx, yy)[0, 1] + if _np.isfinite(c) and c > best_corr: + best_corr = c + best_lag = L + return float(best_lag / float(fs_local)) + + def _apply_lag(y, lag_sec, fs_local): + """Shift y by lag_sec (positive => delay GT) using sample roll; edges set to NaN.""" + y = _np.asarray(y, float) + if y.size == 0 or fs_local <= 0: + return y + shift = int(round(lag_sec * fs_local)) + if shift == 0: + return y + out = _np.empty_like(y) + out[:] = _np.nan + if shift > 0: + # delay: move content right + out[shift:] = y[:-shift] + else: + # advance: move content left + out[:shift] = y[-shift:] + return out + + t = _np.asarray(time_axis, dtype=float) + raw = _np.asarray(raw_signal, dtype=float) + post = _np.asarray(post_signal, dtype=float) + + # guard + if t.size == 0: + t = _np.arange(post.size, dtype=float) / max(fs, 1) + + have_gt = gt_bvp is not None and _np.asarray(gt_bvp).size > 0 + gt_on_pred = None + lag_sec = 0.0 + pearson_r = _np.nan + hr_pred = _welch_hr(_bandpass(post, fs), fs) + hr_gt = 0.0 + + if have_gt: + gt = _np.asarray(gt_bvp, dtype=float).ravel() + if gt_time is not None and _np.asarray(gt_time).size == gt.size: + gt_t = _np.asarray(gt_time, dtype=float).ravel() + gt_on_pred = _safe_interp(gt_t, gt, t) + else: + # No time vector: try length-based mapping to pred time grid + gt_on_pred = _safe_interp(_np.linspace(0, t[-1] if t.size else (gt.size - 1) / max(fs, 1), num=gt.size), + gt, t) + + # Band-limit both before correlation/HR + pred_bp = _bandpass(post, fs) + gt_bp = _bandpass(gt_on_pred, fs) + + # Estimate best lag (sec) of GT relative to Pred + lag_sec = _best_lag(pred_bp, gt_bp, fs_local=fs, max_lag_s=5.0) + + # Apply lag to GT for visualization and correlation + gt_aligned = _apply_lag(gt_on_pred, lag_sec, fs_local=fs) + + # Compute Pearson r on overlapping valid samples + valid = _np.isfinite(gt_aligned) & _np.isfinite(pred_bp) + if valid.sum() >= 16: + pearson_r = float(_np.corrcoef(z(pred_bp[valid]), z(gt_aligned[valid]))[0, 1]) + else: + pearson_r = _np.nan + + hr_gt = _welch_hr(gt_bp[_np.isfinite(gt_bp)], fs) + + + _plt.figure(figsize=(13, 6), dpi=110) + gs = _GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], wspace=0.25, hspace=0.35) + + # (1) Raw Pred + ax1 = _plt.subplot(gs[0, 0]) + ax1.plot(t, raw - (raw.mean() if raw.size else 0.0), linewidth=1.5) + ax1.set_title(f"Predicted (Raw) — fs={fs} Hz") + ax1.set_xlabel("Time (s)"); ax1.set_ylabel("Amplitude") + ax1.grid(True, alpha=0.3) + + # (2) Post Pred + ax2 = _plt.subplot(gs[0, 1]) + ax2.plot(t, post - (post.mean() if post.size else 0.0), linewidth=1.5) + ax2.set_title("Predicted (Post-processed)") + ax2.set_xlabel("Time (s)"); ax2.set_ylabel("Amplitude") + ax2.grid(True, alpha=0.3) + + # (3) Overlay Pred vs GT (z-scored) OR just post + ax3 = _plt.subplot(gs[1, :]) + ax3.plot(t, z(post), label="Pred (post)", linewidth=1.6) + + if have_gt and gt_on_pred is not None: + gt_bp = _bandpass(gt_on_pred, fs) + gt_aligned = _apply_lag(gt_bp, lag_sec, fs_local=fs) + ax3.plot(t, z(gt_aligned), label=f"GT (aligned {lag_sec:+.2f}s)", linewidth=1.2, alpha=0.9) + + # metrics box + txt = [ + f"HR_pred: {hr_pred:.1f} BPM", + f"HR_gt: {hr_gt:.1f} BPM", + f"Pearson r: {pearson_r:.3f}" if _np.isfinite(pearson_r) else "Pearson r: --", + f"Lag: {lag_sec:+.2f} s" + ] + ax3.text(0.01, 0.98, "\n".join(txt), transform=ax3.transAxes, + va="top", ha="left", fontsize=9, + bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="0.8", alpha=0.9)) + ax3.set_title("Pred vs GT (z-score overlay, lag-aligned)") + else: + ax3.set_title("Pred vs GT (no GT provided)") + + ax3.set_xlabel("Time (s)"); ax3.set_ylabel("z") + ax3.grid(True, alpha=0.3) + ax3.legend(loc="upper right") + + _plt.suptitle(title, fontweight="bold") + _plt.tight_layout(rect=[0, 0.02, 1, 0.97]) + _plt.savefig(out_path, bbox_inches="tight") + _plt.close() + return out_path + +def detect_face(frame: np.ndarray) -> Optional[Tuple[int, int, int, int]]: + """ + Robust single-face detector with a few practical guards: + - converts to gray safely + - equalizes histogram (helps underexposure) + - tries multiple scales / minNeighbors + - returns the largest face bbox (x,y,w,h) or None + """ + if frame is None or frame.size == 0: + return None + + try: + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + except Exception: + # If color conversion fails, assume already gray + gray = frame.copy() if frame.ndim == 2 else cv2.cvtColor(frame[..., :3], cv2.COLOR_BGR2GRAY) + + # Light preproc to improve Haar performance + gray = cv2.equalizeHist(gray) + + faces_all = [] + # Try a couple of parameter combos to be more forgiving + params = [ + dict(scaleFactor=1.05, minNeighbors=3), + dict(scaleFactor=1.10, minNeighbors=4), + dict(scaleFactor=1.20, minNeighbors=5), + ] + for p in params: + try: + faces = FACE_CASCADE.detectMultiScale(gray, **p) + if faces is not None and len(faces) > 0: + faces_all.extend([tuple(map(int, f)) for f in faces]) + except Exception: + continue + + if not faces_all: + return None + + # Return the largest (by area) + return max(faces_all, key=lambda f: f[2] * f[3]) + +def crop_roi(face_bbox: Tuple[int, int, int, int], roi_type: str, frame: np.ndarray) -> Optional[np.ndarray]: + """ + Crop ROI from the frame based on a face bbox and the selected roi_type. + Returns the cropped BGR ROI or None if invalid. + """ + if face_bbox is None or frame is None or frame.size == 0: + return None + + x, y, w, h = map(int, face_bbox) + H, W = frame.shape[:2] + + if roi_type == "forehead": + rx = int(x + 0.25 * w); rw = int(0.50 * w) + ry = int(y + 0.10 * h); rh = int(0.20 * h) + elif roi_type == "cheeks": + rx = int(x + 0.15 * w); rw = int(0.70 * w) + ry = int(y + 0.55 * h); rh = int(0.30 * h) + else: + rx, ry, rw, rh = x, y, w, h + + # clamp in-bounds + rx = max(0, rx); ry = max(0, ry) + rx2 = min(W, rx + rw); ry2 = min(H, ry + rh) + + if rx2 <= rx or ry2 <= ry: + return None + + roi = frame[ry:ry2, rx:rx2] + # Avoid empty or 1-pixel slivers + if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4: + return None + return roi + +def crop_roi_with_bbox(face_bbox: Tuple[int, int, int, int], + roi_type: str, + frame: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int,int,int]]]: + if face_bbox is None or frame is None or frame.size == 0: + return None, None + + x, y, w, h = map(int, face_bbox) + H, W = frame.shape[:2] + + if roi_type == "forehead": + rx = int(x + 0.25 * w); rw = int(0.50 * w) + ry = int(y + 0.10 * h); rh = int(0.20 * h) + elif roi_type == "cheeks": + rx = int(x + 0.15 * w); rw = int(0.70 * w) + ry = int(y + 0.55 * h); rh = int(0.30 * h) + else: + rx, ry, rw, rh = x, y, w, h + + rx = max(0, rx); ry = max(0, ry) + rx2 = min(W, rx + rw); ry2 = min(H, ry + rh) + if rx2 <= rx or ry2 <= ry: + return None, None + + roi = frame[ry:ry2, rx:rx2] + if roi.size == 0 or roi.shape[0] < 4 or roi.shape[1] < 4: + return None, None + + return roi, (rx, ry, rx2 - rx, ry2 - ry) + +def normalize_frame(face_bgr: np.ndarray, size: int) -> np.ndarray: + """ + PhysMamba-compatible normalization with DiffNormalized support. + Returns (3, size, size). + """ + if face_bgr is None or face_bgr.size == 0: + return np.zeros((3, size, size), dtype=np.float32) + + try: + face = cv2.resize(face_bgr, (size, size), interpolation=cv2.INTER_AREA) + except Exception: + face = cv2.resize(face_bgr, (size, size)) + + face = face.astype(np.float32) / 255.0 + + # Per-frame standardization + mean = face.mean(axis=(0, 1), keepdims=True) + std = face.std(axis=(0, 1), keepdims=True) + 1e-6 + face = (face - mean) / std + + # BGR -> RGB and HWC -> CHW + chw = face[..., ::-1].transpose(2, 0, 1).astype(np.float32, copy=False) + return chw + +def extract_attention_map(model, clip_tensor: torch.Tensor, + attention_viz) -> Optional[np.ndarray]: + """Attention visualization disabled - model architecture incompatible.""" + return None + +def create_attention_overlay(frame: np.ndarray, attention: Optional[np.ndarray], + attention_viz: Optional[SimpleActivationAttention] = None) -> np.ndarray: + """Create attention heatmap overlay.""" + # Attention disabled - return original frame + return frame + +def occlusion_saliency(roi_bgr, model, fs, patch=16, stride=12): + H, W = roi_bgr.shape[:2] + base_bvp = forward_bvp(model, torch.from_numpy(normalize_frame(roi_bgr, DEFAULT_SIZE)) + .unsqueeze(0).unsqueeze(2).to(DEVICE)) # fake T=1 path if needed + base_power = hr_from_welch(bandpass_filter(base_bvp, fs=fs), fs=fs) + + heat = np.zeros((H, W), np.float32) + for y in range(0, H - patch + 1, stride): + for x in range(0, W - patch + 1, stride): + tmp = roi_bgr.copy() + tmp[y:y+patch, x:x+patch] = 127 # occlude + bvp = forward_bvp(model, torch.from_numpy(normalize_frame(tmp, DEFAULT_SIZE)) + .unsqueeze(0).unsqueeze(2).to(DEVICE)) + power = hr_from_welch(bandpass_filter(bvp, fs=fs), fs=fs) + drop = max(0.0, base_power - power) + heat[y:y+patch, x:x+patch] += drop + heat -= heat.min() + if heat.max() > 1e-8: heat /= heat.max() + return heat + +def _call_model_try_orders(model: nn.Module, clip_tensor: torch.Tensor): + """ + Try common 5D layouts: + [B, C, T, H, W] then [B, T, C, H, W]. + """ + last_err = None + try: + return model(clip_tensor) + except Exception as e: + last_err = e + try: + return model(clip_tensor.permute(0, 2, 1, 3, 4).contiguous()) + except Exception as e: + last_err = e + raise last_err + +def forward_bvp(model: nn.Module, clip_tensor: torch.Tensor) -> np.ndarray: + """ + Forward and extract a 1D time-like BVP vector with length T_clip. + Tolerant to dict/list/tuple heads and odd shapes. + """ + T_clip = int(clip_tensor.shape[2]) # intended time length for [B,C,T,H,W] + with torch.no_grad(): + out = _call_model_try_orders(model, clip_tensor) + + # unwrap common containers + if isinstance(out, dict): + for key in ("bvp", "ppg", "signal", "pred", "y", "out", "logits"): + if key in out and isinstance(out[key], torch.Tensor): + out = out[key] + break + + if isinstance(out, (list, tuple)): + tensors = [t for t in out if isinstance(t, torch.Tensor)] + if not tensors: + return np.zeros(T_clip, dtype=np.float32) + + def score(t: torch.Tensor): + has_T = 1 if T_clip in t.shape else 0 + return (has_T, t.numel()) + + out = max(tensors, key=score) + + if not isinstance(out, torch.Tensor): + return np.zeros(T_clip, dtype=np.float32) + + out = out.detach().cpu().float() + + # ---- 1D + if out.ndim == 1: + v = out + if v.shape[0] == T_clip: + return v.numpy() + if v.numel() == 1: + return np.full(T_clip, float(v.item()), dtype=np.float32) + return np.resize(v.numpy(), T_clip).astype(np.float32) + + # ---- 2D + if out.ndim == 2: + B, K = out.shape + if B == 1: + v = out[0] + return (v.numpy() if v.shape[0] == T_clip else np.resize(v.numpy(), T_clip).astype(np.float32)) + if B == T_clip: + return out[:, 0].numpy() + if K == T_clip: + return out[0, :].numpy() + return np.resize(out.flatten().numpy(), T_clip).astype(np.float32) + + # ---- 3D + if out.ndim == 3: + B, D1, D2 = out.shape[0], out.shape[1], out.shape[2] + if D1 == T_clip: # [B, T, C] + return out[0, :, 0].numpy() + if D2 == T_clip: # [B, C, T] + return out[0, 0, :].numpy() + v = out.mean(dim=tuple(range(1, out.ndim))).squeeze(0) + return np.resize(v.numpy(), T_clip).astype(np.float32) + + # ---- 4D + if out.ndim == 4: + B, A, H, W = out.shape + if A == T_clip: # [B, T, H, W] + return out[0].mean(dim=(1, 2)).numpy() + v = out[0].mean(dim=(1, 2)) + return np.resize(v.numpy(), T_clip).astype(np.float32) + + # ---- 5D+ + if out.ndim >= 5: + shape = list(out.shape) + try: + t_idx = next(i for i, s in enumerate(shape) if (i != 0 and s == T_clip)) + except StopIteration: + pooled = out[0].mean(dim=tuple(i for i in range(1, out.ndim) if i not in (-1,))) + v = pooled.flatten() + return np.resize(v.numpy(), T_clip).astype(np.float32) + + axes = list(range(out.ndim)) + perm = [0, t_idx] + [i for i in axes[1:] if i != t_idx] + o2 = out.permute(*perm) # [B, T, ...] + pooled = o2.mean(dim=tuple(range(2, o2.ndim))) # -> [B, T] + return pooled[0].numpy() + + # fallback: constant vector with mean value + val = float(out.mean().item()) if out.numel() else 0.0 + return np.full(T_clip, val, dtype=np.float32) + +def _fallback_bvp_from_means(means, fs: int) -> np.ndarray: + """ + Classical rPPG from green-channel means when the model yields nothing. + Detrend -> bandpass -> z-normalize. + """ + if means is None: + return np.array([], dtype=np.float32) + + x = np.asarray(means, dtype=np.float32) + if x.size == 0: + return np.array([], dtype=np.float32) + + x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + + try: + x = signal.detrend(x, type="linear") + except Exception: + pass + + y = bandpass_filter(x, fs=fs, low=0.7, high=3.5, order=4) + + std = float(np.std(y)) + 1e-6 + return (y / std).astype(np.float32) + +def _to_floats(s: str) -> List[float]: + """ + Extract all real numbers from free-form text, including scientific notation. + Gracefully ignores 'nan', 'inf', units, and comments. + """ + if not isinstance(s, str) or not s: + return [] + + s = re.sub(r"(#|//|;).*?$", "", s, flags=re.MULTILINE) + + s = s.replace(",", " ").replace(";", " ") + + toks = re.findall(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?", s) + out: List[float] = [] + for t in toks: + try: + v = float(t) + if np.isfinite(v): + out.append(v) + except Exception: + continue + return out + +def parse_ground_truth_file(gt_path: str) -> Tuple[np.ndarray, float, float]: + """ + Parse ground-truth files from: + • UBFC/UBFC-rPPG style TXT: 3 lines => PPG, HR, timestep(s) + • Generic TXT: one column (or free-form) numeric sequence + • JSON: keys like 'ppg' / 'bvp' (optionally nested), 'hr', 'fs' + • CSV: columns named BVP/PPG/Signal + optional HR + optional Time + • MAT: array under keys ['ppg','bvp','signal','wave']; optionally 'fs'/'hr' + • NPY: np.load() 1-D array; optionally sidecar .json with fs/hr (same stem) + + Returns: + bvp : np.ndarray (float) — may be empty + hr : float (mean HR if available or estimated) + fs_hint : float — sampling rate if derivable (0.0 if unknown) + """ + if not gt_path or not os.path.exists(gt_path): + return np.array([]), 0.0, 0.0 + + p = Path(gt_path) + ext = p.suffix.lower() + + def _fs_from_time_vector(tv: np.ndarray) -> float: + tv = np.asarray(tv, dtype=float) + if tv.ndim != 1 or tv.size < 2: + return 0.0 + diffs = np.diff(tv) + diffs = diffs[np.isfinite(diffs) & (diffs > 0)] + return (1.0 / float(np.median(diffs))) if diffs.size else 0.0 + + + def _hr_from_bvp(bvp: np.ndarray, fs_hint: float) -> float: + if bvp is None or bvp.size == 0: + return 0.0 + fs_use = fs_hint if (fs_hint and fs_hint > 0) else 30.0 + bp = bandpass_filter(bvp.astype(float), fs=fs_use) + return hr_from_welch(bp, fs=fs_use) + + if p.name.lower() == "ground_truth.txt" or (ext == ".txt" and p.read_text(errors="ignore").count("\n") >= 2): + try: + lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines() if ln.strip()] + ppg_vals = _to_floats(lines[0]) if len(lines) >= 1 else [] + hr_vals = _to_floats(lines[1]) if len(lines) >= 2 else [] + t_vals = _to_floats(lines[2]) if len(lines) >= 3 else [] + + bvp = np.asarray(ppg_vals, dtype=float) if ppg_vals else np.array([], dtype=float) + + hr = float(np.nanmean(hr_vals)) if hr_vals else 0.0 + fs_hint = 0.0 + if t_vals: + if len(t_vals) == 1: + dt = float(t_vals[0]) + if dt > 0: + fs_hint = 1.0 / dt + else: + fs_hint = _fs_from_time_vector(np.asarray(t_vals, dtype=float)) + + if hr == 0.0 and bvp.size: + hr = _hr_from_bvp(bvp, fs_hint) + return bvp, hr, fs_hint + except Exception: + # Fall through to generic handlers + pass + + + if ext == ".txt": + try: + nums = _to_floats(p.read_text(encoding="utf-8", errors="ignore")) + bvp = np.asarray(nums, dtype=float) if nums else np.array([], dtype=float) + hr = _hr_from_bvp(bvp, fs_hint=0.0) if bvp.size else 0.0 + return bvp, hr, 0.0 + except Exception: + return np.array([]), 0.0, 0.0 + + + if ext == ".json": + try: + data = json.loads(p.read_text(encoding="utf-8", errors="ignore")) + # Try several paths for BVP array + bvp = None + + def _seek(obj, keys): + for k in keys: + if isinstance(obj, dict) and k in obj: + return obj[k] + return None + + # Direct top-level + bvp = _seek(data, ("ppg", "bvp", "signal", "wave")) + # Common nested containers + if bvp is None: + for container_key in ("FullPackage", "package", "data", "gt", "ground_truth"): + if container_key in data: + cand = _seek(data[container_key], ("ppg", "bvp", "signal", "wave")) + if cand is not None: + bvp = cand + break + + if bvp is not None: + bvp = np.asarray(bvp, dtype=float).ravel() + else: + bvp = np.array([], dtype=float) + + # fs / hr (accept scalar or array) + fs_hint = 0.0 + if "fs" in data and isinstance(data["fs"], (int, float)) and data["fs"] > 0: + fs_hint = float(data["fs"]) + + hr = 0.0 + if "hr" in data: + v = data["hr"] + hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v) + + if hr == 0.0 and bvp.size: + hr = _hr_from_bvp(bvp, fs_hint) + return bvp, hr, fs_hint + except Exception: + return np.array([]), 0.0, 0.0 + + + if ext == ".csv": + try: + df = pd.read_csv(p) + # Normalize column names + cols = {str(c).strip().lower(): c for c in df.columns} + + def _first_match(names): + for nm in names: + if nm in cols: + return cols[nm] + return None + + bvp_col = _first_match(["bvp", "ppg", "wave", "signal", "bvp_signal", "ppg_signal"]) + hr_col = _first_match(["hr", "heart_rate", "hr_bpm", "bpm"]) + t_col = _first_match(["time", "t", "timestamp", "sec", "seconds", "time_s"]) + + bvp = np.asarray(df[bvp_col].values, dtype=float) if bvp_col else np.array([], dtype=float) + + fs_hint = 0.0 + if t_col is not None and len(df[t_col].values) >= 2: + fs_hint = _fs_from_time_vector(np.asarray(df[t_col].values, dtype=float)) + + hr = float(np.nanmean(df[hr_col].values)) if (hr_col and df[hr_col].notna().any()) else 0.0 + if hr == 0.0 and bvp.size: + hr = _hr_from_bvp(bvp, fs_hint) + return bvp, hr, fs_hint + except Exception: + return np.array([]), 0.0, 0.0 + + + if ext == ".mat": + try: + md = loadmat(str(p)) + # look for most likely array + arr = None + for key in ("ppg", "bvp", "signal", "wave"): + if key in md and isinstance(md[key], np.ndarray): + arr = md[key] + break + if arr is None: + # fallback: first 1-D array + for v in md.values(): + if isinstance(v, np.ndarray) and v.ndim == 1: + arr = v + break + bvp = np.asarray(arr, dtype=float).ravel() if arr is not None else np.array([], dtype=float) + + fs_hint = 0.0 + for k in ("fs", "Fs", "sampling_rate", "sr"): + if k in md: + try: + fs_hint = float(np.ravel(md[k])[0]) + break + except Exception: + pass + + hr = 0.0 + if "hr" in md: + try: + hr = float(np.nanmean(np.ravel(md["hr"]))) + except Exception: + hr = 0.0 + + if hr == 0.0 and bvp.size: + hr = _hr_from_bvp(bvp, fs_hint) + return bvp, hr, fs_hint + except Exception: + return np.array([]), 0.0, 0.0 + + # ================= NPY ================= + if ext == ".npy": + try: + bvp = np.asarray(np.load(str(p)), dtype=float).ravel() + fs_hint, hr = 0.0, 0.0 + # optional sidecar JSON (same stem) with fs/hr + sidecar = p.with_suffix(".json") + if sidecar.exists(): + try: + meta = json.loads(sidecar.read_text(encoding="utf-8", errors="ignore")) + if isinstance(meta.get("fs", None), (int, float)) and meta["fs"] > 0: + fs_hint = float(meta["fs"]) + if "hr" in meta: + v = meta["hr"] + hr = float(np.nanmean(v)) if isinstance(v, (list, tuple, np.ndarray)) else float(v) + except Exception: + pass + if hr == 0.0 and bvp.size: + hr = _hr_from_bvp(bvp, fs_hint) + return bvp, hr, fs_hint + except Exception: + return np.array([]), 0.0, 0.0 + + # Fallback (unsupported extension) + return np.array([]), 0.0, 0.0 + +def scan_models() -> List[str]: + if not MODEL_DIR.exists(): + return [] + + models = [] + for f in sorted(MODEL_DIR.iterdir()): + if f.suffix.lower() == '.pth': + models.append(f.name) + + return models + +_GLOBAL_CONTROLS: Dict[str, Dict] = {} + +def ensure_controls(control_id: str) -> Tuple[str, Dict]: + # Use a stable default so Pause/Resume/Stop work for the current run + if not control_id: + control_id = "default-session" + if control_id not in _GLOBAL_CONTROLS: + _GLOBAL_CONTROLS[control_id] = { + 'pause': threading.Event(), + 'stop': threading.Event() + } + return control_id, _GLOBAL_CONTROLS[control_id] + +def process_video_file( + video_path: str, + gt_file: Optional[str], + model_name: str, + fps_input: int, + max_seconds: int, + roi_type: str, + control_id: str +): + """ + Enhanced video processing with Grad-CAM attention visualization. + """ + global _HR_SMOOTH + _HR_SMOOTH = None + + def _gt_time_axis(gt_len: int, gt_fs: float) -> Optional[np.ndarray]: + if gt_len <= 1: + return None + if gt_fs and gt_fs > 0: + return np.arange(gt_len, dtype=float) / float(gt_fs) + return None + + control_id, controls = ensure_controls(control_id) + controls['stop'].clear() + + if not model_name: + yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None) + return + + if isinstance(model_name, int): + model_name = str(model_name) + + model_path = MODEL_DIR / model_name + if not model_path.exists(): + yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None) + return + + try: + model, attention_viz = load_physmamba_model(model_path, DEVICE) + except Exception as e: + yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None) + return + + + gt_bvp, gt_hr, gt_fs = parse_ground_truth_file(gt_file) if gt_file else (np.array([]), 0.0, 0.0) + + if not video_path or not os.path.exists(video_path): + yield ("ERROR: Video not found", None, None, None, None, None, None, None, None, None) + return + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + yield ("ERROR: Cannot open video", None, None, None, None, None, None, None, None, None) + return + + fps = int(fps_input) if fps_input else int(cap.get(cv2.CAP_PROP_FPS) or 30) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) + max_frames = int(max_seconds * fps) if max_seconds and max_seconds > 0 else total_frames + MAX_SIGNAL_LENGTH = max_frames if max_frames > 0 else fps * 600 + + frames_chw = deque(maxlen=DEFAULT_T) + raw_g_means = deque(maxlen=DEFAULT_T) + bvp_stream: List[float] = [] + frame_idx = 0 + last_infer = -1 + last_bpm = 0.0 + last_rmssd = 0.0 + last_attention = None + + start_time = time.time() + next_display = time.time() + + tmpdir = Path(tempfile.gettempdir()) + frame_path = tmpdir / "frame.jpg" + attention_path = tmpdir / "attention.jpg" + signal_path = tmpdir / "signal.png" + raw_path = tmpdir / "raw_signal.png" + post_path = tmpdir / "post_signal.png" + + yield ("Starting… reading video frames", None, f"{gt_hr:.1f}" if gt_hr > 0 else "--", + None, None, None, None, None, None, None) + + while True: + if controls['stop'].is_set(): + break + + while controls['pause'].is_set(): + time.sleep(0.2) + if controls['stop'].is_set(): + break + + ret, frame = cap.read() + if not ret or (max_frames > 0 and frame_idx >= max_frames): + break + + frame_idx += 1 + + face = detect_face(frame) + vis_frame = frame.copy() + + if face is not None: + x, y, w, h = face + cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 2) + + if roi_type == "auto": + roi, _ = pick_auto_roi(face, frame, attn=last_attention) + else: + roi = crop_roi(face, roi_type, frame) + + if roi is not None and roi.size > 0: + try: + g_mean = float(roi[..., 1].astype(np.float32).mean()) + raw_g_means.append(g_mean) + except Exception: + pass + + face_norm = normalize_frame(roi, DEFAULT_SIZE) + frames_chw.append(face_norm) + + if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE: + try: + clip = np.stack(list(frames_chw), axis=1).astype(np.float32) + except Exception as e: + print(f"[infer] clip stack failed: {e}") + clip = None + + bvp_out = None + if clip is not None: + clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE) + + try: + raw = forward_bvp(model, clip_t) + if isinstance(raw, np.ndarray): + raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False) + bvp_out = raw if raw.size > 0 else None + else: + bvp_out = None + except Exception as e: + print(f"[infer] forward_bvp error: {e}") + bvp_out = None + + # UPDATED: Generate attention with Grad-CAM + try: + last_attention = extract_attention_map(model, clip_t, attention_viz) + except Exception as e: + print(f"⚠ Attention generation error: {e}") + last_attention = None + + if bvp_out is None or bvp_out.size == 0: + gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0) + fb = _fallback_bvp_from_means(gbuf, fs=fps) + if isinstance(fb, np.ndarray) and fb.size > 0: + bvp_out = fb + else: + print("[infer] fallback produced empty output") + + if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0: + tail = min(DEFAULT_STRIDE, bvp_out.size) + bvp_stream.extend(bvp_out[-tail:].tolist()) + if len(bvp_stream) > MAX_SIGNAL_LENGTH: + bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:] + + if len(bvp_stream) >= int(5 * fps): + seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32) + _, last_bpm = postprocess_bvp(seg, fs=fps) + last_rmssd = compute_rmssd(seg, fs=fps) + + if frame_idx % (DEFAULT_STRIDE * 2) == 0: + print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}") + else: + print("[infer] no usable bvp_out after fallback") + + last_infer = frame_idx + + else: + cv2.putText(vis_frame, "No face detected", (20, 40), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2) + + if last_bpm > 0: + color = (0, 255, 0) if 55 <= last_bpm <= 100 else (0, 165, 255) + cv2.rectangle(vis_frame, (10, 10), (360, 65), (0, 0, 0), -1) + cv2.putText(vis_frame, f"HR: {last_bpm:.1f} BPM", (20, 48), + cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) + + vis_attention = create_attention_overlay(frame, last_attention, attention_viz) + + cv2.imwrite(str(frame_path), vis_frame) + cv2.imwrite(str(attention_path), vis_attention) + + now = time.time() + if now >= next_display: + if len(bvp_stream) >= max(10, int(1 * fps)): + try: + signal_array = np.array(bvp_stream, dtype=float) + time_axis = np.arange(len(signal_array)) / fps + + raw_signal = bandpass_filter(signal_array, fs=fps) + post_signal, _ = postprocess_bvp(signal_array, fs=fps) + + raw_vis = raw_signal - np.mean(raw_signal) + post_vis = post_signal - np.mean(post_signal) + + plt.figure(figsize=(10, 4), dpi=100) + plt.plot(time_axis, raw_vis, linewidth=1.6) + plt.xlabel('Time (s)'); plt.ylabel('Amplitude') + plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM') + plt.grid(True, alpha=0.3) + plt.tight_layout(); plt.savefig(str(raw_path), dpi=100, bbox_inches='tight'); plt.close() + + plt.figure(figsize=(10, 4), dpi=100) + plt.plot(time_axis, post_vis, linewidth=1.6) + plt.xlabel('Time (s)'); plt.ylabel('Amplitude') + plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM') + plt.grid(True, alpha=0.3) + plt.tight_layout(); plt.savefig(str(post_path), dpi=100, bbox_inches='tight'); plt.close() + + if gt_bvp is not None and gt_bvp.size > 0: + try: + gt_time = _gt_time_axis(len(gt_bvp), gt_fs) + plot_signals_with_gt( + time_axis=time_axis, + raw_signal=raw_signal, + post_signal=post_signal, + fs=fps, + out_path=str(signal_path), + gt_time=gt_time, + gt_bvp=gt_bvp, + title=f"Pred vs GT — HR: {last_bpm:.1f} BPM" + ) + except Exception: + fig = plt.figure(figsize=(12, 5), dpi=100) + gs = GridSpec(1, 2, figure=fig, wspace=0.3) + ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6) + ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3) + ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6) + ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3) + plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold') + plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all') + else: + fig = plt.figure(figsize=(12, 5), dpi=100) + gs = GridSpec(1, 2, figure=fig, wspace=0.3) + ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(time_axis, raw_vis, linewidth=1.6) + ax1.set_title('Raw Signal'); ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Amplitude'); ax1.grid(True, alpha=0.3) + ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(time_axis, post_vis, linewidth=1.6) + ax2.set_title('Post-processed Signal'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Amplitude'); ax2.grid(True, alpha=0.3) + plt.suptitle(f'rPPG Signals - HR: {last_bpm:.1f} BPM', fontsize=14, fontweight='bold') + plt.savefig(str(signal_path), dpi=100, bbox_inches='tight'); plt.close('all') + except Exception: + pass + + elapsed = now - start_time + status = f"Frame {frame_idx}/{total_frames} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM" + + yield ( + status, + f"{last_bpm:.1f}" if last_bpm > 0 else None, + f"{gt_hr:.1f}" if gt_hr > 0 else "--", + f"{last_rmssd:.1f}" if last_rmssd > 0 else None, + str(frame_path), + str(attention_path), + str(signal_path) if signal_path.exists() else None, + str(raw_path) if raw_path.exists() else None, + str(post_path) if post_path.exists() else None, + None + ) + + next_display = now + (1.0 / DISPLAY_FPS) + + cap.release() + + # Cleanup Grad-CAM + if attention_viz: attention_viz.cleanup() + + csv_path = None + if bvp_stream: + csv_path = Path(tempfile.gettempdir()) / "bvp_output.csv" + time_array = np.arange(len(bvp_stream)) / fps + signal_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps) + try: + pd.DataFrame({'time_s': time_array, 'bvp': signal_final}).to_csv(csv_path, index=False) + except Exception: + csv_path = None + + try: + final_overlay = Path(tempfile.gettempdir()) / "signal_final_overlay.png" + if gt_bvp is not None and gt_bvp.size > 0: + gt_time = _gt_time_axis(len(gt_bvp), gt_fs) + plot_signals_with_gt( + time_axis=time_array, + raw_signal=bandpass_filter(np.array(bvp_stream, dtype=float), fs=fps), + post_signal=signal_final, + fs=fps, + out_path=str(final_overlay), + gt_time=gt_time, + gt_bvp=gt_bvp, + title=f"Final Pred vs GT — HR: {last_bpm:.1f} BPM" + ) + if final_overlay.exists(): + signal_path = final_overlay + except Exception: + pass + + elapsed = time.time() - start_time + final_status = f"Complete | {frame_idx} frames | {elapsed:.1f}s | HR {last_bpm:.1f} BPM" + + yield ( + final_status, + f"{last_bpm:.1f}" if last_bpm > 0 else None, + f"{gt_hr:.1f}" if gt_hr > 0 else "--", + f"{last_rmssd:.1f}" if last_rmssd > 0 else None, + str(frame_path), + str(attention_path), + str(signal_path) if signal_path.exists() else None, + str(raw_path) if raw_path.exists() else None, + str(post_path) if post_path.exists() else None, + str(csv_path) if csv_path else None + ) + +def process_live_webcam( + model_name: str, + fps_input: int, + roi_type: str, + control_id: str +): + """Stream live webcam with Grad-CAM attention visualization.""" + global _HR_SMOOTH + _HR_SMOOTH = None + + def _perf_heartbeat(frame_idx, t0, bvp_len, frames_chw_len, fps): + if frame_idx == 1: + print(f"[run] device={DEVICE} target_fps={fps}") + if frame_idx % 60 == 0: + elapsed = time.time() - t0 + cur_fps = frame_idx / max(elapsed, 1e-6) + print(f"[perf] frames={frame_idx} ({cur_fps:.1f} FPS) " + f"clip={frames_chw_len}/{DEFAULT_T} bvp_len={bvp_len}") + + control_id, controls = ensure_controls(control_id) + controls['stop'].clear() + + if not model_name: + yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None) + return + + model_path = MODEL_DIR / model_name + if not model_path.exists(): + yield ("ERROR: Model not found", None, None, None, None, None, None, None, None, None) + return + + try: + model, attention_viz = load_physmamba_model(model_path, DEVICE) + except Exception as e: + yield (f"ERROR loading model: {str(e)}", None, None, None, None, None, None, None, None, None) + return + + + cap = None + for camera_id in [0, 1]: + for backend in [cv2.CAP_AVFOUNDATION, cv2.CAP_ANY]: + try: + test_cap = cv2.VideoCapture(camera_id, backend) + if test_cap.isOpened(): + test_cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) + test_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) + ret, frame = test_cap.read() + if ret and frame is not None: + cap = test_cap + break + test_cap.release() + except Exception: + pass + if cap is not None: + break + + if cap is None: + yield ("ERROR: Cannot access webcam", None, None, None, None, None, None, None, None, None) + return + + fps = int(fps_input) if fps_input else 30 + + frames_chw = deque(maxlen=DEFAULT_T) + raw_g_means = deque(maxlen=DEFAULT_T) + bvp_stream: List[float] = [] + frame_idx = 0 + last_infer = -1 + last_bpm = 0.0 + last_rmssd = 0.0 + last_attention = None + + t0 = time.time() + next_display = time.time() + DISPLAY_INTERVAL = 0.5 + + frame_path = Path(tempfile.gettempdir()) / "live_frame.jpg" + attention_path = Path(tempfile.gettempdir()) / "live_attention.jpg" + signal_path = Path(tempfile.gettempdir()) / "live_signal.png" + raw_path = Path(tempfile.gettempdir()) / "live_raw.png" + post_path = Path(tempfile.gettempdir()) / "live_post.png" + + MAX_SIGNAL_LENGTH = fps * 60 + + yield ("Starting… waiting for frames", None, "--", None, + None, None, None, None, None, None) + + while True: + if controls['stop'].is_set(): + break + + while controls['pause'].is_set(): + time.sleep(0.2) + if controls['stop'].is_set(): + break + + ret, frame = cap.read() + if not ret: + time.sleep(0.05) + continue + + frame_idx += 1 + _perf_heartbeat(frame_idx, t0, len(bvp_stream), len(frames_chw), fps) + + face = detect_face(frame) + vis_frame = frame.copy() + + if face is not None: + x, y, w, h = face + cv2.rectangle(vis_frame, (x, y), (x + w, y + h), (0, 255, 0), 3) + cv2.putText(vis_frame, "FACE", (x, max(20, y - 10)), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + + if roi_type == "auto": + roi, _ = pick_auto_roi(face, frame, attn=last_attention) + else: + roi = crop_roi(face, roi_type, frame) + + if roi is not None and roi.size > 0: + try: + g_mean = float(roi[..., 1].astype(np.float32).mean()) + raw_g_means.append(g_mean) + except Exception: + pass + + chw = normalize_frame(roi, DEFAULT_SIZE) + frames_chw.append(chw) + + if len(frames_chw) == DEFAULT_T and (frame_idx - last_infer) >= DEFAULT_STRIDE: + try: + clip = np.stack(list(frames_chw), axis=1).astype(np.float32) + except Exception as e: + print(f"[infer] clip stack failed: {e}") + clip = None + + bvp_out = None + if clip is not None: + clip_t = torch.from_numpy(clip).unsqueeze(0).to(DEVICE) + try: + raw = forward_bvp(model, clip_t) + if isinstance(raw, np.ndarray): + raw = np.nan_to_num(raw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False) + bvp_out = raw if raw.size > 0 else None + else: + bvp_out = None + except Exception as e: + print(f"[infer] forward_bvp error: {e}") + bvp_out = None + + try: + last_attention = extract_attention_map(model, clip_t, attention_viz) + except Exception as e: + print(f" Attention generation error: {e}") + last_attention = None + + if bvp_out is None or bvp_out.size == 0: + gbuf = np.nan_to_num(np.asarray(list(raw_g_means), dtype=np.float32), nan=0.0) + fb = _fallback_bvp_from_means(gbuf, fs=fps) + if isinstance(fb, np.ndarray) and fb.size > 0: + bvp_out = fb + else: + print("[infer] fallback produced empty output") + + if isinstance(bvp_out, np.ndarray) and bvp_out.size > 0: + tail = min(DEFAULT_STRIDE, bvp_out.size) + bvp_stream.extend(bvp_out[-tail:].tolist()) + if len(bvp_stream) > MAX_SIGNAL_LENGTH: + bvp_stream = bvp_stream[-MAX_SIGNAL_LENGTH:] + + if len(bvp_stream) >= int(5 * fps): + seg = np.asarray(bvp_stream[-int(10 * fps):], dtype=np.float32) + _, last_bpm = postprocess_bvp(seg, fs=fps) + last_rmssd = compute_rmssd(seg, fs=fps) + + if frame_idx % (DEFAULT_STRIDE * 2) == 0: + print(f"[infer] appended {tail}, bvp_len={len(bvp_stream)}") + else: + print("[infer] no usable bvp_out after fallback") + + last_infer = frame_idx + + else: + cv2.putText(vis_frame, "No face detected", (20, 40), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (30, 200, 255), 2) + + cv2.putText(vis_frame, f"Fill: {len(frames_chw)}/{DEFAULT_T}", (20, 25), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) + cv2.putText(vis_frame, f"BVP: {len(bvp_stream)}", (20, 45), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) + + if last_bpm > 0: + color = (50, 255, 50) if 55 <= last_bpm <= 100 else (50, 50, 255) if last_bpm > 100 else (255, 200, 50) + overlay = vis_frame.copy() + cv2.rectangle(overlay, (10, 10), (450, 100), (0, 0, 0), -1) + vis_frame = cv2.addWeighted(vis_frame, 0.6, overlay, 0.4, 0) + cv2.putText(vis_frame, f"HR: {last_bpm:.0f} BPM", (20, 80), + cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 3) + cv2.circle(vis_frame, (30, 30), 10, (0, 255, 0), -1) + cv2.putText(vis_frame, "LIVE", (50, 38), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) + else: + cv2.putText(vis_frame, "Collecting…", (20, 80), + cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2) + + vis_attention = create_attention_overlay(frame, last_attention, attention_viz) + + cv2.imwrite(str(frame_path), vis_frame) + cv2.imwrite(str(attention_path), vis_attention) + + now = time.time() + if now >= next_display: + if len(bvp_stream) >= max(10, int(1 * fps)): + try: + sig = np.array(bvp_stream, dtype=float) + t_axis = np.arange(len(sig)) / fps + raw_sig = bandpass_filter(sig, fs=fps) + post_sig, _ = postprocess_bvp(sig, fs=fps) + + rv = raw_sig - np.mean(raw_sig) + pv = post_sig - np.mean(post_sig) + + plt.figure(figsize=(10, 4), dpi=100) + plt.plot(t_axis, rv, linewidth=2) + plt.xlabel('Time (s)'); plt.ylabel('Amplitude') + plt.title(f'Raw Signal - HR: {last_bpm:.1f} BPM') + plt.grid(True, alpha=0.3) + plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]]) + plt.tight_layout(); plt.savefig(str(raw_path)); plt.close() + + plt.figure(figsize=(10, 4), dpi=100) + plt.plot(t_axis, pv, linewidth=2) + plt.xlabel('Time (s)'); plt.ylabel('Amplitude') + plt.title(f'Post-processed Signal - HR: {last_bpm:.1f} BPM') + plt.grid(True, alpha=0.3) + plt.xlim([max(0, t_axis[-1] - 20), t_axis[-1]]) + plt.tight_layout(); plt.savefig(str(post_path)); plt.close() + + fig = plt.figure(figsize=(12, 5), dpi=100) + from matplotlib.gridspec import GridSpec + gs = GridSpec(1, 2, figure=fig, wspace=0.3) + ax1 = fig.add_subplot(gs[0, 0]); ax1.plot(t_axis, rv, linewidth=2) + ax1.set_title('Raw Signal'); ax1.grid(True, alpha=0.3) + ax1.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]]) + ax2 = fig.add_subplot(gs[0, 1]); ax2.plot(t_axis, pv, linewidth=2) + ax2.set_title('Post-processed'); ax2.grid(True, alpha=0.3) + ax2.set_xlim([max(0, t_axis[-1] - 20), t_axis[-1]]) + plt.suptitle(f'LIVE rPPG - HR: {last_bpm:.1f} BPM') + plt.savefig(str(signal_path)); plt.close('all') + except Exception: + pass + + elapsed = now - t0 + status = (f"Frame {frame_idx} | Fill {len(frames_chw)}/{DEFAULT_T} | " + f"BVP {len(bvp_stream)} | HR {last_bpm:.1f} BPM | " + f"Time {int(elapsed)}s") + + yield ( + status, + f"{last_bpm:.1f}" if last_bpm > 0 else None, + "--", + f"{last_rmssd:.1f}" if last_rmssd > 0 else None, + str(frame_path), + str(attention_path), + str(signal_path) if signal_path.exists() else None, + str(raw_path) if raw_path.exists() else None, + str(post_path) if post_path.exists() else None, + None + ) + next_display = now + DISPLAY_INTERVAL + + time.sleep(0.01) + + cap.release() + + # Cleanup Grad-CAM + if attention_viz: attention_viz.cleanup() + + csv_path = None + if bvp_stream: + csv_path = Path(tempfile.gettempdir()) / "live_bvp.csv" + t = np.arange(len(bvp_stream)) / fps + sig_final, _ = postprocess_bvp(np.array(bvp_stream), fs=fps) + pd.DataFrame({"time_s": t, "bvp": sig_final}).to_csv(csv_path, index=False) + + elapsed = time.time() - t0 + final_status = f"Session ended | Frames {frame_idx} | Time {elapsed:.1f}s | HR {last_bpm:.1f} BPM" + yield ( + final_status, + f"{last_bpm:.1f}" if last_bpm > 0 else None, + "--", + f"{last_rmssd:.1f}" if last_rmssd > 0 else None, + str(frame_path), + str(attention_path), + str(signal_path) if signal_path.exists() else None, + str(raw_path) if raw_path.exists() else None, + str(post_path) if post_path.exists() else None, + str(csv_path) if csv_path else None + ) + +def process_stream( + input_source: str, + video_path: Optional[str], + gt_file: Optional[str], + model_name: str, + fps_input: int, + max_seconds: int, + roi_type: str, + control_id: str +): + if input_source == "Live Webcam": + yield from process_live_webcam(model_name, fps_input, roi_type, control_id) + else: + yield from process_video_file(video_path, gt_file, model_name, fps_input, + max_seconds, roi_type, control_id) + +def pause_processing(control_id: str) -> str: + _, controls = ensure_controls(control_id) + controls['pause'].set() + return "Paused" + +def resume_processing(control_id: str) -> str: + _, controls = ensure_controls(control_id) + controls['pause'].clear() + return "Resumed" + +def stop_processing(control_id: str) -> str: + _, controls = ensure_controls(control_id) + controls['stop'].set() + controls['pause'].clear() + return "Stopped" + +def reset_ui(): + return ("Ready", None, None, None, None, None, None, None, None, None) + +def handle_folder_upload(files): + if not files: + return None, None, "No files uploaded" + + if not isinstance(files, list): + files = [files] + + video_path = None + for file_obj in files: + file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name) + if file_path.suffix.lower() in VIDEO_EXTENSIONS: + video_path = str(file_path) + break + + gt_path = None + gt_patterns = ['gtdump.txt', 'ground_truth.txt', 'gt.txt'] + + for file_obj in files: + file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name) + if file_path.name.lower() in [p.lower() for p in gt_patterns]: + gt_path = str(file_path) + break + + if not gt_path: + for file_obj in files: + file_path = Path(file_obj) if isinstance(file_obj, str) else Path(file_obj.name) + if file_path.suffix.lower() in ['.txt', '.json', '.csv']: + if 'readme' not in file_path.name.lower(): + gt_path = str(file_path) + break + + status = [] + if video_path: + status.append(f"Video: {Path(video_path).name}") + else: + status.append("No video found") + + if gt_path: + status.append(f"GT: {Path(gt_path).name}") + else: + status.append("No ground truth") + + return video_path, gt_path, " | ".join(status) + +with gr.Blocks(title="rPPG Analysis with Attention", theme=gr.themes.Soft()) as demo: + gr.Markdown("# rPPG Analysis Tool with Attention Visualization") + + with gr.Row(): + input_source = gr.Radio( + choices=["Video File", "Subject Folder", "Live Webcam"], + value="Live Webcam", + label="Input Source" + ) + + with gr.Row(): + target_layer_input = gr.Textbox( + label="Grad-CAM Target Layer (optional - leave empty for auto)", + placeholder="e.g., backbone.conv3, encoder.layer2", + value="" + ) + + with gr.Row(visible=False) as video_inputs: + with gr.Column(): + video_upload = gr.Video(label="Upload Video", sources=["upload"]) + with gr.Column(): + gt_upload = gr.File(label="Ground Truth (optional)", + file_types=[".txt", ".csv", ".json"]) + + with gr.Row(visible=False) as folder_inputs: + with gr.Column(): + folder_upload = gr.File( + label="Upload Subject Folder", + file_count="directory", + file_types=None, + type="filepath" + ) + folder_status = gr.Textbox(label="Folder Status", interactive=False) + + with gr.Row(visible=True) as webcam_inputs: + gr.Markdown("### Live Webcam - Click Run to start") + + def toggle_input_source(source): + return [ + gr.update(visible=(source == "Video File")), + gr.update(visible=(source == "Subject Folder")), + gr.update(visible=(source == "Live Webcam")) + ] + + input_source.change( + toggle_input_source, + inputs=[input_source], + outputs=[video_inputs, folder_inputs, webcam_inputs] + ) + + folder_video = gr.State(None) + folder_gt = gr.State(None) + + folder_upload.upload( + handle_folder_upload, + inputs=[folder_upload], + outputs=[folder_video, folder_gt, folder_status] + ) + + with gr.Row(): + with gr.Column(scale=3): + model_dropdown = gr.Dropdown( + choices=scan_models(), + value=scan_models()[0] if scan_models() else None, + label="PhysMamba Model", + interactive=True + ) + with gr.Column(scale=1): + refresh_models_btn = gr.Button("Refresh", variant="secondary") + + refresh_models_btn.click( + lambda: gr.update(choices=scan_models()), + inputs=None, + outputs=[model_dropdown] + ) + + with gr.Row(): + fps_slider = gr.Slider( + minimum=10, maximum=120, value=30, step=5, + label="FPS" + ) + max_seconds_slider = gr.Slider( + minimum=10, maximum=600, value=180, step=10, + label="Max Duration (s)" + ) + + with gr.Row(): + roi_dropdown = gr.Dropdown(choices=["auto","forehead","cheeks","face"], value="auto", label="ROI") + + control_state = gr.State(value="") + placeholder_state = gr.State(value=None) + + with gr.Row(): + run_btn = gr.Button("Run", variant="primary") + pause_btn = gr.Button("Pause", variant="secondary") + resume_btn = gr.Button("Resume", variant="secondary") + stop_btn = gr.Button("Stop", variant="stop") + + status_text = gr.Textbox(label="Status", lines=2, value="Ready") + + with gr.Row(): + hr_output = gr.Textbox(label="HR (BPM)", interactive=False) + gt_hr_output = gr.Textbox(label="GT HR (BPM)", interactive=False) + rmssd_output = gr.Textbox(label="HRV RMSSD (ms)", interactive=False) + + with gr.Row(): + with gr.Column(): + frame_output = gr.Image(label="Video Feed", type="filepath") + with gr.Column(): + attention_output = gr.Image(label="Attention Map", type="filepath") + + with gr.Row(): + signal_output = gr.Image(label="Signal Comparison", type="filepath") + + with gr.Row(): + with gr.Column(): + raw_signal_output = gr.Image(label="Raw Signal", type="filepath") + with gr.Column(): + post_signal_output = gr.Image(label="Post-processed Signal", type="filepath") + + with gr.Row(): + csv_output = gr.File(label="Download CSV") + + pause_btn.click( + pause_processing, + inputs=[control_state], + outputs=[status_text] + ) + + resume_btn.click( + resume_processing, + inputs=[control_state], + outputs=[status_text] + ) + + stop_btn.click( + stop_processing, + inputs=[control_state], + outputs=[status_text] + ).then( + reset_ui, + inputs=None, + outputs=[status_text, hr_output, gt_hr_output, rmssd_output, + frame_output, attention_output, signal_output, + raw_signal_output, post_signal_output, csv_output] + ) + + def run_processing(input_source, video_upload, gt_upload, folder_video, folder_gt, + model_name, fps, max_sec, roi, ctrl_id): + """Fixed version that handles model_name type conversion.""" + + if isinstance(model_name, int): + model_name = str(model_name) + + if not model_name: + yield ("ERROR: No model selected", None, None, None, None, None, None, None, None, None) + return + + if input_source == "Video File": + video_path = _as_path(video_upload) + gt_file = _as_path(gt_upload) + elif input_source == "Subject Folder": + video_path = _as_path(folder_video) + gt_file = _as_path(folder_gt) + else: # Live Webcam + video_path, gt_file = None, None + + yield from process_stream( + input_source, video_path, gt_file, + model_name, fps, max_sec, roi, ctrl_id + ) + + + run_btn.click( + fn=run_processing, + inputs=[ + input_source, + video_upload, + gt_upload, + folder_video, + folder_gt, + model_dropdown, + fps_slider, + max_seconds_slider, + roi_dropdown, + control_state + ], + outputs=[ + status_text, + hr_output, + gt_hr_output, + rmssd_output, + frame_output, + attention_output, + signal_output, + raw_signal_output, + post_signal_output, + csv_output + ] + ) + + + run_btn.click( + fn=run_processing, + inputs=[ + input_source, + video_upload, + folder_video, + folder_gt, + model_dropdown, + fps_slider, + max_seconds_slider, + roi_dropdown, + control_state + ], + outputs=[ + status_text, + hr_output, + gt_hr_output, + rmssd_output, + frame_output, + attention_output, + signal_output, + raw_signal_output, + post_signal_output, + csv_output + ] + ) + +if __name__ == "__main__": + demo.queue(max_size=10).launch( + server_name="127.0.0.1", + server_port=7861, + share=False, + show_error=True, + inbrowser=True + ) \ No newline at end of file diff --git a/final_model_release/PURE_PhysMamba_DiffNormalized.pth b/final_model_release/PURE_PhysMamba_DiffNormalized.pth new file mode 100644 index 0000000000000000000000000000000000000000..6f12a9811f91a7a40f7d11f85bcfa1b8101cf5ab --- /dev/null +++ b/final_model_release/PURE_PhysMamba_DiffNormalized.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2da6b9bc7e8728c20743c8f359eb43eefd2a96665098a29302fed2b09bc6b7d7 +size 3013798 diff --git a/final_model_release/UBFC-rPPG_PhysMamba_DiffNormalized.pth b/final_model_release/UBFC-rPPG_PhysMamba_DiffNormalized.pth new file mode 100644 index 0000000000000000000000000000000000000000..5b17c4507b1663ff793e2873ad3c5a306ac38bd0 --- /dev/null +++ b/final_model_release/UBFC-rPPG_PhysMamba_DiffNormalized.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd18f227ff18f5140f471dc78e61426d9344ecc43677b35ec16b3199f1dccd19 +size 3013798 diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd223a52bb43ab0d5ebef78bbc299b1ac6c54590 --- /dev/null +++ b/mamba_ssm/__init__.py @@ -0,0 +1,20 @@ +import torch.nn.functional as F + +import torch, torch.nn as nn, torch.nn.functional as F +print('[mamba_ssm shim] Loaded CPU shim') +class Mamba(nn.Module): + def __init__(self, d_model, d_state=16, d_conv=4, expand=2, conv_bias=True, **kwargs): + super().__init__() + h = d_model * expand + self.in_proj = nn.Linear(d_model, 2*h) + self.dw = nn.Conv1d(h, h, d_conv, padding=d_conv-1, groups=h, bias=conv_bias) + self.mix = nn.Conv1d(h, h, 1) + self.out = nn.Linear(h, d_model) + self.d_model = d_model + def forward(self, x): + B,L,C = x.shape + u,v = self.in_proj(x).chunk(2, dim=-1) + y = F.silu(u) * torch.sigmoid(v) + y = self.dw(y.transpose(1,2))[...,:L] + y = F.silu(self.mix(y)).transpose(1,2) + return self.out(y) diff --git a/mamba_ssm/models/__init__.py b/mamba_ssm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py new file mode 100644 index 0000000000000000000000000000000000000000..383f773f1f700cd53176e51327a5d8dc58158da0 --- /dev/null +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -0,0 +1,233 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. + +import math +from functools import partial + +from collections import namedtuple + +import torch +import torch.nn as nn + +from mamba_ssm.modules.mamba_simple import Mamba, Block +from mamba_ssm.utils.generation import GenerationMixin +from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf + +try: + from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +def create_block( + d_model, + ssm_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None, +): + if ssm_cfg is None: + ssm_cfg = {} + factory_kwargs = {"device": device, "dtype": dtype} + mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + block = Block( + d_model, + mixer_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + ) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights( + module, + n_layer, + initializer_range=0.02, # Now only used for embedding layer. + rescale_prenorm_residual=True, + n_residuals_per_layer=1, # Change to 2 if we have MLP +): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +class MixerModel(nn.Module): + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + ssm_cfg=None, + norm_epsilon: float = 1e-5, + rms_norm: bool = False, + initializer_cfg=None, + fused_add_norm=False, + residual_in_fp32=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + + self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) + + # We change the order of residual and layer norm: + # Instead of LN -> Attn / MLP -> Add, we do: + # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and + # the main branch (output of MLP / Mixer). The model definition is unchanged. + # This is for performance reason: we can fuse add + layer_norm. + self.fused_add_norm = fused_add_norm + if self.fused_add_norm: + if layer_norm_fn is None or rms_norm_fn is None: + raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") + + self.layers = nn.ModuleList( + [ + create_block( + d_model, + ssm_cfg=ssm_cfg, + norm_epsilon=norm_epsilon, + rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, + fused_add_norm=fused_add_norm, + layer_idx=i, + **factory_kwargs, + ) + for i in range(n_layer) + ] + ) + + self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( + d_model, eps=norm_epsilon, **factory_kwargs + ) + + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def forward(self, input_ids, inference_params=None): + hidden_states = self.embedding(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params + ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + return hidden_states + + +class MambaLMHeadModel(nn.Module, GenerationMixin): + + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + initializer_cfg=None, + pad_vocab_size_multiple: int = 1, + device=None, + dtype=None, + **backbone_kwargs, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.backbone = MixerModel( + d_model=d_model, + n_layer=n_layer, + vocab_size=vocab_size, + initializer_cfg=initializer_cfg, + **backbone_kwargs, + **factory_kwargs, + ) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + self.tie_weights() + + def tie_weights(self): + self.lm_head.weight = self.backbone.embedding.weight + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + """ + "position_ids" is just to be compatible with Transformer generation. We don't use it. + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + hidden_states = self.backbone(input_ids, inference_params=inference_params) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + lm_logits = self.lm_head(hidden_states) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + @classmethod + def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): + config = load_config_hf(pretrained_model_name) + model = cls(**config, device=device, dtype=dtype, **kwargs) + model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) + return model diff --git a/mamba_ssm/modules/__init__.py b/mamba_ssm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1dd8f808b50d632bbd22f0648d4cb8939cb1e1 --- /dev/null +++ b/mamba_ssm/modules/mamba_simple.py @@ -0,0 +1,418 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None + +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj +except ImportError: + selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None + +try: + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +try: + from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +class Mamba(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + bimamba=True, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + self.bimamba = bimamba + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + + self.activation = "silu" + self.act = nn.SiLU() + + self.x_proj = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + # NOTE: why plus 1? + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + self.D._no_weight_decay = True + + # bidirectional + # forked from https://github.com/hustvl/Vim + if self.bimamba: + A_b = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner, + ).contiguous() + A_b_log = torch.log(A_b) # Keep A_b_log in fp32 + self.A_b_log = nn.Parameter(A_b_log) + self.A_b_log._no_weight_decay = True + + self.conv1d_b = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + + self.x_proj_b = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + + self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + self.D_b._no_weight_decay = True + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + def forward(self, hidden_states, inference_params=None, T=1): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + batch, seqlen, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out + + # We do matmul and transpose BLH -> HBL at the same time + # NOTE: same as in_proj(hidden_states) but memory-efficient with the following operations + xz = rearrange( + self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), + "d (b l) -> b d l", + l=seqlen, + ) + if self.in_proj.bias is not None: + xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") + + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + # In the backward pass we write dx and dz next to each other to avoid torch.cat + if self.use_fast_path and inference_params is None: # Doesn't support outputting the states + if self.bimamba: + A_b = -torch.exp(self.A_b_log.float()) + out = mamba_inner_fn_no_out_proj( + xz, + self.conv1d.weight, + self.conv1d.bias, + self.x_proj.weight, + self.dt_proj.weight, + A, + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + out_b = mamba_inner_fn_no_out_proj( + xz.flip([-1]), + self.conv1d_b.weight, + self.conv1d_b.bias, + self.x_proj_b.weight, + self.dt_proj_b.weight, + A_b, + None, + None, + self.D_b.float(), + delta_bias=self.dt_proj_b.bias.float(), + delta_softplus=True, + ) + out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) + else: + out = mamba_inner_fn( + xz, + self.conv1d.weight, + self.conv1d.bias, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias, + A, + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + else: + x, z = xz.chunk(2, dim=1) + # Compute short convolution + if conv_state is not None: + conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + assert self.activation in ["silu", "swish"] + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=ssm_state is not None, + ) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, "b d l -> b l d") + out = self.out_proj(y) + return out + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # SSM step + if selective_state_update is None: + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + dB = torch.einsum("bd,bn->bdn", dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + else: + y = selective_state_update( + ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype + ) + ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (batch_size,) + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=self.dt_proj.weight.device, + dtype=self.dt_proj.weight.dtype, + # dtype=torch.float32, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state + + +class Block(nn.Module): + def __init__( + self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(dim) + self.norm = norm_cls(dim) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + """ + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + return hidden_states, residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/mamba_ssm/ops/__init__.py b/mamba_ssm/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da226f864cade0e641c1d0cc9bc07413069b8d9c --- /dev/null +++ b/mamba_ssm/ops/__init__.py @@ -0,0 +1,7 @@ + +import torch +def selective_scan(*args, **kwargs): + for a in args: + if isinstance(a, torch.Tensor): + return a + return torch.tensor(0.0) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..79be971bf7a9ec1eaa8d8af3d8bd75eead52b868 --- /dev/null +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -0,0 +1,17 @@ + +import torch +def selective_scan_fn(*args, **kwargs): + for a in args: + if isinstance(a, torch.Tensor): + return a + return torch.tensor(0.0) +def mamba_inner_fn(*args, **kwargs): + for a in args: + if isinstance(a, torch.Tensor): + return a + return torch.tensor(0.0) +def bimamba_inner_fn(*args, **kwargs): + for a in args: + if isinstance(a, torch.Tensor): + return a + return torch.tensor(0.0) diff --git a/mamba_ssm/ops/triton/__init__.py b/mamba_ssm/ops/triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mamba_ssm/ops/triton/layernorm.py b/mamba_ssm/ops/triton/layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..8df9d042a34b6584196f218f5ffeeb104799bd5e --- /dev/null +++ b/mamba_ssm/ops/triton/layernorm.py @@ -0,0 +1,636 @@ +# Copyright (c) 2023, Tri Dao. +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd + +import triton +import triton.language as tl + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + + +def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py new file mode 100644 index 0000000000000000000000000000000000000000..fa95de73f173292914c5f00fbe9426937d00e502 --- /dev/null +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -0,0 +1,192 @@ +# Copyright (c) 2023, Tri Dao. + +"""We want triton==2.1.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, + # Matrix dimensions + batch, dim, dstate, + # Strides + stride_state_batch, stride_state_dim, stride_state_dstate, + stride_x_batch, stride_x_dim, + stride_dt_batch, stride_dt_dim, + stride_dt_bias_dim, + stride_A_dim, stride_A_dstate, + stride_B_batch, stride_B_dstate, + stride_C_batch, stride_C_dstate, + stride_D_dim, + stride_z_batch, stride_z_dim, + stride_out_batch, stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + state_ptr += pid_b * stride_state_batch + x_ptr += pid_b * stride_x_batch + dt_ptr += pid_b * stride_dt_batch + B_ptr += pid_b * stride_B_batch + C_ptr += pid_b * stride_C_batch + if HAS_Z: + z_ptr += pid_b * stride_z_batch + out_ptr += pid_b * stride_out_batch + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.log(1.0 + tl.exp(dt)) + A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) + x: (batch, dim) + dt: (batch, dim) + A: (dim, dstate) + B: (batch, dstate) + C: (batch, dstate) + D: (dim,) + z: (batch, dim) + dt_bias: (dim,) + Return: + out: (batch, dim) + """ + batch, dim, dstate = state.shape + assert x.shape == (batch, dim) + assert dt.shape == x.shape + assert A.shape == (dim, dstate) + assert B.shape == (batch, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (dim,) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (dim,) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) + z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 + else ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else + ((4, 8)))))) + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, x, dt, dt_bias, A, B, C, D, z, out, + batch, dim, dstate, + state.stride(0), state.stride(1), state.stride(2), + x.stride(0), x.stride(1), + dt.stride(0), dt.stride(1), + dt_bias.stride(0) if dt_bias is not None else 0, + A.stride(0), A.stride(1), + B.stride(0), B.stride(1), + C.stride(0), C.stride(1), + D.stride(0) if D is not None else 0, + z_strides[0], z_strides[1], + out.stride(0), out.stride(1), + dt_softplus, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + return out + + +def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) + x: (batch, dim) + dt: (batch, dim) + A: (dim, dstate) + B: (batch, dstate) + C: (batch, dstate) + D: (dim,) + z: (batch, dim) + dt_bias: (dim,) + Return: + out: (batch, dim) + """ + batch, dim, dstate = state.shape + assert x.shape == (batch, dim) + assert dt.shape == x.shape + assert A.shape == (dim, dstate) + assert B.shape == (batch, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (dim,) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (dim,) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) + dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) + state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate + out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + return (out if z is None else out * F.silu(z)).to(x.dtype) diff --git a/mamba_ssm/utils/__init__.py b/mamba_ssm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mamba_ssm/utils/generation.py b/mamba_ssm/utils/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..9d766b29ac28a388a7d77b22aa2cb1eda733c0f4 --- /dev/null +++ b/mamba_ssm/utils/generation.py @@ -0,0 +1,377 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. +import gc +import time +from collections import namedtuple +from dataclasses import dataclass, field +from functools import partial +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor +from torch.profiler import ProfilerActivity, profile, record_function +from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput + + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + seqlen_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + lengths_per_sample: Optional[Tensor] = None + + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf. Done in-place.""" + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(indices_to_remove, float("-Inf")) + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf. Done in-place.""" + if top_p <= 0.0 or top_p >= 1.0: + return + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits.masked_fill_(indices_to_remove, float("-inf")) + + +def sample(logits, top_k=1, top_p=0.0, temperature=1.0): + """Sample from top-k logits. + Arguments: + logits: Tensor of shape (batch_size, vocab_size) + """ + if top_k == 1: # Short-circuit for greedy decoding + return logits.argmax(dim=-1) + else: + if top_p > 0.0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + logits_top, indices = torch.topk(logits, top_k, dim=-1) + if temperature != 1.0: + logits_top /= temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return indices[ + torch.arange(indices.shape[0], device=indices.device), + torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), + ] + else: + # Clone so that when we modify for top_p we don't change the original logits + logits_top = logits / temperature if temperature != 1.0 else logits.clone() + modify_logits_for_top_p_filtering(logits_top, top_p) + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( + dim=-1 + ) + + +@torch.inference_mode() +def decode( + input_ids, + model, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + eos_token_id=None, + teacher_outputs=None, + vocab_size=None, + tensor_parallel=1, + cg=False, + enable_timing=False, +): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 + if cg: + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + tensor_parallel=tensor_parallel, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params): + decoding = inference_params.seqlen_offset > 0 + if decoding: + position_ids = torch.full( + (batch_size, 1), + inference_params.seqlen_offset, + dtype=torch.long, + device=input_ids.device, + ) + else: + position_ids = None + if not cg or not decoding: + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=1, + ).logits.squeeze(dim=1) + else: + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + ).squeeze(dim=1) + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(logits, inference_params): + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: + token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + else: + token = teacher_outputs[:, inference_params.seqlen_offset] + # return rearrange(token, "b -> b 1") + return token.unsqueeze(1) + + def should_stop(current_token, inference_params): + if inference_params.seqlen_offset == 0: + return False + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if inference_params.seqlen_offset >= max_length - 1: + return True + return False + + start = torch.cuda.Event(enable_timing=enable_timing) + end = torch.cuda.Event(enable_timing=enable_timing) + + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + start.record() + scores, sequences = [], [input_ids] + while not should_stop(sequences[-1], inference_params): + scores.append(get_logits(sequences[-1], inference_params)) + inference_params.seqlen_offset += sequences[-1].shape[1] + sequences.append(sample_tokens(scores[-1], inference_params)) + if enable_timing: + end.record() + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) + + +class GenerationMixin: + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + + def generate( + self, + input_ids, + max_length, + top_k=1, + top_p=0.0, + temperature=1.0, + return_dict_in_generate=False, + output_scores=False, + **kwargs, + ): + output = decode( + input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs + ) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences + + +def allocate_inference_cache( + max_batch_size, + max_seqlen, + nheads, + headdim, + layers: Union[int, Sequence], + device, + dtype=torch.float16, +): + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) + if isinstance(layers, int): + layers = range(layers) + return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} + + +@dataclass +class DecodingCGCache: + max_batch_size: int = 0 + max_seqlen: int = 0 + device = None + dtype = None + callables: dict = field(default_factory=dict) + mempool = None + inference_params: Optional[InferenceParams] = None + run: Optional[Callable] = None + + +@torch.inference_mode() +def update_graph_cache( + model, + cache, + batch_size, + seqlen_og, + max_seqlen, + decoding_seqlens=(1,), + tensor_parallel=1, + dtype=None, + n_warmups=2, +): + if cache is None: + cache = DecodingCGCache() + param_example = next(iter(model.parameters())) + device = param_example.device + if dtype is None: + dtype = param_example.dtype + if ( + (device, dtype) != (cache.device, cache.dtype) + or batch_size > cache.max_batch_size + or max_seqlen > cache.max_seqlen + ): # Invalidate the cache + cache.callables = {} + cache.mempool = None + cache.inference_params = None + gc.collect() + cache.device, cache.dtype = device, dtype + cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen + if hasattr(model, "allocate_inference_cache"): + inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + else: + headdim = getattr( + model.config, + "head_dim", + model.config.hidden_size // model.config.num_attention_heads, + ) + inf_cache = allocate_inference_cache( + batch_size, + max_seqlen, + model.config.num_attention_heads // tensor_parallel, + headdim, + model.config.num_hidden_layers, + device, + dtype, + ) + lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) + cache.inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + cache.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if (batch_size, decoding_seqlen) not in cache.callables: + cache.callables[batch_size, decoding_seqlen] = capture_graph( + model, + cache.inference_params, + batch_size, + max_seqlen, + decoding_seqlen=decoding_seqlen, + mempool=cache.mempool, + n_warmups=n_warmups, + ) + + def dispatch(input_ids, position_ids, seqlen): + batch_size, decoding_seqlen = input_ids.shape[:2] + return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) + + cache.run = dispatch + cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing + return cache + + +def capture_graph( + model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 +): + device = next(iter(model.parameters())).device + input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + seqlen_offset_og = inference_params.seqlen_offset + inference_params.seqlen_offset = max_seqlen - decoding_seqlen + inference_params.lengths_per_sample[:] = inference_params.seqlen_offset + + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(n_warmups): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + s.synchronize() + # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, + # which requires that graph launch and non-captured launch to not overlap (I think, + # that's how I interpret the documentation). I'm not sure if this is required. + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.current_stream().wait_stream(s) + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=mempool): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + + def run(new_input_ids, new_position_ids, seqlen): + inference_params.lengths_per_sample[:] = seqlen + input_ids.copy_(new_input_ids) + position_ids.copy_(new_position_ids) + graph.replay() + return logits.clone() + + inference_params.seqlen_offset = seqlen_offset_og + return run diff --git a/mamba_ssm/utils/hf.py b/mamba_ssm/utils/hf.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7555acddbd260636d1d14d5bd6324f6af0056a --- /dev/null +++ b/mamba_ssm/utils/hf.py @@ -0,0 +1,23 @@ +import json + +import torch + +from transformers.utils import WEIGHTS_NAME, CONFIG_NAME +from transformers.utils.hub import cached_file + + +def load_config_hf(model_name): + resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) + return json.load(open(resolved_archive_file)) + + +def load_state_dict_hf(model_name, device=None, dtype=None): + # If not fp32, then we don't want to load directly to the GPU + mapped_device = "cpu" if dtype not in [torch.float32, None] else device + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) + return torch.load(resolved_archive_file, map_location=mapped_device) + # Convert dtype before moving to GPU to save memory + if dtype is not None: + state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} + state_dict = {k: v.to(device=device) for k, v in state_dict.items()} + return state_dict diff --git a/neural_methods/__init__.py b/neural_methods/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/neural_methods/loss/NegPearsonLoss.py b/neural_methods/loss/NegPearsonLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6bcf5d222c70f3c6c78aea00b4887b125eaf53 --- /dev/null +++ b/neural_methods/loss/NegPearsonLoss.py @@ -0,0 +1,23 @@ +from __future__ import print_function, division +import torch +import matplotlib.pyplot as plt +import argparse, os +import pandas as pd +import numpy as np +import random +import math +from torchvision import transforms +from torch import nn + + +class Neg_Pearson(nn.Module): + def __init__(self): + super(Neg_Pearson, self).__init__() + return + + def forward(self, preds, labels): + cos = nn.CosineSimilarity(dim=0, eps=1e-6) + pearson = cos(preds - preds.mean(dim=0, keepdim=True), labels - labels.mean(dim=0, keepdim=True)) + return torch.mean(1 - pearson) + + diff --git a/neural_methods/loss/PhysFormerLossComputer.py b/neural_methods/loss/PhysFormerLossComputer.py new file mode 100644 index 0000000000000000000000000000000000000000..a55335ff5effa677403f71859b7f2285ad5a17ed --- /dev/null +++ b/neural_methods/loss/PhysFormerLossComputer.py @@ -0,0 +1,120 @@ +''' + Adapted from here: https://github.com/ZitongYu/PhysFormer/blob/main/TorchLossComputer.py + Modifed based on the HR-CNN here: https://github.com/radimspetlik/hr-cnn +''' +import math +import torch +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F +import pdb +import torch.nn as nn + +def normal_sampling(mean, label_k, std): + return math.exp(-(label_k-mean)**2/(2*std**2))/(math.sqrt(2*math.pi)*std) + +def kl_loss(inputs, labels): + # Reshape the labels tensor to match the shape of inputs + labels = labels.view(1, -1) + + # Compute the KL Div Loss + criterion = nn.KLDivLoss(reduction='sum') + loss = criterion(F.log_softmax(inputs, dim=-1), labels) + return loss + +class TorchLossComputer(object): + @staticmethod + def compute_complex_absolute_given_k(output, k, N): + two_pi_n_over_N = torch.autograd.Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N + hanning = torch.autograd.Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1) + + k = k.type(torch.FloatTensor).cuda() + two_pi_n_over_N = two_pi_n_over_N.cuda() + hanning = hanning.cuda() + + output = output.view(1, -1) * hanning + output = output.view(1, 1, -1).type(torch.cuda.FloatTensor) + k = k.view(1, -1, 1) + two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1) + complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \ + + torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2 + + return complex_absolute + + @staticmethod + def complex_absolute(output, Fs, bpm_range=None): + output = output.view(1, -1) + + N = output.size()[1] + + unit_per_hz = Fs / N + feasible_bpm = bpm_range / 60.0 + k = feasible_bpm / unit_per_hz + + # only calculate feasible PSD range [0.7,4] Hz + complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N) + + return (1.0 / complex_absolute.sum()) * complex_absolute # Analogous Softmax operator + + @staticmethod + def cross_entropy_power_spectrum_loss(inputs, target, Fs): + inputs = inputs.view(1, -1) + target = target.view(1, -1) + bpm_range = torch.arange(40, 180, dtype=torch.float).cuda() + + complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) + + whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) + whole_max_idx = whole_max_idx.type(torch.float) + + return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx) + + @staticmethod + def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma): + inputs = inputs.view(1, -1) + target = target.view(1, -1) + bpm_range = torch.arange(40, 180, dtype=torch.float).cuda() + + complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) + + whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) + whole_max_idx = whole_max_idx.type(torch.float) + + #pdb.set_trace() + criterion = FocalLoss(gamma=gamma) + + return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx) + + + @staticmethod + def cross_entropy_power_spectrum_forward_pred(inputs, Fs): + inputs = inputs.view(1, -1) + bpm_range = torch.arange(40, 190, dtype=torch.float).cuda() + + complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) + + whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) + whole_max_idx = whole_max_idx.type(torch.float) + + return whole_max_idx + + @staticmethod + def cross_entropy_power_spectrum_DLDL_softmax2(inputs, target, Fs, std): + target_distribution = [normal_sampling(int(target), i, std) for i in range(40, 180)] + target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution] + target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda')) + + inputs = inputs.view(1, -1) + target = target.view(1, -1) + + bpm_range = torch.arange(40, 180, dtype=torch.float).to(torch.device('cuda')) + + ca = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) + + fre_distribution = ca/torch.sum(ca) + loss_distribution_kl = kl_loss(fre_distribution, target_distribution) + + whole_max_val, whole_max_idx = ca.view(-1).max(0) + whole_max_idx = whole_max_idx.type(torch.float) + return loss_distribution_kl, F.cross_entropy(ca, (target-bpm_range[0]).view(1).type(torch.long)), torch.abs(target[0]-bpm_range[0]-whole_max_idx) + \ No newline at end of file diff --git a/neural_methods/loss/PhysNetNegPearsonLoss.py b/neural_methods/loss/PhysNetNegPearsonLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7bee6d2a7c1378c856063cb087bf07ae963bac --- /dev/null +++ b/neural_methods/loss/PhysNetNegPearsonLoss.py @@ -0,0 +1,43 @@ +from __future__ import print_function, division +import torch +import matplotlib.pyplot as plt +import argparse, os +import pandas as pd +import numpy as np +import random +import math +from torchvision import transforms +from torch import nn + + +class Neg_Pearson(nn.Module): + """ + The Neg_Pearson Module is from the orignal author of Physnet. + Code of 'Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks' + source: https://github.com/ZitongYu/PhysNet/blob/master/NegPearsonLoss.py + """ + + def __init__(self): + super(Neg_Pearson, self).__init__() + return + + + def forward(self, preds, labels): + loss = 0 + for i in range(preds.shape[0]): + sum_x = torch.sum(preds[i]) + sum_y = torch.sum(labels[i]) + sum_xy = torch.sum(preds[i]*labels[i]) + sum_x2 = torch.sum(torch.pow(preds[i],2)) + sum_y2 = torch.sum(torch.pow(labels[i],2)) + N = preds.shape[1] + pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2)))) + loss += 1 - pearson + + + loss = loss/preds.shape[0] + return loss + + + + diff --git a/neural_methods/loss/RythmFormerLossComputer.py b/neural_methods/loss/RythmFormerLossComputer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c441eaa73b7f2e8b16380b337ef8cc33932513b --- /dev/null +++ b/neural_methods/loss/RythmFormerLossComputer.py @@ -0,0 +1,167 @@ +''' + Adapted from here: https://github.com/ZitongYu/PhysFormer/TorchLossComputer.py + Modifed based on the HR-CNN here: https://github.com/radimspetlik/hr-cnn +''' +import math +import torch +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F +import torch.nn as nn +from evaluation.post_process import calculate_metric_per_video + +def normal_sampling(mean, label_k, std): + return math.exp(-(label_k-mean)**2/(2*std**2))/(math.sqrt(2*math.pi)*std) + +def kl_loss(inputs, labels): + criterion = nn.KLDivLoss(reduce=False) + outputs = torch.log(inputs) + loss = criterion(outputs, labels) + #loss = loss.sum()/loss.shape[0] + loss = loss.sum() + return loss + +class Neg_Pearson(nn.Module): # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss + def __init__(self): + super(Neg_Pearson,self).__init__() + + def forward(self, preds, labels): # all variable operation + loss = 0 + for i in range(preds.shape[0]): + sum_x = torch.sum(preds[i]) # x + sum_y = torch.sum(labels[i]) # y + sum_xy = torch.sum(preds[i]*labels[i]) # xy + sum_x2 = torch.sum(torch.pow(preds[i],2)) # x^2 + sum_y2 = torch.sum(torch.pow(labels[i],2)) # y^2 + N = preds.shape[1] + pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2)))) + loss += 1 - pearson + + loss = loss/preds.shape[0] + return loss + +class RhythmFormer_Loss(nn.Module): + def __init__(self): + super(RhythmFormer_Loss,self).__init__() + self.criterion_Pearson = Neg_Pearson() + def forward(self, pred_ppg, labels , epoch , FS , diff_flag): + loss_time = self.criterion_Pearson(pred_ppg.view(1,-1) , labels.view(1,-1)) + loss_CE , loss_distribution_kl = TorchLossComputer.Frequency_loss(pred_ppg.squeeze(-1), labels.squeeze(-1), diff_flag=diff_flag, Fs=FS, std=3.0) + loss_hr = TorchLossComputer.HR_loss(pred_ppg.squeeze(-1), labels.squeeze(-1), diff_flag=diff_flag, Fs=FS, std=3.0) + if torch.isnan(loss_time) : + loss_time = 0 + + loss = 0.2 * loss_time + 1.0 * loss_CE + 1.0 * loss_hr + return loss + +class TorchLossComputer(object): + @staticmethod + def compute_complex_absolute_given_k(output, k, N): + two_pi_n_over_N = Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N + hanning = Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1) + + k = k.type(torch.FloatTensor).cuda() + two_pi_n_over_N = two_pi_n_over_N.cuda() + hanning = hanning.cuda() + + output = output.view(1, -1) * hanning + output = output.view(1, 1, -1).type(torch.cuda.FloatTensor) + k = k.view(1, -1, 1) + two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1) + complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \ + + torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2 + + return complex_absolute + + @staticmethod + def complex_absolute(output, Fs, bpm_range=None): + output = output.view(1, -1) + + N = output.size()[1] + + unit_per_hz = Fs / N + feasible_bpm = bpm_range / 60.0 + k = feasible_bpm / unit_per_hz + + # only calculate feasible PSD range [0.7,4]Hz + complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N) + + return (1.0 / complex_absolute.sum()) * complex_absolute # Analogous Softmax operator + + + @staticmethod + def cross_entropy_power_spectrum_loss(inputs, target, Fs): + inputs = inputs.view(1, -1) + target = target.view(1, -1) + bpm_range = torch.arange(40, 180, dtype=torch.float).cuda() + #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda() + + complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) + + whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) + whole_max_idx = whole_max_idx.type(torch.float) + + #pdb.set_trace() + + #return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2 + return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx) + + @staticmethod + def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma): + inputs = inputs.view(1, -1) + target = target.view(1, -1) + bpm_range = torch.arange(40, 180, dtype=torch.float).cuda() + #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda() + + complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) + + whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) + whole_max_idx = whole_max_idx.type(torch.float) + + #pdb.set_trace() + criterion = FocalLoss(gamma=gamma) + + #return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2 + return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx) + + + @staticmethod + def cross_entropy_power_spectrum_forward_pred(inputs, Fs): + inputs = inputs.view(1, -1) + bpm_range = torch.arange(40, 190, dtype=torch.float).cuda() + #bpm_range = torch.arange(40, 180, dtype=torch.float).cuda() + #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda() + + complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) + + whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) + whole_max_idx = whole_max_idx.type(torch.float) + + return whole_max_idx + + @staticmethod + def Frequency_loss(inputs, target, diff_flag , Fs, std): + hr_gt, pred_hr_peak, SNR, macc = calculate_metric_per_video(inputs.detach().cpu(), target.detach().cpu(), diff_flag = diff_flag, fs=Fs, hr_method='FFT') + inputs = inputs.view(1, -1) + target = target.view(1, -1) + bpm_range = torch.arange(45, 150, dtype=torch.float).to(torch.device('cuda')) + ca = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) + sa = ca/torch.sum(ca) + + target_distribution = [normal_sampling(int(hr_gt), i, std) for i in range(45, 150)] + target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution] + target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda')) + + hr_gt = torch.tensor(hr_gt-45).view(1).type(torch.long).to(torch.device('cuda')) + return F.cross_entropy(ca, hr_gt) , kl_loss(sa , target_distribution) + + @staticmethod + def HR_loss(inputs, target, diff_flag , Fs, std): + psd_gt, psd_pred, SNR, macc = calculate_metric_per_video(inputs.detach().cpu(), target.detach().cpu(), diff_flag = diff_flag, fs=Fs, hr_method='Peak') + pred_distribution = [normal_sampling(np.argmax(psd_pred), i, std) for i in range(psd_pred.size)] + pred_distribution = [i if i > 1e-15 else 1e-15 for i in pred_distribution] + pred_distribution = torch.Tensor(pred_distribution).to(torch.device('cuda')) + target_distribution = [normal_sampling(np.argmax(psd_gt), i, std) for i in range(psd_gt.size)] + target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution] + target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda')) + return kl_loss(pred_distribution , target_distribution) diff --git a/neural_methods/loss/__init__.py b/neural_methods/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/neural_methods/model/BigSmall.py b/neural_methods/model/BigSmall.py new file mode 100644 index 0000000000000000000000000000000000000000..6da5a180b036cfcd861513aaec6d82d65737544d --- /dev/null +++ b/neural_methods/model/BigSmall.py @@ -0,0 +1,177 @@ +"""BigSmall: Multitask Network for AU / Respiration / PPG + +BigSmall: Efficient Multi-Task Learning +For Physiological Measurements +Girish Narayanswamy, Yujia (Nancy) Liu, Yuzhe Yang, Chengqian (Jack) Ma, +Xin Liu, Daniel McDuff, Shwetak Patel + +https://arxiv.org/abs/2303.11573 +""" + +import torch +import torch.nn as nn + + +##################################################### +############ Wrapping Time Shift Module ############# +##################################################### +class WTSM(nn.Module): + def __init__(self, n_segment=3, fold_div=3): + super(WTSM, self).__init__() + self.n_segment = n_segment + self.fold_div = fold_div + + def forward(self, x): + nt, c, h, w = x.size() + n_batch = nt // self.n_segment + x = x.view(n_batch, self.n_segment, c, h, w) + fold = c // self.fold_div + out = torch.zeros_like(x) + out[:, :-1, :fold] = x[:, 1:, :fold] # shift left + out[:, -1, :fold] = x[:, 0, :fold] # wrap left + out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right + out[:, 0, fold: 2 * fold] = x[:, -1, fold: 2 * fold] # wrap right + out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # no shift for final fold + return out.view(nt, c, h, w) + + + +####################################################################################### +##################################### BigSmall Model ################################## +####################################################################################### +class BigSmall(nn.Module): + + def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, + dropout_rate1=0.25, dropout_rate2=0.5, dropout_rate3=0.5, pool_size1=(2, 2), pool_size2=(4,4), + nb_dense=128, out_size_bvp=1, out_size_resp=1, out_size_au=12, n_segment=3): + + super(BigSmall, self).__init__() + + self.in_channels = in_channels + self.kernel_size = kernel_size + self.dropout_rate1 = dropout_rate1 + self.dropout_rate2 = dropout_rate2 + self.dropout_rate3 = dropout_rate3 + self.pool_size1 = pool_size1 + self.pool_size2 = pool_size2 + self.nb_filters1 = nb_filters1 + self.nb_filters2 = nb_filters2 + self.nb_dense = nb_dense + + self.out_size_bvp = out_size_bvp + self.out_size_resp = out_size_resp + self.out_size_au = out_size_au + + self.n_segment = n_segment + + # Big Convolutional Layers + self.big_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True) + self.big_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True) + self.big_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), bias=True) + self.big_conv4 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True) + self.big_conv5 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True) + self.big_conv6 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), bias=True) + + # Big Avg Pooling / Dropout Layers + self.big_avg_pooling1 = nn.AvgPool2d(self.pool_size1) + self.big_dropout1 = nn.Dropout(self.dropout_rate1) + self.big_avg_pooling2 = nn.AvgPool2d(self.pool_size1) + self.big_dropout2 = nn.Dropout(self.dropout_rate2) + self.big_avg_pooling3 = nn.AvgPool2d(self.pool_size2) + self.big_dropout3 = nn.Dropout(self.dropout_rate3) + + # TSM layers + self.TSM_1 = WTSM(n_segment=self.n_segment) + self.TSM_2 = WTSM(n_segment=self.n_segment) + self.TSM_3 = WTSM(n_segment=self.n_segment) + self.TSM_4 = WTSM(n_segment=self.n_segment) + + # Small Convolutional Layers + self.small_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True) + self.small_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True) + self.small_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, padding=(1,1), bias=True) + self.small_conv4 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1,1), bias=True) + + # AU Fully Connected Layers + self.au_fc1 = nn.Linear(5184, self.nb_dense, bias=True) + self.au_fc2 = nn.Linear(self.nb_dense, self.out_size_au, bias=True) + + # BVP Fully Connected Layers + self.bvp_fc1 = nn.Linear(5184, self.nb_dense, bias=True) + self.bvp_fc2 = nn.Linear(self.nb_dense, self.out_size_bvp, bias=True) + + # Resp Fully Connected Layers + self.resp_fc1 = nn.Linear(5184, self.nb_dense, bias=True) + self.resp_fc2 = nn.Linear(self.nb_dense, self.out_size_resp, bias=True) + + + def forward(self, inputs, params=None): + + big_input = inputs[0] # big res + small_input = inputs[1] # small res + + # reshape Big + nt, c, h, w = big_input.size() + n_batch = nt // self.n_segment + big_input = big_input.view(n_batch, self.n_segment, c, h, w) + big_input = torch.moveaxis(big_input, 1, 2) # color channel to idx 1, sequence channel to idx 2 + big_input = big_input[:, :, 0, :, :] # use only first frame in sequences + + + # Big Conv block 1 + b1 = nn.functional.relu(self.big_conv1(big_input)) + b2 = nn.functional.relu(self.big_conv2(b1)) + b3 = self.big_avg_pooling1(b2) + b4 = self.big_dropout1(b3) + + # Big Conv block 2 + b5 = nn.functional.relu(self.big_conv3(b4)) + b6 = nn.functional.relu(self.big_conv4(b5)) + b7 = self.big_avg_pooling2(b6) + b8 = self.big_dropout2(b7) + + # Big Conv block 3 + b9 = nn.functional.relu(self.big_conv5(b8)) + b10 = nn.functional.relu(self.big_conv6(b9)) + b11 = self.big_avg_pooling3(b10) + b12 = self.big_dropout3(b11) + + # Reformat Big Shape For Concat w/ Small Branch + b13 = torch.stack((b12, b12, b12), 2) #TODO: this is hardcoded for num_segs = 3: change this... + b14 = torch.moveaxis(b13, 1, 2) + bN, bD, bC, bH, bW = b14.size() + b15 = b14.reshape(int(bN*bD), bC, bH, bW) + + # Small Conv block 1 + s1 = self.TSM_1(small_input) + s2 = nn.functional.relu(self.small_conv1(s1)) + s3 = self.TSM_2(s2) + s4 = nn.functional.relu(self.small_conv2(s3)) + + # Small Conv block 2 + s5 = self.TSM_3(s4) + s6 = nn.functional.relu(self.small_conv3(s5)) + s7 = self.TSM_4(s6) + s8 = nn.functional.relu(self.small_conv4(s7)) + + # Shared Layers + concat = b15 + s8 # sum layers + + # share1 = concat.view(concat.size(0), -1) # flatten entire tensors + share1 = concat.reshape(concat.size(0), -1) + + # AU Output Layers + aufc1 = nn.functional.relu(self.au_fc1(share1)) + au_out = self.au_fc2(aufc1) + + # BVP Output Layers + bvpfc1 = nn.functional.relu(self.bvp_fc1(share1)) + bvp_out = self.bvp_fc2(bvpfc1) + + # Resp Output Layers + respfc1 = nn.functional.relu(self.resp_fc1(share1)) + resp_out = self.resp_fc2(respfc1) + + return au_out, bvp_out, resp_out + + diff --git a/neural_methods/model/DeepPhys.py b/neural_methods/model/DeepPhys.py new file mode 100644 index 0000000000000000000000000000000000000000..8d53021e38bdd5805d806318e8a8e7911b6ab358 --- /dev/null +++ b/neural_methods/model/DeepPhys.py @@ -0,0 +1,125 @@ +"""DeepPhys - 2D Convolutional Attention Network. +DeepPhys: Video-Based Physiological Measurement Using Convolutional Attention Networks +ECCV, 2018 +Weixuan Chen, Daniel McDuff +""" + +import torch +import torch.nn as nn + + +class Attention_mask(nn.Module): + def __init__(self): + super(Attention_mask, self).__init__() + + def forward(self, x): + xsum = torch.sum(x, dim=2, keepdim=True) + xsum = torch.sum(xsum, dim=3, keepdim=True) + xshape = tuple(x.size()) + return x / xsum * xshape[2] * xshape[3] * 0.5 + + def get_config(self): + """May be generated manually. """ + config = super(Attention_mask, self).get_config() + return config + + +class DeepPhys(nn.Module): + + def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25, + dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, img_size=36): + """Definition of DeepPhys. + Args: + in_channels: the number of input channel. Default: 3 + img_size: height/width of each frame. Default: 36. + Returns: + DeepPhys model. + """ + super(DeepPhys, self).__init__() + self.in_channels = in_channels + self.kernel_size = kernel_size + self.dropout_rate1 = dropout_rate1 + self.dropout_rate2 = dropout_rate2 + self.pool_size = pool_size + self.nb_filters1 = nb_filters1 + self.nb_filters2 = nb_filters2 + self.nb_dense = nb_dense + # Motion branch convs + self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), + bias=True) + self.motion_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True) + self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), + bias=True) + self.motion_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True) + # Apperance branch convs + self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, + padding=(1, 1), bias=True) + self.apperance_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True) + self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, + padding=(1, 1), bias=True) + self.apperance_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True) + # Attention layers + self.apperance_att_conv1 = nn.Conv2d(self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True) + self.attn_mask_1 = Attention_mask() + self.apperance_att_conv2 = nn.Conv2d(self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True) + self.attn_mask_2 = Attention_mask() + # Avg pooling + self.avg_pooling_1 = nn.AvgPool2d(self.pool_size) + self.avg_pooling_2 = nn.AvgPool2d(self.pool_size) + self.avg_pooling_3 = nn.AvgPool2d(self.pool_size) + # Dropout layers + self.dropout_1 = nn.Dropout(self.dropout_rate1) + self.dropout_2 = nn.Dropout(self.dropout_rate1) + self.dropout_3 = nn.Dropout(self.dropout_rate1) + self.dropout_4 = nn.Dropout(self.dropout_rate2) + # Dense layers + if img_size == 36: + self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True) + elif img_size == 72: + self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True) + elif img_size == 96: + self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True) + else: + raise Exception('Unsupported image size') + self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True) + + def forward(self, inputs, params=None): + + diff_input = inputs[:, :3, :, :] + raw_input = inputs[:, 3:, :, :] + + d1 = torch.tanh(self.motion_conv1(diff_input)) + d2 = torch.tanh(self.motion_conv2(d1)) + + r1 = torch.tanh(self.apperance_conv1(raw_input)) + r2 = torch.tanh(self.apperance_conv2(r1)) + + g1 = torch.sigmoid(self.apperance_att_conv1(r2)) + g1 = self.attn_mask_1(g1) + gated1 = d2 * g1 + + d3 = self.avg_pooling_1(gated1) + d4 = self.dropout_1(d3) + + r3 = self.avg_pooling_2(r2) + r4 = self.dropout_2(r3) + + d5 = torch.tanh(self.motion_conv3(d4)) + d6 = torch.tanh(self.motion_conv4(d5)) + + r5 = torch.tanh(self.apperance_conv3(r4)) + r6 = torch.tanh(self.apperance_conv4(r5)) + + g2 = torch.sigmoid(self.apperance_att_conv2(r6)) + g2 = self.attn_mask_2(g2) + gated2 = d6 * g2 + + d7 = self.avg_pooling_3(gated2) + d8 = self.dropout_3(d7) + d9 = d8.view(d8.size(0), -1) + d10 = torch.tanh(self.final_dense_1(d9)) + d11 = self.dropout_4(d10) + out = self.final_dense_2(d11) + + return out + diff --git a/neural_methods/model/EfficientPhys.py b/neural_methods/model/EfficientPhys.py new file mode 100644 index 0000000000000000000000000000000000000000..781acbf599b8babfd2f9aac1f3f71095f057d852 --- /dev/null +++ b/neural_methods/model/EfficientPhys.py @@ -0,0 +1,128 @@ +"""EfficientPhys: Enabling Simple, Fast and Accurate Camera-Based Vitals Measurement +Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV 2023) +Xin Liu, Brial Hill, Ziheng Jiang, Shwetak Patel, Daniel McDuff +""" + +import torch +import torch.nn as nn + + +class Attention_mask(nn.Module): + def __init__(self): + super(Attention_mask, self).__init__() + + def forward(self, x): + xsum = torch.sum(x, dim=2, keepdim=True) + xsum = torch.sum(xsum, dim=3, keepdim=True) + xshape = tuple(x.size()) + return x / xsum * xshape[2] * xshape[3] * 0.5 + + def get_config(self): + """May be generated manually. """ + config = super(Attention_mask, self).get_config() + return config + + +class TSM(nn.Module): + def __init__(self, n_segment=10, fold_div=3): + super(TSM, self).__init__() + self.n_segment = n_segment + self.fold_div = fold_div + + def forward(self, x): + nt, c, h, w = x.size() + n_batch = nt // self.n_segment + x = x.view(n_batch, self.n_segment, c, h, w) + fold = c // self.fold_div + out = torch.zeros_like(x) + out[:, :-1, :fold] = x[:, 1:, :fold] # shift left + out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right + out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift + return out.view(nt, c, h, w) + + +class EfficientPhys(nn.Module): + + def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25, + dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20, img_size=36, channel='raw'): + super(EfficientPhys, self).__init__() + self.in_channels = in_channels + self.kernel_size = kernel_size + self.dropout_rate1 = dropout_rate1 + self.dropout_rate2 = dropout_rate2 + self.pool_size = pool_size + self.nb_filters1 = nb_filters1 + self.nb_filters2 = nb_filters2 + self.nb_dense = nb_dense + # TSM layers + self.TSM_1 = TSM(n_segment=frame_depth) + self.TSM_2 = TSM(n_segment=frame_depth) + self.TSM_3 = TSM(n_segment=frame_depth) + self.TSM_4 = TSM(n_segment=frame_depth) + # Motion branch convs + self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), + bias=True) + self.motion_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True) + self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), + bias=True) + self.motion_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True) + # Attention layers + self.apperance_att_conv1 = nn.Conv2d(self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True) + self.attn_mask_1 = Attention_mask() + self.apperance_att_conv2 = nn.Conv2d(self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True) + self.attn_mask_2 = Attention_mask() + # Avg pooling + self.avg_pooling_1 = nn.AvgPool2d(self.pool_size) + self.avg_pooling_2 = nn.AvgPool2d(self.pool_size) + self.avg_pooling_3 = nn.AvgPool2d(self.pool_size) + # Dropout layers + self.dropout_1 = nn.Dropout(self.dropout_rate1) + self.dropout_2 = nn.Dropout(self.dropout_rate1) + self.dropout_3 = nn.Dropout(self.dropout_rate1) + self.dropout_4 = nn.Dropout(self.dropout_rate2) + # Dense layers + if img_size == 36: + self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True) + elif img_size == 72: + self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True) + elif img_size == 96: + self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True) + else: + raise Exception('Unsupported image size') + self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True) + self.batch_norm = nn.BatchNorm2d(3) + self.channel = channel + + def forward(self, inputs, params=None): + inputs = torch.diff(inputs, dim=0) + inputs = self.batch_norm(inputs) + + network_input = self.TSM_1(inputs) + d1 = torch.tanh(self.motion_conv1(network_input)) + d1 = self.TSM_2(d1) + d2 = torch.tanh(self.motion_conv2(d1)) + + g1 = torch.sigmoid(self.apperance_att_conv1(d2)) + g1 = self.attn_mask_1(g1) + gated1 = d2 * g1 + + d3 = self.avg_pooling_1(gated1) + d4 = self.dropout_1(d3) + + d4 = self.TSM_3(d4) + d5 = torch.tanh(self.motion_conv3(d4)) + d5 = self.TSM_4(d5) + d6 = torch.tanh(self.motion_conv4(d5)) + + g2 = torch.sigmoid(self.apperance_att_conv2(d6)) + g2 = self.attn_mask_2(g2) + gated2 = d6 * g2 + + d7 = self.avg_pooling_3(gated2) + d8 = self.dropout_3(d7) + d9 = d8.view(d8.size(0), -1) + d10 = torch.tanh(self.final_dense_1(d9)) + d11 = self.dropout_4(d10) + out = self.final_dense_2(d11) + + return out diff --git a/neural_methods/model/FactorizePhys/FSAM.py b/neural_methods/model/FactorizePhys/FSAM.py new file mode 100644 index 0000000000000000000000000000000000000000..71a5a53947fbfc0bd7901ff382f1458a6122501a --- /dev/null +++ b/neural_methods/model/FactorizePhys/FSAM.py @@ -0,0 +1,530 @@ +""" +FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing +NeurIPS 2024 +Jitesh Joshi, Sos S. Agaian, and Youngjun Cho +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.batchnorm import _BatchNorm +import numpy as np +import neurokit2 as nk + + +class _MatrixDecompositionBase(nn.Module): + def __init__(self, device, md_config, debug=False, dim="3D"): + super().__init__() + + self.dim = dim + self.md_type = md_config["MD_TYPE"] + if dim == "3D": + self.transform = md_config["MD_TRANSFORM"] + self.S = md_config["MD_S"] + self.R = md_config["MD_R"] + self.debug = debug + + self.train_steps = md_config["MD_STEPS"] + self.eval_steps = md_config["MD_STEPS"] + + self.inv_t = md_config["INV_T"] + self.eta = md_config["ETA"] + + self.rand_init = md_config["RAND_INIT"] + self.device = device + + # print('Dimension:', self.dim) + # print('S', self.S) + # print('D', self.D) + # print('R', self.R) + # print('train_steps', self.train_steps) + # print('eval_steps', self.eval_steps) + # print('inv_t', self.inv_t) + # print('eta', self.eta) + # print('rand_init', self.rand_init) + + def _build_bases(self, B, S, D, R): + raise NotImplementedError + + def local_step(self, x, bases, coef): + raise NotImplementedError + + @torch.no_grad() + def local_inference(self, x, bases): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(x.transpose(1, 2), bases) + coef = F.softmax(self.inv_t * coef, dim=-1) + + steps = self.train_steps if self.training else self.eval_steps + for _ in range(steps): + bases, coef = self.local_step(x, bases, coef) + + return bases, coef + + def compute_coef(self, x, bases, coef): + raise NotImplementedError + + def forward(self, x, return_bases=False): + + if self.debug: + print("Org x.shape", x.shape) + + if self.dim == "3D": # (B, C, T, H, W) -> (B * S, D, N) + B, C, T, H, W = x.shape + + # t = Time, k = Channels, a & B = height and width + if self.transform.lower() == "t_kab": + # # dimension of vector of our interest is T (rPPG signal as T dimension), so forming this as vector + # # From spatial and channel dimension, which are features, only 2-4 shall be enough to generate the approximated attention matrix + D = T // self.S + N = C * H * W + + elif self.transform.lower() == "tk_ab": + D = T * C // self.S + N = H * W + + elif self.transform.lower() == "k_tab": + D = C // self.S + N = T * H * W + + else: + print("Invalid MD_TRANSFORM specified:", self.transform) + exit() + + # # smoothening the temporal dimension + # x = x.view(B * self.S, N, D) + # # print("Intermediate-1 x", x.shape) + + # sample_1 = x[:, :, 0].unsqueeze(2) + # sample_2 = x[:, :, -1].unsqueeze(2) + # x = torch.cat([sample_1, x, sample_2], dim=2) + # gaussian_kernel = [1.0, 1.0, 1.0] + # kernels = torch.FloatTensor([[gaussian_kernel]]).repeat(N, N, 1).to(self.device) + # bias = torch.FloatTensor(torch.zeros(N)).to(self.device) + # x = F.conv1d(x, kernels, bias=bias, padding="valid") + # x = (x - x.min()) / (x.max() - x.min()) + + # x = x.permute(0, 2, 1) + # # print("Intermediate-2 x", x.shape) + + x = x.view(B * self.S, D, N) + + elif self.dim == "2D": # (B, C, H, W) -> (B * S, D, N) + B, C, H, W = x.shape + D = C // self.S + N = H * W + x = x.view(B * self.S, D, N) + + elif self.dim == "2D_TSM": # (B*frame_depth, C, H, W) -> (B, D, N) + B, C, H, W = x.shape + BN = B + B = B // self.S + D = self.S + N = C * H * W + x = x.view(B, D, N) + self.S = 1 # re-setting this for local inference + + elif self.dim == "1D": # (B, C, L) -> (B * S, D, N) + B, C, L = x.shape + D = L // self.S + N = C + x = x.view(B * self.S, D, N) + + else: + print("Dimension not supported") + exit() + + if self.debug: + print("MD_Type", self.md_type) + print("MD_S", self.S) + print("MD_D", D) + print("MD_N", N) + print("MD_R", self.R) + print("MD_TRAIN_STEPS", self.train_steps) + print("MD_EVAL_STEPS", self.eval_steps) + print("x.view(B * self.S, D, N)", x.shape) + + if not self.rand_init and not hasattr(self, 'bases'): + bases = self._build_bases(1, self.S, D, self.R) + self.register_buffer('bases', bases) + + # (S, D, R) -> (B * S, D, R) + if self.rand_init: + bases = self._build_bases(B, self.S, D, self.R) + else: + bases = self.bases.repeat(B, 1, 1).to(self.device) + + bases, coef = self.local_inference(x, bases) + + # (B * S, N, R) + coef = self.compute_coef(x, bases, coef) + + # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) + x = torch.bmm(bases, coef.transpose(1, 2)) + + + if self.dim == "3D": + + apply_smoothening = False + if apply_smoothening: + # smoothening the temporal dimension + x = x.view(B, D * self.S, N) #Joining temporal dimension for contiguous smoothening + # print("Intermediate-0 x", x.shape) + x = x.permute(0, 2, 1) + # print("Intermediate-1 x", x.shape) + + sample_1 = x[:, :, 0].unsqueeze(2) + # sample_2 = x[:, :, 0].unsqueeze(2) + sample_3 = x[:, :, -1].unsqueeze(2) + # sample_4 = x[:, :, -1].unsqueeze(2) + x = torch.cat([sample_1, x, sample_3], dim=2) + # x = torch.cat([sample_1, sample_2, x, sample_3, sample_4], dim=2) + # gaussian_kernel = [0.25, 0.50, 0.75, 0.50, 0.25] + # gaussian_kernel = [0.33, 0.66, 1.00, 0.66, 0.33] + # gaussian_kernel = [0.3, 0.7, 1.0, 0.7, 0.3] + # gaussian_kernel = [0.3, 1.0, 1.0, 1.0, 0.3] + # gaussian_kernel = [0.20, 0.80, 1.00, 0.80, 0.20] + # gaussian_kernel = [1.0, 1.0, 1.0] + gaussian_kernel = [0.8, 1.0, 0.8] + kernels = torch.FloatTensor([[gaussian_kernel]]).repeat(N, N, 1).to(self.device) + bias = torch.FloatTensor(torch.zeros(N)).to(self.device) + x = F.conv1d(x, kernels, bias=bias, padding="valid") + # x = (x - x.min()) / (x.max() - x.min()) + # x = (x - x.mean()) / (x.std()) + # x = x - x.min() + x = (x - x.min())/(x.std()) + + # print("Intermediate-2 x", x.shape) + + # (B * S, D, N) -> (B, C, T, H, W) + x = x.view(B, C, T, H, W) + elif self.dim == "2D": + # (B * S, D, N) -> (B, C, H, W) + x = x.view(B, C, H, W) + + elif self.dim == "2D_TSM": + # (B, D, N) -> (B, C, H, W) + x = x.view(BN, C, H, W) + + else: + # (B * S, D, N) -> (B, C, L) + x = x.view(B, C, L) + + # (B * L, D, R) -> (B, L, N, D) + bases = bases.view(B, self.S, D, self.R) + + if not self.rand_init and not self.training and not return_bases: + self.online_update(bases) + + # if not self.rand_init or return_bases: + # return x, bases + # else: + return x + + @torch.no_grad() + def online_update(self, bases): + # (B, S, D, R) -> (S, D, R) + update = bases.mean(dim=0) + self.bases += self.eta * (update - self.bases) + self.bases = F.normalize(self.bases, dim=1) + + +class NMF(_MatrixDecompositionBase): + def __init__(self, device, md_config, debug=False, dim="3D"): + super().__init__(device, md_config, debug=debug, dim=dim) + self.device = device + self.inv_t = 1 + + def _build_bases(self, B, S, D, R): + # bases = torch.rand((B * S, D, R)).to(self.device) + bases = torch.ones((B * S, D, R)).to(self.device) + bases = F.normalize(bases, dim=1) + + return bases + + @torch.no_grad() + def local_step(self, x, bases, coef): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # Multiplicative Update + coef = coef * numerator / (denominator + 1e-6) + + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + numerator = torch.bmm(x, coef) + # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) + denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) + # Multiplicative Update + bases = bases * numerator / (denominator + 1e-6) + + return bases, coef + + def compute_coef(self, x, bases, coef): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # multiplication update + coef = coef * numerator / (denominator + 1e-6) + + return coef + + +class VQ(_MatrixDecompositionBase): + def __init__(self, device, md_config, debug=False, dim="3D"): + super().__init__(device, md_config, debug=debug, dim=dim) + self.device = device + + def _build_bases(self, B, S, D, R): + # bases = torch.randn((B * S, D, R)).to(self.device) + bases = torch.ones((B * S, D, R)).to(self.device) + bases = F.normalize(bases, dim=1) + return bases + + @torch.no_grad() + def local_step(self, x, bases, _): + # (B * S, D, N), normalize x along D (for cosine similarity) + std_x = F.normalize(x, dim=1) + + # (B * S, D, R), normalize bases along D (for cosine similarity) + std_bases = F.normalize(bases, dim=1, eps=1e-6) + + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(std_x.transpose(1, 2), std_bases) + + # softmax along R + coef = F.softmax(self.inv_t * coef, dim=-1) + + # normalize along N + coef = coef / (1e-6 + coef.sum(dim=1, keepdim=True)) + + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + bases = torch.bmm(x, coef) + + return bases, coef + + + def compute_coef(self, x, bases, _): + with torch.no_grad(): + # (B * S, D, N) -> (B * S, 1, N) + x_norm = x.norm(dim=1, keepdim=True) + + # (B * S, D, N) / (B * S, 1, N) -> (B * S, D, N) + std_x = x / (1e-6 + x_norm) + + # (B * S, D, R), normalize bases along D (for cosine similarity) + std_bases = F.normalize(bases, dim=1, eps=1e-6) + + # (B * S, N, D)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(std_x.transpose(1, 2), std_bases) + + # softmax along R + coef = F.softmax(self.inv_t * coef, dim=-1) + + return coef + + +class ConvBNReLU(nn.Module): + @classmethod + def _same_paddings(cls, kernel_size, dim): + if dim == "3D": + if kernel_size == (1, 1, 1): + return (0, 0, 0) + elif kernel_size == (3, 3, 3): + return (1, 1, 1) + elif dim == "2D" or dim == "2D_TSM": + if kernel_size == (1, 1): + return (0, 0) + elif kernel_size == (3, 3): + return (1, 1) + else: + if kernel_size == 1: + return 0 + elif kernel_size == 3: + return 1 + + def __init__(self, in_c, out_c, dim, + kernel_size=1, stride=1, padding='same', + dilation=1, groups=1, act='relu', apply_bn=False, apply_act=True): + super().__init__() + + self.apply_bn = apply_bn + self.apply_act = apply_act + self.dim = dim + if dilation == 1: + if self.dim == "3D": + dilation = (1, 1, 1) + elif self.dim == "2D" or dim == "2D_TSM": + dilation = (1, 1) + else: + dilation = 1 + + if kernel_size == 1: + if self.dim == "3D": + kernel_size = (1, 1, 1) + elif self.dim == "2D" or dim == "2D_TSM": + kernel_size = (1, 1) + else: + kernel_size = 1 + + if stride == 1: + if self.dim == "3D": + stride = (1, 1, 1) + elif self.dim == "2D" or dim == "2D_TSM": + stride = (1, 1) + else: + stride = 1 + + if padding == 'same': + padding = self._same_paddings(kernel_size, dim) + + if self.dim == "3D": + self.conv = nn.Conv3d(in_c, out_c, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + groups=groups, + bias=False) + elif self.dim == "2D" or dim == "2D_TSM": + self.conv = nn.Conv2d(in_c, out_c, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + groups=groups, + bias=False) + else: + self.conv = nn.Conv1d(in_c, out_c, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + groups=groups, + bias=False) + + if act == "sigmoid": + self.act = nn.Sigmoid() + else: + self.act = nn.ReLU(inplace=True) + + if self.apply_bn: + if self.dim == "3D": + self.bn = nn.InstanceNorm3d(out_c) + elif self.dim == "2D" or dim == "2D_TSM": + self.bn = nn.InstanceNorm2d(out_c) + else: + self.bn = nn.InstanceNorm1d(out_c) + + def forward(self, x): + x = self.conv(x) + if self.apply_act: + x = self.act(x) + if self.apply_bn: + x = self.bn(x) + return x + + +class FeaturesFactorizationModule(nn.Module): + def __init__(self, inC, device, md_config, dim="3D", debug=False): + super().__init__() + + self.device = device + self.dim = dim + md_type = md_config["MD_TYPE"] + align_C = md_config["align_channels"] # inC // 2 # // 2 #// 8 + + if self.dim == "3D": + if "nmf" in md_type.lower(): + self.pre_conv_block = nn.Sequential( + nn.Conv3d(inC, align_C, (1, 1, 1)), + nn.ReLU(inplace=True)) + else: + self.pre_conv_block = nn.Conv3d(inC, align_C, (1, 1, 1)) + elif self.dim == "2D" or self.dim == "2D_TSM": + if "nmf" in md_type.lower(): + self.pre_conv_block = nn.Sequential( + nn.Conv2d(inC, align_C, (1, 1)), + nn.ReLU(inplace=True) + ) + else: + self.pre_conv_block = nn.Conv2d(inC, align_C, (1, 1)) + elif self.dim == "1D": + if "nmf" in md_type.lower(): + self.pre_conv_block = nn.Sequential( + nn.Conv1d(inC, align_C, 1), + nn.ReLU(inplace=True) + ) + else: + self.pre_conv_block = nn.Conv1d(inC, align_C, 1) + else: + print("Dimension not supported") + + if "nmf" in md_type.lower(): + self.md_block = NMF(self.device, md_config, dim=self.dim, debug=debug) + elif "vq" in md_type.lower(): + self.md_block = VQ(self.device, md_config, dim=self.dim, debug=debug) + else: + print("Unknown type specified for MD_TYPE:", md_type) + exit() + + if self.dim == "3D": + if "nmf" in md_type.lower(): + self.post_conv_block = nn.Sequential( + ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1), + nn.Conv3d(align_C, inC, 1, bias=False) + ) + else: + self.post_conv_block = nn.Sequential( + ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False), + nn.Conv3d(align_C, inC, 1, bias=False) + ) + elif self.dim == "2D" or self.dim == "2D_TSM": + if "nmf" in md_type.lower(): + self.post_conv_block = nn.Sequential( + ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1), + nn.Conv2d(align_C, inC, 1, bias=False) + ) + else: + self.post_conv_block = nn.Sequential( + ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False), + nn.Conv2d(align_C, inC, 1, bias=False) + ) + else: + if "nmf" in md_type.lower(): + self.post_conv_block = nn.Sequential( + ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1), + nn.Conv1d(align_C, inC, 1, bias=False) + ) + else: + self.post_conv_block = nn.Sequential( + ConvBNReLU(align_C, align_C, dim=self.dim, kernel_size=1, apply_act=False), + nn.Conv1d(align_C, inC, 1, bias=False) + ) + + self._init_weight() + + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d): + N = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels + m.weight.data.normal_(0, np.sqrt(2. / N)) + elif isinstance(m, nn.Conv2d): + N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, np.sqrt(2. / N)) + elif isinstance(m, nn.Conv1d): + N = m.kernel_size[0] * m.out_channels + m.weight.data.normal_(0, np.sqrt(2. / N)) + elif isinstance(m, _BatchNorm): + m.weight.data.fill_(1) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.pre_conv_block(x) + att = self.md_block(x) + dist = torch.dist(x, att) + att = self.post_conv_block(att) + + return att, dist + + def online_update(self, bases): + if hasattr(self.md_block, 'online_update'): + self.md_block.online_update(bases) + diff --git a/neural_methods/model/FactorizePhys/FactorizePhys.py b/neural_methods/model/FactorizePhys/FactorizePhys.py new file mode 100644 index 0000000000000000000000000000000000000000..16b211c665fe61c2a6f82f9e81e1adb692e7050c --- /dev/null +++ b/neural_methods/model/FactorizePhys/FactorizePhys.py @@ -0,0 +1,251 @@ +""" +FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing +NeurIPS 2024 +Jitesh Joshi, Sos S. Agaian, and Youngjun Cho +""" + +import torch +import torch.nn as nn +from neural_methods.model.FactorizePhys.FSAM import FeaturesFactorizationModule + +nf = [8, 12, 16] + +model_config = { + "MD_FSAM": True, + "MD_TYPE": "NMF", + "MD_TRANSFORM": "T_KAB", + "MD_R": 1, + "MD_S": 1, + "MD_STEPS": 4, + "MD_INFERENCE": False, + "MD_RESIDUAL": False, + "INV_T": 1, + "ETA": 0.9, + "RAND_INIT": True, + "in_channels": 3, + "data_channels": 4, + "align_channels": nf[2] // 2, + "height": 72, + "weight": 72, + "batch_size": 4, + "frames": 160, + "debug": False, + "assess_latency": False, + "num_trials": 20, + "visualize": False, + "ckpt_path": "", + "data_path": "", + "label_path": "" +} + + +class ConvBlock3D(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride, padding): + super(ConvBlock3D, self).__init__() + self.conv_block_3d = nn.Sequential( + nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding=padding, bias=False), + nn.Tanh(), + nn.InstanceNorm3d(out_channel), + ) + + def forward(self, x): + return self.conv_block_3d(x) + + +class rPPG_FeatureExtractor(nn.Module): + def __init__(self, inCh, dropout_rate=0.1, debug=False): + super(rPPG_FeatureExtractor, self).__init__() + # inCh, out_channel, kernel_size, stride, padding + + self.debug = debug + # Input: #B, inCh, 160, 72, 72 + self.FeatureExtractor = nn.Sequential( + ConvBlock3D(inCh, nf[0], [3, 3, 3], [1, 1, 1], [1, 1, 1]), #B, nf[0], 160, 72, 72 + ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 160, 35, 35 + ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 33, 33 + nn.Dropout3d(p=dropout_rate), + + ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 31, 31 + ConvBlock3D(nf[1], nf[2], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[2], 160, 15, 15 + ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 13, 13 + nn.Dropout3d(p=dropout_rate), + ) + + def forward(self, x): + voxel_embeddings = self.FeatureExtractor(x) + if self.debug: + print("rPPG Feature Extractor") + print(" voxel_embeddings.shape", voxel_embeddings.shape) + return voxel_embeddings + + +class BVP_Head(nn.Module): + def __init__(self, md_config, device, dropout_rate=0.1, debug=False): + super(BVP_Head, self).__init__() + self.debug = debug + + self.use_fsam = md_config["MD_FSAM"] + self.md_type = md_config["MD_TYPE"] + self.md_infer = md_config["MD_INFERENCE"] + self.md_res = md_config["MD_RESIDUAL"] + + self.conv_block = nn.Sequential( + ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 11, 11 + ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 9, 9 + ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 160, 7, 7 + nn.Dropout3d(p=dropout_rate), + ) + + if self.use_fsam: + inC = nf[2] + self.fsam = FeaturesFactorizationModule(inC, device, md_config, dim="3D", debug=debug) + self.fsam_norm = nn.InstanceNorm3d(inC) + self.bias1 = nn.Parameter(torch.tensor(1.0), requires_grad=True).to(device) + else: + inC = nf[2] + + self.final_layer = nn.Sequential( + ConvBlock3D(inC, nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 160, 5, 5 + ConvBlock3D(nf[1], nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 160, 3, 3 + nn.Conv3d(nf[0], 1, (3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), #B, 1, 160, 1, 1 + ) + + + def forward(self, voxel_embeddings, batch, length): + + if self.debug: + print("BVP Head") + print(" voxel_embeddings.shape", voxel_embeddings.shape) + + voxel_embeddings = self.conv_block(voxel_embeddings) + + if (self.md_infer or self.training or self.debug) and self.use_fsam: + if "NMF" in self.md_type: + att_mask, appx_error = self.fsam(voxel_embeddings - voxel_embeddings.min()) # to make it positive (>= 0) + else: + att_mask, appx_error = self.fsam(voxel_embeddings) + + if self.debug: + print("att_mask.shape", att_mask.shape) + + # # directly use att_mask ---> difficult to converge without Residual connection. Needs high rank + # factorized_embeddings = self.fsam_norm(att_mask) + + # # Residual connection: + # factorized_embeddings = voxel_embeddings + self.fsam_norm(att_mask) + + if self.md_res: + # Multiplication with Residual connection + x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1) + factorized_embeddings = self.fsam_norm(x) + factorized_embeddings = voxel_embeddings + factorized_embeddings + else: + # Multiplication + x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1) + factorized_embeddings = self.fsam_norm(x) + + # # Concatenate + # factorized_embeddings = torch.cat([voxel_embeddings, self.fsam_norm(x)], dim=1) + + x = self.final_layer(factorized_embeddings) + + else: + x = self.final_layer(voxel_embeddings) + + rPPG = x.view(-1, length) + + if self.debug: + print(" rPPG.shape", rPPG.shape) + + if (self.md_infer or self.training or self.debug) and self.use_fsam: + return rPPG, factorized_embeddings, appx_error + else: + return rPPG + + + +class FactorizePhys(nn.Module): + def __init__(self, frames, md_config, in_channels=3, dropout=0.1, device=torch.device("cpu"), debug=False): + super(FactorizePhys, self).__init__() + self.debug = debug + + self.in_channels = in_channels + if self.in_channels == 1 or self.in_channels == 3: + self.norm = nn.InstanceNorm3d(self.in_channels) + elif self.in_channels == 4: + self.rgb_norm = nn.InstanceNorm3d(3) + self.thermal_norm = nn.InstanceNorm3d(1) + else: + print("Unsupported input channels") + + self.use_fsam = md_config["MD_FSAM"] + self.md_infer = md_config["MD_INFERENCE"] + + for key in model_config: + if key not in md_config: + md_config[key] = model_config[key] + + if self.debug: + print("nf:", nf) + + self.rppg_feature_extractor = rPPG_FeatureExtractor(self.in_channels, dropout_rate=dropout, debug=debug) + + self.rppg_head = BVP_Head(md_config, device=device, dropout_rate=dropout, debug=debug) + + + def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32] + + [batch, channel, length, width, height] = x.shape + + # if self.in_channels == 1: + # x = x[:, :, :-1, :, :] + # else: + # x = torch.diff(x, dim=2) + + x = torch.diff(x, dim=2) + + if self.debug: + print("Input.shape", x.shape) + + if self.in_channels == 1: + x = self.norm(x[:, -1:, :, :, :]) + elif self.in_channels == 3: + x = self.norm(x[:, :3, :, :, :]) + elif self.in_channels == 4: + rgb_x = self.rgb_norm(x[:, :3, :, :, :]) + thermal_x = self.thermal_norm(x[:, -1:, :, :, :]) + x = torch.concat([rgb_x, thermal_x], dim = 1) + else: + try: + print("Specified input channels:", self.in_channels) + print("Data channels", channel) + assert self.in_channels <= channel + except: + print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels") + print("Default or specified channels:", self.in_channels) + print("Data channels [B, C, N, W, H]", x.shape) + print("Exiting") + exit() + + if self.debug: + print("Diff Normalized shape", x.shape) + + voxel_embeddings = self.rppg_feature_extractor(x) + + if (self.md_infer or self.training or self.debug) and self.use_fsam: + rPPG, factorized_embeddings, appx_error = self.rppg_head(voxel_embeddings, batch, length-1) + else: + rPPG = self.rppg_head(voxel_embeddings, batch, length-1) + + # if self.debug: + # print("rppg_feats.shape", rppg_feats.shape) + + # rPPG = rppg_feats.view(-1, length-1) + + if self.debug: + print("rPPG.shape", rPPG.shape) + + if (self.md_infer or self.training or self.debug) and self.use_fsam: + return rPPG, voxel_embeddings, factorized_embeddings, appx_error + else: + return rPPG, voxel_embeddings \ No newline at end of file diff --git a/neural_methods/model/FactorizePhys/FactorizePhysBig.py b/neural_methods/model/FactorizePhys/FactorizePhysBig.py new file mode 100644 index 0000000000000000000000000000000000000000..10d616ce8b85a76e9f1e37a9934d32f961b7de4c --- /dev/null +++ b/neural_methods/model/FactorizePhys/FactorizePhysBig.py @@ -0,0 +1,251 @@ +""" +FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing +NeurIPS 2024 +Jitesh Joshi, Sos S. Agaian, and Youngjun Cho +""" + +import torch +import torch.nn as nn +from neural_methods.model.FactorizePhys.FSAM import FeaturesFactorizationModule + +nf = [8, 12, 16] + +model_config = { + "MD_FSAM": True, + "MD_TYPE": "NMF", + "MD_TRANSFORM": "T_KAB", + "MD_R": 1, + "MD_S": 1, + "MD_STEPS": 4, + "MD_INFERENCE": False, + "MD_RESIDUAL": False, + "INV_T": 1, + "ETA": 0.9, + "RAND_INIT": True, + "in_channels": 3, + "data_channels": 4, + "align_channels": nf[2] // 2, + "height": 128, + "weight": 128, + "batch_size": 4, + "frames": 240, + "debug": False, + "assess_latency": False, + "num_trials": 20, + "visualize": False, + "ckpt_path": "", + "data_path": "", + "label_path": "" +} + + +class ConvBlock3D(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride, padding): + super(ConvBlock3D, self).__init__() + self.conv_block_3d = nn.Sequential( + nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding=padding, bias=False), + nn.Tanh(), + nn.InstanceNorm3d(out_channel), + ) + + def forward(self, x): + return self.conv_block_3d(x) + + +class rPPG_FeatureExtractor(nn.Module): + def __init__(self, inCh, dropout_rate=0.1, debug=False): + super(rPPG_FeatureExtractor, self).__init__() + # inCh, out_channel, kernel_size, stride, padding + + self.debug = debug + # Input: #B, inCh, 240, 128, 128 + self.FeatureExtractor = nn.Sequential( + ConvBlock3D(inCh, nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 240, 126, 126 + ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 240, 62, 62 + ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[1], 240, 60, 60 + nn.Dropout3d(p=dropout_rate), + + ConvBlock3D(nf[1], nf[1], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[1], 240, 29, 29 + ConvBlock3D(nf[1], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 27, 27 + ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 25, 25 + nn.Dropout3d(p=dropout_rate), + ) + + def forward(self, x): + voxel_embeddings = self.FeatureExtractor(x) + if self.debug: + print("rPPG Feature Extractor") + print(" voxel_embeddings.shape", voxel_embeddings.shape) + return voxel_embeddings + + +class BVP_Head(nn.Module): + def __init__(self, md_config, device, dropout_rate=0.1, debug=False): + super(BVP_Head, self).__init__() + self.debug = debug + + self.use_fsam = md_config["MD_FSAM"] + self.md_type = md_config["MD_TYPE"] + self.md_infer = md_config["MD_INFERENCE"] + self.md_res = md_config["MD_RESIDUAL"] + + self.conv_block = nn.Sequential( + ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 2, 2], [1, 0, 0]), #B, nf[2], 240, 12, 12 + ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 10, 10 + ConvBlock3D(nf[2], nf[2], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[2], 240, 8, 8 + nn.Dropout3d(p=dropout_rate), + ) + + if self.use_fsam: + inC = nf[2] + self.fsam = FeaturesFactorizationModule(inC, device, md_config, dim="3D", debug=debug) + self.fsam_norm = nn.InstanceNorm3d(inC) + self.bias1 = nn.Parameter(torch.tensor(1.0), requires_grad=True).to(device) + else: + inC = nf[2] + + self.final_layer = nn.Sequential( + ConvBlock3D(inC, nf[1], [3, 4, 4], [1, 1, 1], [1, 0, 0]), #B, nf[1], 240, 5, 5 + ConvBlock3D(nf[1], nf[0], [3, 3, 3], [1, 1, 1], [1, 0, 0]), #B, nf[0], 240, 3, 3 + nn.Conv3d(nf[0], 1, (3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), #B, 1, 240, 1, 1 + ) + + + def forward(self, voxel_embeddings, batch, length): + + if self.debug: + print("BVP Head") + print(" voxel_embeddings.shape", voxel_embeddings.shape) + + if (self.md_infer or self.training or self.debug) and self.use_fsam: + if "NMF" in self.md_type: + att_mask, appx_error = self.fsam(voxel_embeddings - voxel_embeddings.min()) # to make it positive (>= 0) + else: + att_mask, appx_error = self.fsam(voxel_embeddings) + + if self.debug: + print("att_mask.shape", att_mask.shape) + + # # directly use att_mask ---> difficult to converge without Residual connection. Needs high rank + # factorized_embeddings = self.fsam_norm(att_mask) + + # # Residual connection: + # factorized_embeddings = voxel_embeddings + self.fsam_norm(att_mask) + + if self.md_res: + # Multiplication with Residual connection + x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1) + factorized_embeddings = self.fsam_norm(x) + factorized_embeddings = voxel_embeddings + factorized_embeddings + else: + # Multiplication + x = torch.mul(voxel_embeddings - voxel_embeddings.min() + self.bias1, att_mask - att_mask.min() + self.bias1) + factorized_embeddings = self.fsam_norm(x) + + # # Concatenate + # factorized_embeddings = torch.cat([voxel_embeddings, self.fsam_norm(x)], dim=1) + + x = self.conv_block(factorized_embeddings) + x = self.final_layer(x) + + else: + x = self.conv_block(voxel_embeddings) + x = self.final_layer(x) + + rPPG = x.view(-1, length) + + if self.debug: + print(" rPPG.shape", rPPG.shape) + + if (self.md_infer or self.training or self.debug) and self.use_fsam: + return rPPG, factorized_embeddings, appx_error + else: + return rPPG + + + +class FactorizePhysBig(nn.Module): + def __init__(self, frames, md_config, in_channels=3, dropout=0.1, device=torch.device("cpu"), debug=False): + super(FactorizePhysBig, self).__init__() + self.debug = debug + + self.in_channels = in_channels + if self.in_channels == 1 or self.in_channels == 3: + self.norm = nn.InstanceNorm3d(self.in_channels) + elif self.in_channels == 4: + self.rgb_norm = nn.InstanceNorm3d(3) + self.thermal_norm = nn.InstanceNorm3d(1) + else: + print("Unsupported input channels") + + self.use_fsam = md_config["MD_FSAM"] + self.md_infer = md_config["MD_INFERENCE"] + + for key in model_config: + if key not in md_config: + md_config[key] = model_config[key] + + if self.debug: + print("nf:", nf) + + self.rppg_feature_extractor = rPPG_FeatureExtractor(self.in_channels, dropout_rate=dropout, debug=debug) + + self.rppg_head = BVP_Head(md_config, device=device, dropout_rate=dropout, debug=debug) + + + def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32] + + [batch, channel, length, width, height] = x.shape + + # if self.in_channels == 1: + # x = x[:, :, :-1, :, :] + # else: + # x = torch.diff(x, dim=2) + + x = torch.diff(x, dim=2) + + if self.debug: + print("Input.shape", x.shape) + + if self.in_channels == 1: + x = self.norm(x[:, -1:, :, :, :]) + elif self.in_channels == 3: + x = self.norm(x[:, :3, :, :, :]) + elif self.in_channels == 4: + rgb_x = self.rgb_norm(x[:, :3, :, :, :]) + thermal_x = self.thermal_norm(x[:, -1:, :, :, :]) + x = torch.concat([rgb_x, thermal_x], dim = 1) + else: + try: + print("Specified input channels:", self.in_channels) + print("Data channels", channel) + assert self.in_channels <= channel + except: + print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels") + print("Default or specified channels:", self.in_channels) + print("Data channels [B, C, N, W, H]", x.shape) + print("Exiting") + exit() + + if self.debug: + print("Diff Normalized shape", x.shape) + + voxel_embeddings = self.rppg_feature_extractor(x) + + if (self.md_infer or self.training or self.debug) and self.use_fsam: + rPPG, factorized_embeddings, appx_error = self.rppg_head(voxel_embeddings, batch, length-1) + else: + rPPG = self.rppg_head(voxel_embeddings, batch, length-1) + + # if self.debug: + # print("rppg_feats.shape", rppg_feats.shape) + + # rPPG = rppg_feats.view(-1, length-1) + + if self.debug: + print("rPPG.shape", rPPG.shape) + + if (self.md_infer or self.training or self.debug) and self.use_fsam: + return rPPG, voxel_embeddings, factorized_embeddings, appx_error + else: + return rPPG, voxel_embeddings \ No newline at end of file diff --git a/neural_methods/model/FactorizePhys/__init__.py b/neural_methods/model/FactorizePhys/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/neural_methods/model/FactorizePhys/test_FactorizePhys.py b/neural_methods/model/FactorizePhys/test_FactorizePhys.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9d3298e5ae507d502ab29cbc1d87cad8e5ffc0 --- /dev/null +++ b/neural_methods/model/FactorizePhys/test_FactorizePhys.py @@ -0,0 +1,286 @@ +""" +FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing +NeurIPS 2024 +Jitesh Joshi, Sos S. Agaian, and Youngjun Cho +""" + +import time +import numpy as np +from pathlib import Path +import matplotlib.pyplot as plt +import matplotlib.cm as cm +from scipy.signal import resample +import torch +import torch.nn as nn +import torch.nn.functional as F +from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys +# from torch.utils.tensorboard import SummaryWriter + +model_config = { + "MD_FSAM": True, + "MD_TYPE": "NMF", + "MD_TRANSFORM": "T_KAB", + "MD_R": 1, + "MD_S": 1, + "MD_STEPS": 4, + "MD_INFERENCE": True, + "MD_RESIDUAL": True, + "in_channels": 3, + "data_channels": 4, + "height": 72, + "weight": 72, + "batch_size": 2, + "frames": 160, + "debug": True, + "assess_latency": False, + "num_trials": 20, + "visualize": False, + "ckpt_path": "./final_model_release/iBVP_FactorizePhys_FSAM_Res.pth", + "data_path": "/mnt/sda/data/prep/iBVP_Dataset/iBVP_RGB_160_72x72", + "label_path": "/mnt/sda/data/prep/iBVP_Dataset/iBVP_RGB_160_72x72" +} + + +class TestFactorizePhysBig(object): + def __init__(self) -> None: + self.ckpt_path = Path(model_config["ckpt_path"]) + self.data_path = Path(model_config["data_path"]) + self.label_path = Path(model_config["label_path"]) + + self.use_fsam = model_config["MD_FSAM"] + self.md_infer = model_config["MD_INFERENCE"] + + self.batch_size = model_config["batch_size"] + self.frames = model_config["frames"] + self.in_channels = model_config["in_channels"] + self.data_channels = model_config["data_channels"] + self.height = model_config["height"] + self.width = model_config["weight"] + self.debug = bool(model_config["debug"]) + self.assess_latency = bool(model_config["assess_latency"]) + self.visualize = model_config["visualize"] + + if self.visualize: + self.data_files = list(sorted(self.data_path.rglob("*input*.npy"))) + self.label_files = list(sorted(self.data_path.rglob("*label*.npy"))) + self.num_trials = len(self.data_files) + + self.plot_dir = Path.cwd().joinpath("plots").joinpath("inference") + self.plot_dir.mkdir(parents=True, exist_ok=True) + + self.attention_map_dir = self.plot_dir.joinpath("attention_maps").joinpath(self.data_path.name).joinpath(self.ckpt_path.name) + self.attention_map_dir.mkdir(parents=True, exist_ok=True) + + else: + if self.assess_latency: + self.num_trials = model_config["num_trials"] + else: + self.num_trials = 1 + + if torch.cuda.is_available(): + self.device = torch.device(0) + else: + self.device = torch.device("cpu") + + md_config = {} + md_config["FRAME_NUM"] = model_config["frames"] + md_config["MD_S"] = model_config["MD_S"] + md_config["MD_R"] = model_config["MD_R"] + md_config["MD_STEPS"] = model_config["MD_STEPS"] + md_config["MD_FSAM"] = model_config["MD_FSAM"] + md_config["MD_TYPE"] = model_config["MD_TYPE"] + md_config["MD_TRANSFORM"] = model_config["MD_TRANSFORM"] + md_config["MD_INFERENCE"] = model_config["MD_INFERENCE"] + md_config["MD_RESIDUAL"] = model_config["MD_RESIDUAL"] + + if self.visualize: + self.net = nn.DataParallel(FactorizePhys(frames=self.frames, md_config=md_config, + device=self.device, in_channels=self.in_channels, debug=self.debug), device_ids=[0]).to(self.device) + self.net.load_state_dict(torch.load(str(self.ckpt_path), map_location=self.device)) + else: + self.net = FactorizePhys(frames=self.frames, md_config=md_config, + device=self.device, in_channels=self.in_channels, debug=self.debug).to(self.device) + + self.net.eval() + if self.assess_latency: + self.time_vec = [] + + if self.debug: + self.appx_error_list = [] + + + def load_data(self, num_trial): + + if self.visualize: + self.np_data = np.load(str(self.data_files[num_trial])) + self.np_label = np.load(str(self.label_files[num_trial])) + self.np_label = np.expand_dims(self.np_label, 0) + self.np_label = torch.tensor(self.np_label) + + # print("Chunk data shape", self.np_data.shape) + # print("Chunk label shape", self.np_label.shape) + # print("Min Max of input data:", np.min(self.np_data), np.max(self.np_data)) + # exit() + + self.test_data = np.transpose(self.np_data, (3, 0, 1, 2)) + self.test_data = torch.from_numpy(self.test_data) + self.test_data = self.test_data.unsqueeze(0) + + last_frame = torch.unsqueeze(self.test_data[:, :, -1, :, :], 2).repeat(1, 1, 1, 1, 1) + self.test_data = torch.cat((self.test_data, last_frame), 2) + self.test_data = self.test_data.to(torch.float32).to(self.device) + else: + self.test_data = torch.rand(self.batch_size, self.data_channels, self.frames + 1, self.height, self.width) + self.test_data = self.test_data.to(torch.float32).to(self.device) + + + def run_inference(self, num_trial): + + if self.visualize: + print("Processing:", self.data_files[num_trial].name) + if self.assess_latency: + t0 = time.time() + + if (self.md_infer or self.net.training or self.debug) and self.use_fsam: + self.pred, self.vox_embed, self.factorized_embed, self.appx_error = self.net(self.test_data) + else: + self.pred, self.vox_embed = self.net(self.test_data) + + if self.assess_latency: + t1 = time.time() + self.time_vec.append(t1-t0) + + if self.debug: + print("pred.shape", self.pred.shape) + if (self.md_infer or self.net.training or self.debug) and self.use_fsam: + self.appx_error_list.append(self.appx_error.item()) + + if self.visualize: + self.save_attention_maps(num_trial) + + + def save_attention_maps(self, num_trial): + b, channels, enc_frames, enc_height, enc_width = self.vox_embed.shape + label_matrix = self.np_label.unsqueeze(0).repeat(1, channels, 1).unsqueeze( + 2).unsqueeze(2).permute(0, 1, 4, 3, 2).repeat(1, 1, 1, enc_height, enc_width) + label_matrix = label_matrix.to(device=self.device) + corr_matrix = F.cosine_similarity(self.vox_embed, label_matrix, dim=2).abs() + + # avg_emb = torch.mean(self.vox_embed, dim=1) + # b, enc_frames, enc_height, enc_width = avg_emb.shape + # label_matrix = np_label.unsqueeze(0).unsqueeze(2).permute(0, 3, 2, 1).repeat(1, 1, enc_height, enc_width) + # label_matrix = label_matrix.to(device=device) + # corr_matrix = F.cosine_similarity(avg_emb, label_matrix, dim=1) + + if self.debug: + print("corr_matrix.shape", corr_matrix.shape) + print("self.test_data.shape:", self.test_data.shape) + print("self.vox_embed.shape:", self.vox_embed.shape) + + self.test_data = self.test_data.detach().cpu().numpy() + self.vox_embed = self.vox_embed.detach().cpu().numpy() + corr_matrix = corr_matrix.detach().cpu().numpy() + + fig, ax = plt.subplots(4, 4, figsize=[16, 16]) + fig.tight_layout() + + ax[0, 0].imshow(self.np_data[enc_frames//2, ...].astype(np.uint8)) + ax[0, 0].axis('off') + cmap = "coolwarm" + + ch = 0 + ax[0, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[0, 1].axis('off') + + ch = 1 + ax[0, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[0, 2].axis('off') + + ch = 2 + ax[0, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[0, 3].axis('off') + + ch = 3 + ax[1, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[1, 0].axis('off') + + ch = 4 + ax[1, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[1, 1].axis('off') + + ch = 5 + ax[1, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[1, 2].axis('off') + + ch = 6 + ax[1, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[1, 3].axis('off') + + ch = 7 + ax[2, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[2, 0].axis('off') + + ch = 8 + ax[2, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[2, 1].axis('off') + + ch = 9 + ax[2, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[2, 2].axis('off') + + ch = 10 + ax[2, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[2, 3].axis('off') + + ch = 11 + ax[3, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[3, 0].axis('off') + + ch = 12 + ax[3, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[3, 1].axis('off') + + ch = 13 + ax[3, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[3, 2].axis('off') + + ch = 14 + ax[3, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[3, 3].axis('off') + + # plt.show() + plt.savefig(str(self.attention_map_dir.joinpath(str(self.data_files[num_trial].name.replace(".npy", "_attention_map.jpg"))))) + plt.close(fig) + + + def output_summary_results(self): + if self.assess_latency: + print("Median time: ", np.median(self.time_vec)) + plt.plot(self.time_vec) + plt.savefig(str(self.plot_dir.joinpath("Latency.jpg"))) + + if self.debug: + if (self.md_infer or self.net.training or self.debug) and self.use_fsam: + print("Median error:", np.median(self.appx_error_list)) + + pytorch_total_params = sum(p.numel() for p in self.net.parameters()) + print("Total parameters = ", pytorch_total_params) + + pytorch_trainable_params = sum(p.numel() + for p in self.net.parameters() if p.requires_grad) + print("Trainable parameters = ", pytorch_trainable_params) + + +if __name__ == "__main__": + + testObj = TestFactorizePhysBig() + + print("testObj.num_trials:", testObj.num_trials) + for trial_num in range(testObj.num_trials): + testObj.load_data(trial_num) + testObj.run_inference(trial_num) + + testObj.output_summary_results() + + # writer.add_graph(net, test_data) + # writer.close() \ No newline at end of file diff --git a/neural_methods/model/FactorizePhys/test_FactorizePhysBig.py b/neural_methods/model/FactorizePhys/test_FactorizePhysBig.py new file mode 100644 index 0000000000000000000000000000000000000000..43177d0096fec3ff672cdb46ab07e6a86d5f76aa --- /dev/null +++ b/neural_methods/model/FactorizePhys/test_FactorizePhysBig.py @@ -0,0 +1,292 @@ +""" +FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing +NeurIPS 2024 +Jitesh Joshi, Sos S. Agaian, and Youngjun Cho +""" + +import time +import numpy as np +from pathlib import Path +import matplotlib.pyplot as plt +import matplotlib.cm as cm +from scipy.signal import resample +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig +# from torch.utils.tensorboard import SummaryWriter + +model_config = { + "MD_FSAM": True, + "MD_TYPE": "NMF", + "MD_TRANSFORM": "T_KAB", + "MD_R": 1, + "MD_S": 1, + "MD_STEPS": 4, + "MD_INFERENCE": True, + "MD_RESIDUAL": True, + "in_channels": 3, + "data_channels": 4, + "height": 128, + "weight": 128, + "batch_size": 1, + "frames": 240, + "debug": True, + "assess_latency": False, + "num_trials": 20, + "visualize": False, + # "ckpt_path": "./final_model_release/UBFC-rPPG_Intra_FactorizePhys_Base_HighRes.pth", + "ckpt_path": "./final_model_release/UBFC-rPPG_Intra_FactorizePhys_FSAM_Res_HighRes.pth", + "data_path": "/mnt/sda/data/prep/UBFC-rPPG/UBFC-rPPG_Raw_240_128x128", + "label_path": "/mnt/sda/data/prep/UBFC-rPPG/UBFC-rPPG_Raw_240_128x128" +} + +# default `log_dir` is "runs" - we'll be more specific here +# writer = SummaryWriter('runs/FactorizePhys') + +class TestFactorizePhysBig(object): + def __init__(self) -> None: + self.ckpt_path = Path(model_config["ckpt_path"]) + self.data_path = Path(model_config["data_path"]) + self.label_path = Path(model_config["label_path"]) + + self.use_fsam = model_config["MD_FSAM"] + self.md_infer = model_config["MD_INFERENCE"] + + self.batch_size = model_config["batch_size"] + self.frames = model_config["frames"] + self.in_channels = model_config["in_channels"] + self.data_channels = model_config["data_channels"] + self.height = model_config["height"] + self.width = model_config["weight"] + self.debug = bool(model_config["debug"]) + self.assess_latency = bool(model_config["assess_latency"]) + self.visualize = model_config["visualize"] + + if self.visualize: + # self.data_files = list(sorted(self.data_path.rglob("*subject12*input*.npy"))) + # self.label_files = list(sorted(self.data_path.rglob("*subject12*label*.npy"))) + self.data_files = list(sorted(self.data_path.rglob("*input*.npy"))) + self.label_files = list(sorted(self.data_path.rglob("*label*.npy"))) + self.num_trials = len(self.data_files) + + self.plot_dir = Path.cwd().joinpath("plots").joinpath("inference") + self.plot_dir.mkdir(parents=True, exist_ok=True) + + self.attention_map_dir = self.plot_dir.joinpath("attention_maps").joinpath(self.data_path.name).joinpath(self.ckpt_path.name) + self.attention_map_dir.mkdir(parents=True, exist_ok=True) + + else: + if self.assess_latency: + self.num_trials = model_config["num_trials"] + else: + self.num_trials = 1 + + if torch.cuda.is_available(): + self.device = torch.device(0) + else: + self.device = torch.device("cpu") + + md_config = {} + md_config["FRAME_NUM"] = model_config["frames"] + md_config["MD_S"] = model_config["MD_S"] + md_config["MD_R"] = model_config["MD_R"] + md_config["MD_STEPS"] = model_config["MD_STEPS"] + md_config["MD_FSAM"] = model_config["MD_FSAM"] + md_config["MD_TYPE"] = model_config["MD_TYPE"] + md_config["MD_TRANSFORM"] = model_config["MD_TRANSFORM"] + md_config["MD_INFERENCE"] = model_config["MD_INFERENCE"] + md_config["MD_RESIDUAL"] = model_config["MD_RESIDUAL"] + + if self.visualize: + self.net = nn.DataParallel(FactorizePhysBig(frames=self.frames, md_config=md_config, + device=self.device, in_channels=self.in_channels, debug=self.debug), device_ids=[0]).to(self.device) + self.net.load_state_dict(torch.load(str(self.ckpt_path), map_location=self.device)) + else: + self.net = FactorizePhysBig(frames=self.frames, md_config=md_config, + device=self.device, in_channels=self.in_channels, debug=self.debug).to(self.device) + + self.net.eval() + if self.assess_latency: + self.time_vec = [] + + if self.debug: + self.appx_error_list = [] + + + def load_data(self, num_trial): + + if self.visualize: + self.np_data = np.load(str(self.data_files[num_trial])) + self.np_label = np.load(str(self.label_files[num_trial])) + self.np_label = np.expand_dims(self.np_label, 0) + self.np_label = torch.tensor(self.np_label) + + # print("Chunk data shape", self.np_data.shape) + # print("Chunk label shape", self.np_label.shape) + # print("Min Max of input data:", np.min(self.np_data), np.max(self.np_data)) + # exit() + + self.test_data = np.transpose(self.np_data, (3, 0, 1, 2)) + self.test_data = torch.from_numpy(self.test_data) + self.test_data = self.test_data.unsqueeze(0) + + last_frame = torch.unsqueeze(self.test_data[:, :, -1, :, :], 2).repeat(1, 1, 1, 1, 1) + self.test_data = torch.cat((self.test_data, last_frame), 2) + self.test_data = self.test_data.to(torch.float32).to(self.device) + else: + self.test_data = torch.rand(self.batch_size, self.data_channels, self.frames + 1, self.height, self.width) + self.test_data = self.test_data.to(torch.float32).to(self.device) + + + def run_inference(self, num_trial): + + if self.visualize: + print("Processing:", self.data_files[num_trial].name) + if self.assess_latency: + t0 = time.time() + + if (self.md_infer or self.net.training or self.debug) and self.use_fsam: + self.pred, self.vox_embed, self.factorized_embed, self.appx_error = self.net(self.test_data) + else: + self.pred, self.vox_embed = self.net(self.test_data) + + if self.assess_latency: + t1 = time.time() + self.time_vec.append(t1-t0) + + if self.debug: + print("pred.shape", self.pred.shape) + if (self.md_infer or self.net.training or self.debug) and self.use_fsam: + self.appx_error_list.append(self.appx_error.item()) + + if self.visualize: + self.save_attention_maps(num_trial) + + + def save_attention_maps(self, num_trial): + b, channels, enc_frames, enc_height, enc_width = self.vox_embed.shape + label_matrix = self.np_label.unsqueeze(0).repeat(1, channels, 1).unsqueeze( + 2).unsqueeze(2).permute(0, 1, 4, 3, 2).repeat(1, 1, 1, enc_height, enc_width) + label_matrix = label_matrix.to(device=self.device) + corr_matrix = F.cosine_similarity(self.vox_embed, label_matrix, dim=2).abs() + + # avg_emb = torch.mean(self.vox_embed, dim=1) + # b, enc_frames, enc_height, enc_width = avg_emb.shape + # label_matrix = np_label.unsqueeze(0).unsqueeze(2).permute(0, 3, 2, 1).repeat(1, 1, enc_height, enc_width) + # label_matrix = label_matrix.to(device=device) + # corr_matrix = F.cosine_similarity(avg_emb, label_matrix, dim=1) + + if self.debug: + print("corr_matrix.shape", corr_matrix.shape) + print("self.test_data.shape:", self.test_data.shape) + print("self.vox_embed.shape:", self.vox_embed.shape) + + self.test_data = self.test_data.detach().cpu().numpy() + self.vox_embed = self.vox_embed.detach().cpu().numpy() + corr_matrix = corr_matrix.detach().cpu().numpy() + + fig, ax = plt.subplots(4, 4, figsize=[16, 16]) + fig.tight_layout() + cmap = "coolwarm" + + ax[0, 0].imshow(self.np_data[enc_frames//2, ...].astype(np.uint8)) + ax[0, 0].axis('off') + + ch = 0 + ax[0, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[0, 1].axis('off') + + ch = 1 + ax[0, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[0, 2].axis('off') + + ch = 2 + ax[0, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[0, 3].axis('off') + + ch = 3 + ax[1, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[1, 0].axis('off') + + ch = 4 + ax[1, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[1, 1].axis('off') + + ch = 5 + ax[1, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[1, 2].axis('off') + + ch = 6 + ax[1, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[1, 3].axis('off') + + ch = 7 + ax[2, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[2, 0].axis('off') + + ch = 8 + ax[2, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[2, 1].axis('off') + + ch = 9 + ax[2, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[2, 2].axis('off') + + ch = 10 + ax[2, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[2, 3].axis('off') + + ch = 11 + ax[3, 0].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[3, 0].axis('off') + + ch = 12 + ax[3, 1].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[3, 1].axis('off') + + ch = 13 + ax[3, 2].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[3, 2].axis('off') + + ch = 14 + ax[3, 3].imshow(corr_matrix[0, ch, :, :], cmap=cmap, vmin=0, vmax=1) + ax[3, 3].axis('off') + + # plt.show() + plt.savefig(str(self.attention_map_dir.joinpath(str(self.data_files[num_trial].name.replace(".npy", "_attention_map.jpg"))))) + plt.close(fig) + + + def output_summary_results(self): + if self.assess_latency: + print("Median time: ", np.median(self.time_vec)) + plt.plot(self.time_vec) + plt.savefig(str(self.plot_dir.joinpath("Latency.jpg"))) + + if self.debug: + if (self.md_infer or self.net.training or self.debug) and self.use_fsam: + print("Median error:", np.median(self.appx_error_list)) + + pytorch_total_params = sum(p.numel() for p in self.net.parameters()) + print("Total parameters = ", pytorch_total_params) + + pytorch_trainable_params = sum(p.numel() + for p in self.net.parameters() if p.requires_grad) + print("Trainable parameters = ", pytorch_trainable_params) + + +if __name__ == "__main__": + + testObj = TestFactorizePhysBig() + + print("testObj.num_trials:", testObj.num_trials) + for trial_num in range(testObj.num_trials): + testObj.load_data(trial_num) + testObj.run_inference(trial_num) + + testObj.output_summary_results() + + # writer.add_graph(net, test_data) + # writer.close() \ No newline at end of file diff --git a/neural_methods/model/PhysFormer.py b/neural_methods/model/PhysFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d96067399626cbbba749d771a0bfa3a5597785 --- /dev/null +++ b/neural_methods/model/PhysFormer.py @@ -0,0 +1,313 @@ +"""This file is a combination of Physformer.py and transformer_layer.py + in the official PhysFormer implementation here: + https://github.com/ZitongYu/PhysFormer + + model.py - Model and module class for ViT. + They are built to mirror those in the official Jax implementation. +""" + +import numpy as np +from typing import Optional +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F +import math + +def as_tuple(x): + return x if isinstance(x, tuple) else (x, x) + +''' +Temporal Center-difference based Convolutional layer (3D version) +theta: control the percentage of original convolution and centeral-difference convolution +''' +class CDC_T(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, + padding=1, dilation=1, groups=1, bias=False, theta=0.6): + + super(CDC_T, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.theta = theta + + def forward(self, x): + out_normal = self.conv(x) + + if math.fabs(self.theta - 0.0) < 1e-8: + return out_normal + else: + [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape + + # only CD works on temporal kernel size>1 + if self.conv.weight.shape[2] > 1: + kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum( + 2).sum(2) + kernel_diff = kernel_diff[:, :, None, None, None] + out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, + padding=0, dilation=self.conv.dilation, groups=self.conv.groups) + return out_normal - self.theta * out_diff + + else: + return out_normal + + +def split_last(x, shape): + "split the last dimension to given shape" + shape = list(shape) + assert shape.count(-1) <= 1 + if -1 in shape: + shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) + return x.view(*x.size()[:-1], *shape) + + +def merge_last(x, n_dims): + "merge the last n_dims to a dimension" + s = x.size() + assert n_dims > 1 and n_dims < len(s) + return x.view(*s[:-n_dims], -1) + +class MultiHeadedSelfAttention_TDC_gra_sharp(nn.Module): + """Multi-Headed Dot Product Attention with depth-wise Conv3d""" + def __init__(self, dim, num_heads, dropout, theta): + super().__init__() + + self.proj_q = nn.Sequential( + CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=theta), + nn.BatchNorm3d(dim), + ) + self.proj_k = nn.Sequential( + CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=theta), + nn.BatchNorm3d(dim), + ) + self.proj_v = nn.Sequential( + nn.Conv3d(dim, dim, 1, stride=1, padding=0, groups=1, bias=False), + ) + + self.drop = nn.Dropout(dropout) + self.n_heads = num_heads + self.scores = None # for visualization + + def forward(self, x, gra_sharp): # [B, 4*4*40, 128] + """ + x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) + mask : (B(batch_size) x S(seq_len)) + * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W + """ + # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) + + [B, P, C]=x.shape + x = x.transpose(1, 2).view(B, C, P//16, 4, 4) # [B, dim, 40, 4, 4] + q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) + q = q.flatten(2).transpose(1, 2) # [B, 4*4*40, dim] + k = k.flatten(2).transpose(1, 2) # [B, 4*4*40, dim] + v = v.flatten(2).transpose(1, 2) # [B, 4*4*40, dim] + + q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) + # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) + scores = q @ k.transpose(-2, -1) / gra_sharp + + scores = self.drop(F.softmax(scores, dim=-1)) + # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) + h = (scores @ v).transpose(1, 2).contiguous() + # -merge-> (B, S, D) + h = merge_last(h, 2) + self.scores = scores + return h, scores + + + + +class PositionWiseFeedForward_ST(nn.Module): + """FeedForward Neural Networks for each position""" + def __init__(self, dim, ff_dim): + super().__init__() + + self.fc1 = nn.Sequential( + nn.Conv3d(dim, ff_dim, 1, stride=1, padding=0, bias=False), + nn.BatchNorm3d(ff_dim), + nn.ELU(), + ) + + self.STConv = nn.Sequential( + nn.Conv3d(ff_dim, ff_dim, 3, stride=1, padding=1, groups=ff_dim, bias=False), + nn.BatchNorm3d(ff_dim), + nn.ELU(), + ) + + self.fc2 = nn.Sequential( + nn.Conv3d(ff_dim, dim, 1, stride=1, padding=0, bias=False), + nn.BatchNorm3d(dim), + ) + + def forward(self, x): # [B, 4*4*40, 128] + [B, P, C]=x.shape + x = x.transpose(1, 2).view(B, C, P//16, 4, 4) # [B, dim, 40, 4, 4] + x = self.fc1(x) # x [B, ff_dim, 40, 4, 4] + x = self.STConv(x) # x [B, ff_dim, 40, 4, 4] + x = self.fc2(x) # x [B, dim, 40, 4, 4] + x = x.flatten(2).transpose(1, 2) # [B, 4*4*40, dim] + + return x + +class Block_ST_TDC_gra_sharp(nn.Module): + """Transformer Block""" + def __init__(self, dim, num_heads, ff_dim, dropout, theta): + super().__init__() + self.attn = MultiHeadedSelfAttention_TDC_gra_sharp(dim, num_heads, dropout, theta) + self.proj = nn.Linear(dim, dim) + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.pwff = PositionWiseFeedForward_ST(dim, ff_dim) + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + self.drop = nn.Dropout(dropout) + + def forward(self, x, gra_sharp): + Atten, Score = self.attn(self.norm1(x), gra_sharp) + h = self.drop(self.proj(Atten)) + x = x + h + h = self.drop(self.pwff(self.norm2(x))) + x = x + h + return x, Score + +class Transformer_ST_TDC_gra_sharp(nn.Module): + """Transformer with Self-Attentive Blocks""" + def __init__(self, num_layers, dim, num_heads, ff_dim, dropout, theta): + super().__init__() + self.blocks = nn.ModuleList([ + Block_ST_TDC_gra_sharp(dim, num_heads, ff_dim, dropout, theta) for _ in range(num_layers)]) + + def forward(self, x, gra_sharp): + for block in self.blocks: + x, Score = block(x, gra_sharp) + return x, Score + +# stem_3DCNN + ST-ViT with local Depthwise Spatio-Temporal MLP +class ViT_ST_ST_Compact3_TDC_gra_sharp(nn.Module): + + def __init__( + self, + name: Optional[str] = None, + pretrained: bool = False, + patches: int = 16, + dim: int = 768, + ff_dim: int = 3072, + num_heads: int = 12, + num_layers: int = 12, + attention_dropout_rate: float = 0.0, + dropout_rate: float = 0.2, + representation_size: Optional[int] = None, + load_repr_layer: bool = False, + classifier: str = 'token', + #positional_embedding: str = '1d', + in_channels: int = 3, + frame: int = 160, + theta: float = 0.2, + image_size: Optional[int] = None, + ): + super().__init__() + + + self.image_size = image_size + self.frame = frame + self.dim = dim + + # Image and patch sizes + t, h, w = as_tuple(image_size) # tube sizes + ft, fh, fw = as_tuple(patches) # patch sizes, ft = 4 ==> 160/4=40 + gt, gh, gw = t//ft, h // fh, w // fw # number of patches + seq_len = gh * gw * gt + + # Patch embedding [4x16x16]conv + self.patch_embedding = nn.Conv3d(dim, dim, kernel_size=(ft, fh, fw), stride=(ft, fh, fw)) + + # Transformer + self.transformer1 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads, + ff_dim=ff_dim, dropout=dropout_rate, theta=theta) + # Transformer + self.transformer2 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads, + ff_dim=ff_dim, dropout=dropout_rate, theta=theta) + # Transformer + self.transformer3 = Transformer_ST_TDC_gra_sharp(num_layers=num_layers//3, dim=dim, num_heads=num_heads, + ff_dim=ff_dim, dropout=dropout_rate, theta=theta) + + + + self.Stem0 = nn.Sequential( + nn.Conv3d(3, dim//4, [1, 5, 5], stride=1, padding=[0,2,2]), + nn.BatchNorm3d(dim//4), + nn.ReLU(inplace=True), + nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)), + ) + + self.Stem1 = nn.Sequential( + nn.Conv3d(dim//4, dim//2, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(dim//2), + nn.ReLU(inplace=True), + nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)), + ) + self.Stem2 = nn.Sequential( + nn.Conv3d(dim//2, dim, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(dim), + nn.ReLU(inplace=True), + nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)), + ) + + self.upsample = nn.Sequential( + nn.Upsample(scale_factor=(2,1,1)), + nn.Conv3d(dim, dim, [3, 1, 1], stride=1, padding=(1,0,0)), + nn.BatchNorm3d(dim), + nn.ELU(), + ) + self.upsample2 = nn.Sequential( + nn.Upsample(scale_factor=(2,1,1)), + nn.Conv3d(dim, dim//2, [3, 1, 1], stride=1, padding=(1,0,0)), + nn.BatchNorm3d(dim//2), + nn.ELU(), + ) + + self.ConvBlockLast = nn.Conv1d(dim//2, 1, 1,stride=1, padding=0) + + + # Initialize weights + self.init_weights() + + @torch.no_grad() + def init_weights(self): + def _init(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal + if hasattr(m, 'bias') and m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0) + self.apply(_init) + + + def forward(self, x, gra_sharp): + + # b is batch number, c channels, t frame, fh frame height, and fw frame width + b, c, t, fh, fw = x.shape + + x = self.Stem0(x) + x = self.Stem1(x) + x = self.Stem2(x) # [B, 64, 160, 64, 64] + + x = self.patch_embedding(x) # [B, 64, 40, 4, 4] + x = x.flatten(2).transpose(1, 2) # [B, 40*4*4, 64] + + + Trans_features, Score1 = self.transformer1(x, gra_sharp) # [B, 4*4*40, 64] + Trans_features2, Score2 = self.transformer2(Trans_features, gra_sharp) # [B, 4*4*40, 64] + Trans_features3, Score3 = self.transformer3(Trans_features2, gra_sharp) # [B, 4*4*40, 64] + + # upsampling heads + #features_last = Trans_features3.transpose(1, 2).view(b, self.dim, 40, 4, 4) # [B, 64, 40, 4, 4] + features_last = Trans_features3.transpose(1, 2).view(b, self.dim, t//4, 4, 4) # [B, 64, 40, 4, 4] + + features_last = self.upsample(features_last) # x [B, 64, 7*7, 80] + features_last = self.upsample2(features_last) # x [B, 32, 7*7, 160] + + features_last = torch.mean(features_last,3) # x [B, 32, 160, 4] + features_last = torch.mean(features_last,3) # x [B, 32, 160] + rPPG = self.ConvBlockLast(features_last) # x [B, 1, 160] + + rPPG = rPPG.squeeze(1) + + return rPPG, Score1, Score2, Score3 diff --git a/neural_methods/model/PhysMamba.py b/neural_methods/model/PhysMamba.py new file mode 100644 index 0000000000000000000000000000000000000000..96ce07e7f5b6ef666051badbfdd2985daec75594 --- /dev/null +++ b/neural_methods/model/PhysMamba.py @@ -0,0 +1,246 @@ +import math +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_, DropPath +from mamba_ssm import Mamba +from torch.nn import functional as F + +class ChannelAttention3D(nn.Module): + def __init__(self, in_channels, reduction): + super(ChannelAttention3D, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool3d(1) + self.max_pool = nn.AdaptiveMaxPool3d(1) + + self.fc = nn.Sequential( + nn.Conv3d(in_channels, in_channels // reduction, 1, bias=False), + nn.ReLU(), + nn.Conv3d(in_channels // reduction, in_channels, 1, bias=False) + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc(self.avg_pool(x)) + max_out = self.fc(self.max_pool(x)) + out = avg_out + max_out + attention = self.sigmoid(out) + return x*attention + +class LateralConnection(nn.Module): + def __init__(self, fast_channels=32, slow_channels=64): + super(LateralConnection, self).__init__() + self.conv = nn.Sequential( + nn.Conv3d(fast_channels, slow_channels, [3, 1, 1], stride=[2, 1, 1], padding=[1,0,0]), + nn.BatchNorm3d(64), + nn.ReLU(), + ) + + def forward(self, slow_path, fast_path): + fast_path = self.conv(fast_path) + return fast_path + slow_path + +class CDC_T(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, + padding=1, dilation=1, groups=1, bias=False, theta=0.2): + + super(CDC_T, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.theta = theta + + def forward(self, x): + + out_normal = self.conv(x) + + if math.fabs(self.theta - 0.0) < 1e-8: + return out_normal + else: + [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape + + # only CD works on temporal kernel size>1 + if self.conv.weight.shape[2] > 1: + kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum( + 2).sum(2) + kernel_diff = kernel_diff[:, :, None, None, None] + out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, + padding=0, dilation=self.conv.dilation, groups=self.conv.groups) + return out_normal - self.theta * out_diff + + else: + return out_normal + +class MambaLayer(nn.Module): + def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, channel_token = False): + super(MambaLayer, self).__init__() + self.dim = dim + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + drop_path = 0 + self.mamba = Mamba( + d_model=dim, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + bimamba=True, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward_patch_token(self, x): + B, C, nf, H, W = x.shape + B, d_model = x.shape[:2] + assert d_model == self.dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2) + x_norm = self.norm1(x_flat) + x_mamba = self.mamba(x_norm) + x_out = self.norm2(x_flat + self.drop_path(x_mamba)) + out = x_out.transpose(-1, -2).reshape(B, d_model, *img_dims) + return out + + def forward(self, x): + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: + x = x.type(torch.float32) + out = self.forward_patch_token(x) + return out + +def conv_block(in_channels, out_channels, kernel_size, stride, padding, bn=True, activation='relu'): + layers = [nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)] + if bn: + layers.append(nn.BatchNorm3d(out_channels)) + if activation == 'relu': + layers.append(nn.ReLU(inplace=True)) + elif activation == 'elu': + layers.append(nn.ELU(inplace=True)) + return nn.Sequential(*layers) + + +class PhysMamba(nn.Module): + def __init__(self, theta=0.5, drop_rate1=0.25, drop_rate2=0.5, frames=128): + super(PhysMamba, self).__init__() + + self.ConvBlock1 = conv_block(3, 16, [1, 5, 5], stride=1, padding=[0, 2, 2]) + self.ConvBlock2 = conv_block(16, 32, [3, 3, 3], stride=1, padding=1) + self.ConvBlock3 = conv_block(32, 64, [3, 3, 3], stride=1, padding=1) + self.ConvBlock4 = conv_block(64, 64, [4, 1, 1], stride=[4, 1, 1], padding=0) + self.ConvBlock5 = conv_block(64, 32, [2, 1, 1], stride=[2, 1, 1], padding=0) + self.ConvBlock6 = conv_block(32, 32, [3, 1, 1], stride=1, padding=[1, 0, 0], activation='elu') + + # Temporal Difference Mamba Blocks + # Slow Stream + self.Block1 = self._build_block(64, theta) + self.Block2 = self._build_block(64, theta) + self.Block3 = self._build_block(64, theta) + # Fast Stream + self.Block4 = self._build_block(32, theta) + self.Block5 = self._build_block(32, theta) + self.Block6 = self._build_block(32, theta) + + # Upsampling + self.upsample1 = nn.Sequential( + nn.Upsample(scale_factor=(2,1,1)), + nn.Conv3d(64, 64, [3, 1, 1], stride=1, padding=(1,0,0)), + nn.BatchNorm3d(64), + nn.ELU(), + ) + self.upsample2 = nn.Sequential( + nn.Upsample(scale_factor=(2,1,1)), + nn.Conv3d(96, 48, [3, 1, 1], stride=1, padding=(1,0,0)), + nn.BatchNorm3d(48), + nn.ELU(), + ) + + self.ConvBlockLast = nn.Conv3d(48, 1, [1, 1, 1], stride=1, padding=0) + self.MaxpoolSpa = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)) + self.MaxpoolSpaTem = nn.MaxPool3d((2, 2, 2), stride=2) + + self.fuse_1 = LateralConnection(fast_channels=32, slow_channels=64) + self.fuse_2 = LateralConnection(fast_channels=32, slow_channels=64) + + self.drop_1 = nn.Dropout(drop_rate1) + self.drop_2 = nn.Dropout(drop_rate1) + self.drop_3 = nn.Dropout(drop_rate2) + self.drop_4 = nn.Dropout(drop_rate2) + self.drop_5 = nn.Dropout(drop_rate2) + self.drop_6 = nn.Dropout(drop_rate2) + + self.poolspa = nn.AdaptiveAvgPool3d((frames, 1, 1)) + + def _build_block(self, channels, theta): + return nn.Sequential( + CDC_T(channels, channels, theta=theta), + nn.BatchNorm3d(channels), + nn.ReLU(), + MambaLayer(dim=channels), + ChannelAttention3D(in_channels=channels, reduction=2), + ) + + def forward(self, x): + [batch, channel, length, width, height] = x.shape + + x = self.ConvBlock1(x) + x = self.MaxpoolSpa(x) + x = self.ConvBlock2(x) + x = self.ConvBlock3(x) + x = self.MaxpoolSpa(x) + + # Process streams + s_x = self.ConvBlock4(x) # Slow stream + f_x = self.ConvBlock5(x) # Fast stream + + # First set of blocks and fusion + s_x1 = self.Block1(s_x) + s_x1 = self.MaxpoolSpa(s_x1) + s_x1 = self.drop_1(s_x1) + + f_x1 = self.Block4(f_x) + f_x1 = self.MaxpoolSpa(f_x1) + f_x1 = self.drop_2(f_x1) + + s_x1 = self.fuse_1(s_x1,f_x1) # LateralConnection + + # Second set of blocks and fusion + s_x2 = self.Block2(s_x1) + s_x2 = self.MaxpoolSpa(s_x2) + s_x2 = self.drop_3(s_x2) + + f_x2 = self.Block5(f_x1) + f_x2 = self.MaxpoolSpa(f_x2) + f_x2 = self.drop_4(f_x2) + + s_x2 = self.fuse_2(s_x2,f_x2) # LateralConnection + + # Third blocks and upsampling + s_x3 = self.Block3(s_x2) + s_x3 = self.upsample1(s_x3) + s_x3 = self.drop_5(s_x3) + + f_x3 = self.Block6(f_x2) + f_x3 = self.ConvBlock6(f_x3) + f_x3 = self.drop_6(f_x3) + + # Final fusion and upsampling + x_fusion = torch.cat((f_x3, s_x3), dim=1) + x_final = self.upsample2(x_fusion) + + x_final = self.poolspa(x_final) + x_final = self.ConvBlockLast(x_final) + + rPPG = x_final.view(-1, length) + + return rPPG diff --git a/neural_methods/model/PhysNet.py b/neural_methods/model/PhysNet.py new file mode 100644 index 0000000000000000000000000000000000000000..126312c9b1abe54b3c92e3d8c48a633fdb342b0f --- /dev/null +++ b/neural_methods/model/PhysNet.py @@ -0,0 +1,124 @@ +""" PhysNet +We repulicate the net pipeline of the orginal paper, but set the input as diffnormalized data. +orginal source: +Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks +British Machine Vision Conference (BMVC)} 2019, +By Zitong Yu, 2019/05/05 +Only for research purpose, and commercial use is not allowed. +MIT License +Copyright (c) 2019 +""" + +import math +import pdb + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _triple + + +class PhysNet_padding_Encoder_Decoder_MAX(nn.Module): + def __init__(self, frames=128): + super(PhysNet_padding_Encoder_Decoder_MAX, self).__init__() + + self.ConvBlock1 = nn.Sequential( + nn.Conv3d(3, 16, [1, 5, 5], stride=1, padding=[0, 2, 2]), + nn.BatchNorm3d(16), + nn.ReLU(inplace=True), + ) + + self.ConvBlock2 = nn.Sequential( + nn.Conv3d(16, 32, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(32), + nn.ReLU(inplace=True), + ) + self.ConvBlock3 = nn.Sequential( + nn.Conv3d(32, 64, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + + self.ConvBlock4 = nn.Sequential( + nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + self.ConvBlock5 = nn.Sequential( + nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + self.ConvBlock6 = nn.Sequential( + nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + self.ConvBlock7 = nn.Sequential( + nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + self.ConvBlock8 = nn.Sequential( + nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + self.ConvBlock9 = nn.Sequential( + nn.Conv3d(64, 64, [3, 3, 3], stride=1, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + + self.upsample = nn.Sequential( + nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=[ + 4, 1, 1], stride=[2, 1, 1], padding=[1, 0, 0]), # [1, 128, 32] + nn.BatchNorm3d(64), + nn.ELU(), + ) + self.upsample2 = nn.Sequential( + nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=[ + 4, 1, 1], stride=[2, 1, 1], padding=[1, 0, 0]), # [1, 128, 32] + nn.BatchNorm3d(64), + nn.ELU(), + ) + + self.ConvBlock10 = nn.Conv3d(64, 1, [1, 1, 1], stride=1, padding=0) + + self.MaxpoolSpa = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)) + self.MaxpoolSpaTem = nn.MaxPool3d((2, 2, 2), stride=2) + + # self.poolspa = nn.AdaptiveMaxPool3d((frames,1,1)) # pool only spatial space + self.poolspa = nn.AdaptiveAvgPool3d((frames, 1, 1)) + + def forward(self, x): # Batch_size*[3, T, 128,128] + x_visual = x + [batch, channel, length, width, height] = x.shape + + x = self.ConvBlock1(x) # x [3, T, 128,128] + x = self.MaxpoolSpa(x) # x [16, T, 64,64] + + x = self.ConvBlock2(x) # x [32, T, 64,64] + x_visual6464 = self.ConvBlock3(x) # x [32, T, 64,64] + # x [32, T/2, 32,32] Temporal halve + x = self.MaxpoolSpaTem(x_visual6464) + + x = self.ConvBlock4(x) # x [64, T/2, 32,32] + x_visual3232 = self.ConvBlock5(x) # x [64, T/2, 32,32] + x = self.MaxpoolSpaTem(x_visual3232) # x [64, T/4, 16,16] + + x = self.ConvBlock6(x) # x [64, T/4, 16,16] + x_visual1616 = self.ConvBlock7(x) # x [64, T/4, 16,16] + x = self.MaxpoolSpa(x_visual1616) # x [64, T/4, 8,8] + + x = self.ConvBlock8(x) # x [64, T/4, 8, 8] + x = self.ConvBlock9(x) # x [64, T/4, 8, 8] + x = self.upsample(x) # x [64, T/2, 8, 8] + x = self.upsample2(x) # x [64, T, 8, 8] + + # x [64, T, 1,1] --> groundtruth left and right - 7 + x = self.poolspa(x) + x = self.ConvBlock10(x) # x [1, T, 1,1] + + rPPG = x.view(-1, length) + + return rPPG, x_visual, x_visual3232, x_visual1616 diff --git a/neural_methods/model/RhythmFormer.py b/neural_methods/model/RhythmFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..46739b66e93b1985e8f8d096c817738743a9f77d --- /dev/null +++ b/neural_methods/model/RhythmFormer.py @@ -0,0 +1,418 @@ +""" +RhythmFormer:Extracting rPPG Signals Based on Hierarchical Temporal Periodic Transformer +""" +from typing import Optional +import torch +from torch import nn, Tensor, LongTensor +from torch.nn import functional as F +import math +from typing import Tuple, Union +from timm.models.layers import trunc_normal_, DropPath + + + +""" +Adapted from here: https://github.com/rayleizhu/BiFormer +""" +import torch +from torch import Tensor, LongTensor , nn +import torch.nn.functional as F +from typing import Optional, Tuple + +def _grid2seq(x:Tensor, region_size:Tuple[int], num_heads:int): + """ + Args: + x: BCTHW tensor + region size: int + num_heads: number of attention heads + Return: + out: rearranged x, has a shape of (bs, nhead, nregion, reg_size, head_dim) + region_t, region_h, region_w: number of regions per t/col/row + """ + B, C, T, H, W = x.size() + region_t ,region_h, region_w = T//region_size[0], H//region_size[1], W//region_size[2] + x = x.view(B, num_heads, C//num_heads, region_t, region_size[0],region_h, region_size[1], region_w, region_size[2]) + x = torch.einsum('bmdtohpwq->bmthwopqd', x).flatten(2, 4).flatten(-4, -2) # (bs, nhead, nregion, reg_size, head_dim) + return x, region_t, region_h, region_w + + +def _seq2grid(x:Tensor, region_t:int, region_h:int, region_w:int, region_size:Tuple[int]): + """ + Args: + x: (bs, nhead, nregion, reg_size^2, head_dim) + Return: + x: (bs, C, T, H, W) + """ + bs, nhead, nregion, reg_size_square, head_dim = x.size() + x = x.view(bs, nhead, region_t, region_h, region_w, region_size[0], region_size[1], region_size[2], head_dim) + x = torch.einsum('bmthwopqd->bmdtohpwq', x).reshape(bs, nhead*head_dim, + region_t*region_size[0],region_h*region_size[1], region_w*region_size[2]) + return x + + +def video_regional_routing_attention_torch( + query:Tensor, key:Tensor, value:Tensor, scale:float, + region_graph:LongTensor, region_size:Tuple[int], + kv_region_size:Optional[Tuple[int]]=None, + auto_pad=False)->Tensor: + """ + Args: + query, key, value: (B, C, T, H, W) tensor + scale: the scale/temperature for dot product attention + region_graph: (B, nhead, t_q*h_q*w_q, topk) tensor, topk <= t_k*h_k*w_k + region_size: region/window size for queries, (rt, rh, rw) + key_region_size: optional, if None, key_region_size=region_size + Return: + output: (B, C, T, H, W) tensor + attn: (bs, nhead, q_nregion, reg_size, topk*kv_region_size) attention matrix + """ + kv_region_size = kv_region_size or region_size + bs, nhead, q_nregion, topk = region_graph.size() + + # # Auto pad to deal with any input size + # q_pad_b, q_pad_r, kv_pad_b, kv_pad_r = 0, 0, 0, 0 + # if auto_pad: + # _, _, Hq, Wq = query.size() + # q_pad_b = (region_size[0] - Hq % region_size[0]) % region_size[0] + # q_pad_r = (region_size[1] - Wq % region_size[1]) % region_size[1] + # if (q_pad_b > 0 or q_pad_r > 0): + # query = F.pad(query, (0, q_pad_r, 0, q_pad_b)) # zero padding + + # _, _, Hk, Wk = key.size() + # kv_pad_b = (kv_region_size[0] - Hk % kv_region_size[0]) % kv_region_size[0] + # kv_pad_r = (kv_region_size[1] - Wk % kv_region_size[1]) % kv_region_size[1] + # if (kv_pad_r > 0 or kv_pad_b > 0): + # key = F.pad(key, (0, kv_pad_r, 0, kv_pad_b)) # zero padding + # value = F.pad(value, (0, kv_pad_r, 0, kv_pad_b)) # zero padding + + # to sequence format, i.e. (bs, nhead, nregion, reg_size, head_dim) + query, q_region_t, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=nhead) + key, _, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=nhead) + value, _, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=nhead) + + # gather key and values. + # torch.gather does not support broadcasting, hence we do it manually + bs, nhead, kv_nregion, kv_region_size, head_dim = key.size() + broadcasted_region_graph = region_graph.view(bs, nhead, q_nregion, topk, 1, 1).\ + expand(-1, -1, -1, -1, kv_region_size, head_dim) + key_g = torch.gather(key.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\ + expand(-1, -1, query.size(2), -1, -1, -1), dim=3, + index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim) + value_g = torch.gather(value.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\ + expand(-1, -1, query.size(2), -1, -1, -1), dim=3, + index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim) + + # token-to-token attention + # (bs, nhead, q_nregion, reg_size, head_dim) @ (bs, nhead, q_nregion, head_dim, topk*kv_region_size) + # -> (bs, nhead, q_nregion, reg_size, topk*kv_region_size) + attn = (query * scale) @ key_g.flatten(-3, -2).transpose(-1, -2) + attn = torch.softmax(attn, dim=-1) + # (bs, nhead, q_nregion, reg_size, topk*kv_region_size) @ (bs, nhead, q_nregion, topk*kv_region_size, head_dim) + # -> (bs, nhead, q_nregion, reg_size, head_dim) + output = attn @ value_g.flatten(-3, -2) + + # to BCTHW format + output = _seq2grid(output, region_t=q_region_t, region_h=q_region_h, region_w=q_region_w, region_size=region_size) + + # remove paddings if needed + # if auto_pad and (q_pad_b > 0 or q_pad_r > 0): + # output = output[:, :, :Hq, :Wq] + + return output, attn + + + + +class CDC_T(nn.Module): + """ + The CDC_T Module is from here: https://github.com/ZitongYu/PhysFormer/model/transformer_layer.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, + padding=1, dilation=1, groups=1, bias=False, theta=0.6): + + super(CDC_T, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.theta = theta + + def forward(self, x): + out_normal = self.conv(x) + + if math.fabs(self.theta - 0.0) < 1e-8: + return out_normal + else: + # pdb.set_trace() + [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape + + # only CD works on temporal kernel size>1 + if self.conv.weight.shape[2] > 1: + kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum( + 2).sum(2) + kernel_diff = kernel_diff[:, :, None, None, None] + out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, + padding=0, dilation=self.conv.dilation, groups=self.conv.groups) + return out_normal - self.theta * out_diff + + else: + return out_normal + +class video_BRA(nn.Module): + + def __init__(self, dim, num_heads=8, t_patch=8, qk_scale=None, topk=4, side_dwconv=3, auto_pad=False, attn_backend='torch'): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, 'dim must be divisible by num_heads!' + self.head_dim = self.dim // self.num_heads + self.scale = qk_scale or self.dim ** -0.5 + self.topk = topk + self.t_patch = t_patch # frame of patch + ################side_dwconv (i.e. LCE in Shunted Transformer)########### + self.lepe = nn.Conv3d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \ + lambda x: torch.zeros_like(x) + ########################################## + self.qkv_linear = nn.Conv3d(self.dim, 3*self.dim, kernel_size=1) + self.output_linear = nn.Conv3d(self.dim, self.dim, kernel_size=1) + self.proj_q = nn.Sequential( + CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=0.2), + nn.BatchNorm3d(dim), + ) + self.proj_k = nn.Sequential( + CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=0.2), + nn.BatchNorm3d(dim), + ) + self.proj_v = nn.Sequential( + nn.Conv3d(dim, dim, 1, stride=1, padding=0, groups=1, bias=False), + ) + if attn_backend == 'torch': + self.attn_fn = video_regional_routing_attention_torch + else: + raise ValueError('CUDA implementation is not available yet. Please stay tuned.') + + def forward(self, x:Tensor): + + N, C, T, H, W = x.size() + t_region = max(4 // self.t_patch , 1) + region_size = (t_region, H//4 , W//4) + + # STEP 1: linear projection + q , k , v = self.proj_q(x) , self.proj_k(x) ,self.proj_v(x) + + # STEP 2: pre attention + q_r = F.avg_pool3d(q.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) + k_r = F.avg_pool3d(k.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) # ncthw + q_r:Tensor = q_r.permute(0, 2, 3, 4, 1).flatten(1, 3) # n(thw)c + k_r:Tensor = k_r.flatten(2, 4) # nc(thw) + a_r = q_r @ k_r # n(thw)(thw) + _, idx_r = torch.topk(a_r, k=self.topk, dim=-1) # n(thw)k + idx_r:LongTensor = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1) + + # STEP 3: refined attention + output, attn_mat = self.attn_fn(query=q, key=k, value=v, scale=self.scale, + region_graph=idx_r, region_size=region_size) + + output = output + self.lepe(v) # nctHW + output = self.output_linear(output) # nctHW + + return output + +class video_BiFormerBlock(nn.Module): + def __init__(self, dim, drop_path=0., num_heads=4, t_patch=1,qk_scale=None, topk=4, mlp_ratio=2, side_dwconv=5): + super().__init__() + self.t_patch = t_patch + self.norm1 = nn.BatchNorm3d(dim) + self.attn = video_BRA(dim=dim, num_heads=num_heads, t_patch=t_patch,qk_scale=qk_scale, topk=topk, side_dwconv=side_dwconv) + self.norm2 = nn.BatchNorm3d(dim) + self.mlp = nn.Sequential(nn.Conv3d(dim, int(mlp_ratio*dim), kernel_size=1), + nn.BatchNorm3d(int(mlp_ratio*dim)), + nn.GELU(), + nn.Conv3d(int(mlp_ratio*dim), int(mlp_ratio*dim), 3, stride=1, padding=1), + nn.BatchNorm3d(int(mlp_ratio*dim)), + nn.GELU(), + nn.Conv3d(int(mlp_ratio*dim), dim, kernel_size=1), + nn.BatchNorm3d(dim), + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class Fusion_Stem(nn.Module): + def __init__(self,apha=0.5,belta=0.5): + super(Fusion_Stem, self).__init__() + + self.stem11 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), + nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + ) + + self.stem12 = nn.Sequential(nn.Conv2d(12, 64, kernel_size=5, stride=2, padding=2), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + ) + + self.stem21 =nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + + self.stem22 =nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + + self.apha = apha + self.belta = belta + + def forward(self, x): + """Definition of Fusion_Stem. + Args: + x [N,D,C,H,W] + Returns: + fusion_x [N*D,C,H/4,W/4] + """ + N, D, C, H, W = x.shape + x1 = torch.cat([x[:,:1,:,:,:],x[:,:1,:,:,:],x[:,:D-2,:,:,:]],1) + x2 = torch.cat([x[:,:1,:,:,:],x[:,:D-1,:,:,:]],1) + x3 = x + x4 = torch.cat([x[:,1:,:,:,:],x[:,D-1:,:,:,:]],1) + x5 = torch.cat([x[:,2:,:,:,:],x[:,D-1:,:,:,:],x[:,D-1:,:,:,:]],1) + x_diff = self.stem12(torch.cat([x2-x1,x3-x2,x4-x3,x5-x4],2).view(N * D, 12, H, W)) + x3 = x3.contiguous().view(N * D, C, H, W) + x = self.stem11(x3) + + #fusion layer1 + x_path1 = self.apha*x + self.belta*x_diff + x_path1 = self.stem21(x_path1) + #fusion layer2 + x_path2 = self.stem22(x_diff) + x = self.apha*x_path1 + self.belta*x_path2 + + return x + +class TPT_Block(nn.Module): + def __init__(self, dim, depth, num_heads, t_patch, topk, + mlp_ratio=4., drop_path=0., side_dwconv=5): + super().__init__() + self.dim = dim + self.depth = depth + ############ downsample layers & upsample layers ##################### + self.downsample_layers = nn.ModuleList() + self.upsample_layers = nn.ModuleList() + self.layer_n = int(math.log(t_patch,2)) + for i in range(self.layer_n): + downsample_layer = nn.Sequential( + nn.BatchNorm3d(dim), + nn.Conv3d(dim , dim , kernel_size=(2, 1, 1), stride=(2, 1, 1)), + ) + self.downsample_layers.append(downsample_layer) + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=(2, 1, 1)), + nn.Conv3d(dim , dim , [3, 1, 1], stride=1, padding=(1, 0, 0)), + nn.BatchNorm3d(dim), + nn.ELU(), + ) + self.upsample_layers.append(upsample_layer) + ###################################################################### + self.blocks = nn.ModuleList([ + video_BiFormerBlock( + dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + num_heads=num_heads, + t_patch=t_patch, + topk=topk, + mlp_ratio=mlp_ratio, + side_dwconv=side_dwconv, + ) + for i in range(depth) + ]) + def forward(self, x:torch.Tensor): + """Definition of TPT_Block. + Args: + x [N,C,D,H,W] + Returns: + x [N,C,D,H,W] + """ + for i in range(self.layer_n) : + x = self.downsample_layers[i](x) + for blk in self.blocks: + x = blk(x) + for i in range(self.layer_n) : + x = self.upsample_layers[i](x) + + return x + +class RhythmFormer(nn.Module): + + def __init__( + self, + name: Optional[str] = None, + pretrained: bool = False, + dim: int = 64, frame: int = 160, + image_size: Optional[int] = (160,128,128), + in_chans=64, head_dim=16, + stage_n = 3, + embed_dim=[64, 64, 64], mlp_ratios=[1.5, 1.5, 1.5], + depth=[2, 2, 2], + t_patchs:Union[int, Tuple[int]]=(2, 4, 8), + topks:Union[int, Tuple[int]]=(40, 40, 40), + side_dwconv:int=3, + drop_path_rate=0., + use_checkpoint_stages=[], + ): + super().__init__() + + self.image_size = image_size + self.frame = frame + self.dim = dim + self.stage_n = stage_n + + self.Fusion_Stem = Fusion_Stem() + self.patch_embedding = nn.Conv3d(in_chans,embed_dim[0], kernel_size=(1, 4, 4), stride=(1, 4, 4)) + self.ConvBlockLast = nn.Conv1d(embed_dim[-1], 1, kernel_size=1,stride=1, padding=0) + + ########################################################################## + self.stages = nn.ModuleList() + nheads= [dim // head_dim for dim in embed_dim] + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] + for i in range(stage_n): + stage = TPT_Block(dim=embed_dim[i], + depth=depth[i], + num_heads=nheads[i], + mlp_ratio=mlp_ratios[i], + drop_path=dp_rates[sum(depth[:i]):sum(depth[:i+1])], + t_patch=t_patchs[i], topk=topks[i], side_dwconv=side_dwconv + ) + self.stages.append(stage) + ########################################################################## + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + N, D, C, H, W = x.shape + x = self.Fusion_Stem(x) #[N*D 64 H/4 W/4] + x = x.view(N,D,64,H//4,W//4).permute(0,2,1,3,4) + x = self.patch_embedding(x) #[N 64 D 8 8] + for i in range(3): + x = self.stages[i](x) #[N 64 D 8 8] + features_last = torch.mean(x,3) #[N, 64, D, 8] + features_last = torch.mean(features_last,3) #[N, 64, D] + rPPG = self.ConvBlockLast(features_last) #[N, 1, D] + rPPG = rPPG.squeeze(1) + return rPPG diff --git a/neural_methods/model/TS_CAN.py b/neural_methods/model/TS_CAN.py new file mode 100644 index 0000000000000000000000000000000000000000..4af6be923a3641c2201f0486c5a0b351ab5e65ec --- /dev/null +++ b/neural_methods/model/TS_CAN.py @@ -0,0 +1,269 @@ +"""Temporal Shift Convolutional Attention Network (TS-CAN). +Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement +NeurIPS, 2020 +Xin Liu, Josh Fromm, Shwetak Patel, Daniel McDuff +""" + +import torch +import torch.nn as nn + + +class Attention_mask(nn.Module): + def __init__(self): + super(Attention_mask, self).__init__() + + def forward(self, x): + xsum = torch.sum(x, dim=2, keepdim=True) + xsum = torch.sum(xsum, dim=3, keepdim=True) + xshape = tuple(x.size()) + return x / xsum * xshape[2] * xshape[3] * 0.5 + + def get_config(self): + """May be generated manually. """ + config = super(Attention_mask, self).get_config() + return config + + +class TSM(nn.Module): + def __init__(self, n_segment=10, fold_div=3): + super(TSM, self).__init__() + self.n_segment = n_segment + self.fold_div = fold_div + + def forward(self, x): + nt, c, h, w = x.size() + n_batch = nt // self.n_segment + x = x.view(n_batch, self.n_segment, c, h, w) + fold = c // self.fold_div + out = torch.zeros_like(x) + out[:, :-1, :fold] = x[:, 1:, :fold] # shift left + out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right + out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift + return out.view(nt, c, h, w) + + +class TSCAN(nn.Module): + + def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25, + dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20, img_size=36): + """Definition of TS_CAN. + Args: + in_channels: the number of input channel. Default: 3 + frame_depth: the number of frame (window size) used in temport shift. Default: 20 + img_size: height/width of each frame. Default: 36. + Returns: + TS_CAN model. + """ + super(TSCAN, self).__init__() + self.in_channels = in_channels + self.kernel_size = kernel_size + self.dropout_rate1 = dropout_rate1 + self.dropout_rate2 = dropout_rate2 + self.pool_size = pool_size + self.nb_filters1 = nb_filters1 + self.nb_filters2 = nb_filters2 + self.nb_dense = nb_dense + # TSM layers + self.TSM_1 = TSM(n_segment=frame_depth) + self.TSM_2 = TSM(n_segment=frame_depth) + self.TSM_3 = TSM(n_segment=frame_depth) + self.TSM_4 = TSM(n_segment=frame_depth) + # Motion branch convs + self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), + bias=True) + self.motion_conv2 = nn.Conv2d( + self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True) + self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), + bias=True) + self.motion_conv4 = nn.Conv2d( + self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True) + # Apperance branch convs + self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, + padding=(1, 1), bias=True) + self.apperance_conv2 = nn.Conv2d( + self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True) + self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, + padding=(1, 1), bias=True) + self.apperance_conv4 = nn.Conv2d( + self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True) + # Attention layers + self.apperance_att_conv1 = nn.Conv2d( + self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True) + self.attn_mask_1 = Attention_mask() + self.apperance_att_conv2 = nn.Conv2d( + self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True) + self.attn_mask_2 = Attention_mask() + # Avg pooling + self.avg_pooling_1 = nn.AvgPool2d(self.pool_size) + self.avg_pooling_2 = nn.AvgPool2d(self.pool_size) + self.avg_pooling_3 = nn.AvgPool2d(self.pool_size) + # Dropout layers + self.dropout_1 = nn.Dropout(self.dropout_rate1) + self.dropout_2 = nn.Dropout(self.dropout_rate1) + self.dropout_3 = nn.Dropout(self.dropout_rate1) + self.dropout_4 = nn.Dropout(self.dropout_rate2) + # Dense layers + if img_size == 36: + self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True) + elif img_size == 72: + self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True) + elif img_size == 96: + self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True) + elif img_size == 128: + self.final_dense_1 = nn.Linear(57600, self.nb_dense, bias=True) + else: + raise Exception('Unsupported image size') + self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True) + + def forward(self, inputs, params=None): + diff_input = inputs[:, :3, :, :] + raw_input = inputs[:, 3:, :, :] + + diff_input = self.TSM_1(diff_input) + d1 = torch.tanh(self.motion_conv1(diff_input)) + d1 = self.TSM_2(d1) + d2 = torch.tanh(self.motion_conv2(d1)) + + r1 = torch.tanh(self.apperance_conv1(raw_input)) + r2 = torch.tanh(self.apperance_conv2(r1)) + + g1 = torch.sigmoid(self.apperance_att_conv1(r2)) + g1 = self.attn_mask_1(g1) + gated1 = d2 * g1 + + d3 = self.avg_pooling_1(gated1) + d4 = self.dropout_1(d3) + + r3 = self.avg_pooling_2(r2) + r4 = self.dropout_2(r3) + + d4 = self.TSM_3(d4) + d5 = torch.tanh(self.motion_conv3(d4)) + d5 = self.TSM_4(d5) + d6 = torch.tanh(self.motion_conv4(d5)) + + r5 = torch.tanh(self.apperance_conv3(r4)) + r6 = torch.tanh(self.apperance_conv4(r5)) + + g2 = torch.sigmoid(self.apperance_att_conv2(r6)) + g2 = self.attn_mask_2(g2) + gated2 = d6 * g2 + + d7 = self.avg_pooling_3(gated2) + d8 = self.dropout_3(d7) + d9 = d8.view(d8.size(0), -1) + d10 = torch.tanh(self.final_dense_1(d9)) + d11 = self.dropout_4(d10) + out = self.final_dense_2(d11) + + return out + + +class MTTS_CAN(nn.Module): + """MTTS_CAN is the multi-task (respiration) version of TS-CAN""" + + def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25, + dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20): + super(MTTS_CAN, self).__init__() + self.in_channels = in_channels + self.kernel_size = kernel_size + self.dropout_rate1 = dropout_rate1 + self.dropout_rate2 = dropout_rate2 + self.pool_size = pool_size + self.nb_filters1 = nb_filters1 + self.nb_filters2 = nb_filters2 + self.nb_dense = nb_dense + # TSM layers + self.TSM_1 = TSM(n_segment=frame_depth) + self.TSM_2 = TSM(n_segment=frame_depth) + self.TSM_3 = TSM(n_segment=frame_depth) + self.TSM_4 = TSM(n_segment=frame_depth) + # Motion branch convs + self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1), + bias=True) + self.motion_conv2 = nn.Conv2d( + self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True) + self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1), + bias=True) + self.motion_conv4 = nn.Conv2d( + self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True) + # Apperance branch convs + self.apperance_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, + padding=(1, 1), bias=True) + self.apperance_conv2 = nn.Conv2d( + self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True) + self.apperance_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, + padding=(1, 1), bias=True) + self.apperance_conv4 = nn.Conv2d( + self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True) + # Attention layers + self.apperance_att_conv1 = nn.Conv2d( + self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True) + self.attn_mask_1 = Attention_mask() + self.apperance_att_conv2 = nn.Conv2d( + self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True) + self.attn_mask_2 = Attention_mask() + # Avg pooling + self.avg_pooling_1 = nn.AvgPool2d(self.pool_size) + self.avg_pooling_2 = nn.AvgPool2d(self.pool_size) + self.avg_pooling_3 = nn.AvgPool2d(self.pool_size) + # Dropout layers + self.dropout_1 = nn.Dropout(self.dropout_rate1) + self.dropout_2 = nn.Dropout(self.dropout_rate1) + self.dropout_3 = nn.Dropout(self.dropout_rate1) + self.dropout_4_y = nn.Dropout(self.dropout_rate2) + self.dropout_4_r = nn.Dropout(self.dropout_rate2) + + # Dense layers + self.final_dense_1_y = nn.Linear(16384, self.nb_dense, bias=True) + self.final_dense_2_y = nn.Linear(self.nb_dense, 1, bias=True) + self.final_dense_1_r = nn.Linear(16384, self.nb_dense, bias=True) + self.final_dense_2_r = nn.Linear(self.nb_dense, 1, bias=True) + + def forward(self, inputs, params=None): + diff_input = inputs[:, :3, :, :] + raw_input = inputs[:, 3:, :, :] + + diff_input = self.TSM_1(diff_input) + d1 = torch.tanh(self.motion_conv1(diff_input)) + d1 = self.TSM_2(d1) + d2 = torch.tanh(self.motion_conv2(d1)) + + r1 = torch.tanh(self.apperance_conv1(raw_input)) + r2 = torch.tanh(self.apperance_conv2(r1)) + + g1 = torch.sigmoid(self.apperance_att_conv1(r2)) + g1 = self.attn_mask_1(g1) + gated1 = d2 * g1 + + d3 = self.avg_pooling_1(gated1) + d4 = self.dropout_1(d3) + + r3 = self.avg_pooling_2(r2) + r4 = self.dropout_2(r3) + + d4 = self.TSM_3(d4) + d5 = torch.tanh(self.motion_conv3(d4)) + d5 = self.TSM_4(d5) + d6 = torch.tanh(self.motion_conv4(d5)) + + r5 = torch.tanh(self.apperance_conv3(r4)) + r6 = torch.tanh(self.apperance_conv4(r5)) + + g2 = torch.sigmoid(self.apperance_att_conv2(r6)) + g2 = self.attn_mask_2(g2) + gated2 = d6 * g2 + + d7 = self.avg_pooling_3(gated2) + d8 = self.dropout_3(d7) + d9 = d8.view(d8.size(0), -1) + + d10 = torch.tanh(self.final_dense_1_y(d9)) + d11 = self.dropout_4_y(d10) + out_y = self.final_dense_2_y(d11) + + d10 = torch.tanh(self.final_dense_1_r(d9)) + d11 = self.dropout_4_r(d10) + out_r = self.final_dense_2_r(d11) + + return out_y, out_r diff --git a/neural_methods/model/__init__.py b/neural_methods/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/neural_methods/model/iBVPNet.py b/neural_methods/model/iBVPNet.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6f8b1d592b9194b434f5555ddbceb4bc5c11c5 --- /dev/null +++ b/neural_methods/model/iBVPNet.py @@ -0,0 +1,194 @@ +"""iBVPNet - 3D Convolutional Network. +Proposed along with the iBVP Dataset, see https://doi.org/10.3390/electronics13071334 + +Joshi, Jitesh, and Youngjun Cho. 2024. "iBVP Dataset: RGB-Thermal rPPG Dataset with High Resolution Signal Quality Labels" Electronics 13, no. 7: 1334. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvBlock3D(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride, padding): + super(ConvBlock3D, self).__init__() + self.conv_block_3d = nn.Sequential( + nn.Conv3d(in_channel, out_channel, kernel_size, stride, padding), + nn.Tanh(), + nn.InstanceNorm3d(out_channel), + ) + + def forward(self, x): + return self.conv_block_3d(x) + + +class DeConvBlock3D(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride, padding): + super(DeConvBlock3D, self).__init__() + k_t, k_s1, k_s2 = kernel_size + s_t, s_s1, s_s2 = stride + self.deconv_block_3d = nn.Sequential( + nn.ConvTranspose3d(in_channel, in_channel, (k_t, 1, 1), (s_t, 1, 1), padding), + nn.Tanh(), + nn.InstanceNorm3d(in_channel), + + nn.Conv3d(in_channel, out_channel, (1, k_s1, k_s2), (1, s_s1, s_s2), padding), + nn.Tanh(), + nn.InstanceNorm3d(out_channel), + ) + + def forward(self, x): + return self.deconv_block_3d(x) + +# num_filters +nf = [8, 16, 24, 40, 64] + +class encoder_block(nn.Module): + def __init__(self, in_channel, debug=False): + super(encoder_block, self).__init__() + # in_channel, out_channel, kernel_size, stride, padding + + self.debug = debug + self.spatio_temporal_encoder = nn.Sequential( + ConvBlock3D(in_channel, nf[0], [1, 3, 3], [1, 1, 1], [0, 1, 1]), + ConvBlock3D(nf[0], nf[1], [3, 3, 3], [1, 1, 1], [1, 1, 1]), + nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)), + ConvBlock3D(nf[1], nf[2], [1, 3, 3], [1, 1, 1], [0, 1, 1]), + ConvBlock3D(nf[2], nf[3], [3, 3, 3], [1, 1, 1], [1, 1, 1]), + nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)), + ConvBlock3D(nf[3], nf[4], [1, 3, 3], [1, 1, 1], [0, 1, 1]), + ConvBlock3D(nf[4], nf[4], [3, 3, 3], [1, 1, 1], [1, 1, 1]), + ) + + self.temporal_encoder = nn.Sequential( + ConvBlock3D(nf[4], nf[4], [11, 1, 1], [1, 1, 1], [5, 0, 0]), + ConvBlock3D(nf[4], nf[4], [11, 3, 3], [1, 1, 1], [5, 1, 1]), + nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)), + ConvBlock3D(nf[4], nf[4], [11, 1, 1], [1, 1, 1], [5, 0, 0]), + ConvBlock3D(nf[4], nf[4], [11, 3, 3], [1, 1, 1], [5, 1, 1]), + nn.MaxPool3d((2, 2, 2), stride=(2, 1, 1)), + ConvBlock3D(nf[4], nf[4], [7, 1, 1], [1, 1, 1], [3, 0, 0]), + ConvBlock3D(nf[4], nf[4], [7, 3, 3], [1, 1, 1], [3, 1, 1]) + ) + + def forward(self, x): + if self.debug: + print("Encoder") + print("x.shape", x.shape) + st_x = self.spatio_temporal_encoder(x) + if self.debug: + print("st_x.shape", st_x.shape) + t_x = self.temporal_encoder(st_x) + if self.debug: + print("t_x.shape", t_x.shape) + return t_x + + +class decoder_block(nn.Module): + def __init__(self, debug=False): + super(decoder_block, self).__init__() + self.debug = debug + self.decoder_block = nn.Sequential( + DeConvBlock3D(nf[4], nf[3], [7, 3, 3], [2, 2, 2], [2, 1, 1]), + DeConvBlock3D(nf[3], nf[2], [7, 3, 3], [2, 2, 2], [2, 1, 1]) + ) + + def forward(self, x): + if self.debug: + print("Decoder") + print("x.shape", x.shape) + x = self.decoder_block(x) + if self.debug: + print("x.shape", x.shape) + return x + + + +class iBVPNet(nn.Module): + def __init__(self, frames, in_channels=3, debug=False): + super(iBVPNet, self).__init__() + self.debug = debug + + self.in_channels = in_channels + if self.in_channels == 1 or self.in_channels == 3: + self.norm = nn.InstanceNorm3d(self.in_channels) + elif self.in_channels == 4: + self.rgb_norm = nn.InstanceNorm3d(3) + self.thermal_norm = nn.InstanceNorm3d(1) + else: + print("Unsupported input channels") + + self.ibvpnet = nn.Sequential( + encoder_block(in_channels, debug), + decoder_block(debug), + # spatial adaptive pooling + nn.AdaptiveMaxPool3d((frames, 1, 1)), + nn.Conv3d(nf[2], 1, [1, 1, 1], stride=1, padding=0) + ) + + + def forward(self, x): # [batch, Features=3, Temp=frames, Width=32, Height=32] + + [batch, channel, length, width, height] = x.shape + + x = torch.diff(x, dim=2) + + if self.debug: + print("Input.shape", x.shape) + + if self.in_channels == 1: + x = self.norm(x[:, -1:, :, :, :]) + elif self.in_channels == 3: + x = self.norm(x[:, :3, :, :, :]) + elif self.in_channels == 4: + rgb_x = self.rgb_norm(x[:, :3, :, :, :]) + thermal_x = self.thermal_norm(x[:, -1:, :, :, :]) + x = torch.concat([rgb_x, thermal_x], dim = 1) + else: + try: + print("Specified input channels:", self.in_channels) + print("Data channels", channel) + assert self.in_channels <= channel + except: + print("Incorrectly preprocessed data provided as input. Number of channels exceed the specified or default channels") + print("Default or specified channels:", self.in_channels) + print("Data channels [B, C, N, W, H]", x.shape) + print("Exiting") + exit() + + if self.debug: + print("Diff Normalized shape", x.shape) + + feats = self.ibvpnet(x) + if self.debug: + print("feats.shape", feats.shape) + rPPG = feats.view(-1, length-1) + return rPPG + + +if __name__ == "__main__": + import torch + from torch.utils.tensorboard import SummaryWriter + + # default `log_dir` is "runs" - we'll be more specific here + writer = SummaryWriter('runs/iBVPNet') + + duration = 8 + fs = 25 + batch_size = 4 + frames = duration*fs + in_channels = 1 + height = 64 + width = 64 + test_data = torch.rand(batch_size, in_channels, frames, height, width) + + net = iBVPNet(in_channels=in_channels, frames=frames, debug=True) + # print("-"*100) + # print(net) + # print("-"*100) + pred = net(test_data) + + print(pred.shape) + + writer.add_graph(net, test_data) + writer.close() diff --git a/neural_methods/trainer/BaseTrainer.py b/neural_methods/trainer/BaseTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab774036aa6a246df1603993fa2dfb376de0b80a --- /dev/null +++ b/neural_methods/trainer/BaseTrainer.py @@ -0,0 +1,108 @@ +import torch +from torch.autograd import Variable +import matplotlib.pyplot as plt +from matplotlib.ticker import ScalarFormatter, MaxNLocator +import os +import pickle + + +class BaseTrainer: + @staticmethod + def add_trainer_args(parser): + """Adds arguments to Paser for training process""" + parser.add_argument('--lr', default=None, type=float) + parser.add_argument('--model_file_name', default=None, type=float) + return parser + + def __init__(self): + pass + + def train(self, data_loader): + pass + + def valid(self, data_loader): + pass + + def test(self): + pass + + def save_test_outputs(self, predictions, labels, config): + + output_dir = config.TEST.OUTPUT_SAVE_DIR + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + # Filename ID to be used in any output files that get saved + if config.TOOLBOX_MODE == 'train_and_test': + filename_id = self.model_file_name + elif config.TOOLBOX_MODE == 'only_test': + model_file_root = config.INFERENCE.MODEL_PATH.split("/")[-1].split(".pth")[0] + filename_id = model_file_root + "_" + config.TEST.DATA.DATASET + else: + raise ValueError('Metrics.py evaluation only supports train_and_test and only_test!') + output_path = os.path.join(output_dir, filename_id + '_outputs.pickle') + + data = dict() + data['predictions'] = predictions + data['labels'] = labels + data['label_type'] = config.TEST.DATA.PREPROCESS.LABEL_TYPE + data['fs'] = config.TEST.DATA.FS + + with open(output_path, 'wb') as handle: # save out frame dict pickle file + pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) + + print('Saving outputs to:', output_path) + + def plot_losses_and_lrs(self, train_loss, valid_loss, lrs, config): + + output_dir = os.path.join(config.LOG.PATH, config.TRAIN.DATA.EXP_DATA_NAME, 'plots') + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + # Filename ID to be used in plots that get saved + if config.TOOLBOX_MODE == 'train_and_test': + filename_id = self.model_file_name + else: + raise ValueError('Metrics.py evaluation only supports train_and_test and only_test!') + + # Create a single plot for training and validation losses + plt.figure(figsize=(10, 6)) + epochs = range(0, len(train_loss)) # Integer values for x-axis + plt.plot(epochs, train_loss, label='Training Loss') + if len(valid_loss) > 0: + plt.plot(epochs, valid_loss, label='Validation Loss') + else: + print("The list of validation losses is empty. The validation loss will not be plotted!") + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.title(f'{filename_id} Losses') + plt.legend() + plt.xticks(epochs) + + # Set y-axis ticks with more granularity + ax = plt.gca() + ax.yaxis.set_major_locator(MaxNLocator(integer=False, prune='both')) + + loss_plot_filename = os.path.join(output_dir, filename_id + '_losses.pdf') + plt.savefig(loss_plot_filename, dpi=300) + plt.close() + + # Create a separate plot for learning rates + plt.figure(figsize=(6, 4)) + scheduler_steps = range(0, len(lrs)) + plt.plot(scheduler_steps, lrs, label='Learning Rate') + plt.xlabel('Scheduler Step') + plt.ylabel('Learning Rate') + plt.title(f'{filename_id} LR Schedule') + plt.legend() + + # Set y-axis values in scientific notation + ax = plt.gca() + ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True, useOffset=False)) + ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0)) # Force scientific notation + + lr_plot_filename = os.path.join(output_dir, filename_id + '_learning_rates.pdf') + plt.savefig(lr_plot_filename, bbox_inches='tight', dpi=300) + plt.close() + + print('Saving plots of losses and learning rates to:', output_dir) diff --git a/neural_methods/trainer/BigSmallTrainer.py b/neural_methods/trainer/BigSmallTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2118dcd9006a804a3a367a00428f38b6ff8b28bf --- /dev/null +++ b/neural_methods/trainer/BigSmallTrainer.py @@ -0,0 +1,484 @@ +"""Trainer for BigSmall Multitask Models""" + +# Training / Eval Imports +import torch +import torch.optim as optim +from neural_methods.trainer.BaseTrainer import BaseTrainer +from neural_methods import loss +from neural_methods.model.BigSmall import BigSmall +from evaluation.bigsmall_multitask_metrics import (calculate_bvp_metrics, + calculate_resp_metrics, + calculate_bp4d_au_metrics) + +# Other Imports +from collections import OrderedDict +import numpy as np +import os +from tqdm import tqdm + +class BigSmallTrainer(BaseTrainer): + + def define_model(self, config): + + # BigSmall Model + model = BigSmall(n_segment=3) + + if self.using_TSM: + self.frame_depth = config.MODEL.BIGSMALL.FRAME_DEPTH + self.base_len = self.num_of_gpu * self.frame_depth + + return model + + def format_data_shape(self, data, labels): + # reshape big data + data_big = data[0] + N, D, C, H, W = data_big.shape + data_big = data_big.view(N * D, C, H, W) + + # reshape small data + data_small = data[1] + N, D, C, H, W = data_small.shape + data_small = data_small.view(N * D, C, H, W) + + # reshape labels + if len(labels.shape) != 3: # this training format requires labels that are of shape N_label, D_label, C_label + labels = torch.unsqueeze(labels, dim=-1) + N_label, D_label, C_label = labels.shape + labels = labels.view(N_label * D_label, C_label) + + # If using temporal shift module + if self.using_TSM: + data_big = data_big[:(N * D) // self.base_len * self.base_len] + data_small = data_small[:(N * D) // self.base_len * self.base_len] + labels = labels[:(N * D) // self.base_len * self.base_len] + + data[0] = data_big + data[1] = data_small + labels = torch.unsqueeze(labels, dim=-1) + + return data, labels + + + def send_data_to_device(self, data, labels): + big_data = data[0].to(self.device) + small_data = data[1].to(self.device) + labels = labels.to(self.device) + data = (big_data, small_data) + return data, labels + + + def get_label_idxs(self, label_list, used_labels): + label_idxs = [] + for l in used_labels: + idx = label_list.index(l) + label_idxs.append(idx) + return label_idxs + + + def remove_data_parallel(self, old_state_dict): + new_state_dict = OrderedDict() + + for k, v in old_state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + + return new_state_dict + + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) + print('') + + + def __init__(self, config, data_loader): + + print('') + print('Init BigSmall Multitask Trainer\n\n') + + self.config = config # save config file + + # Set up GPU/CPU compute device + if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0: + self.device = torch.device(config.DEVICE) # set device to primary GPU + self.num_of_gpu = config.NUM_OF_GPU_TRAIN # set number of used GPUs + else: + self.device = "cpu" # if no GPUs set device is CPU + self.num_of_gpu = 0 # no GPUs used + + # Defining model + self.using_TSM = True + self.model = self.define_model(config) # define the model + + if torch.cuda.device_count() > 1 and config.NUM_OF_GPU_TRAIN > 1: # distribute model across GPUs + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) # data parallel model + + self.model = self.model.to(self.device) # send model to primary GPU + + # Training parameters + self.batch_size = config.TRAIN.BATCH_SIZE + self.max_epoch_num = config.TRAIN.EPOCHS + self.LR = config.TRAIN.LR + + # Set Loss and Optimizer + AU_weights = torch.as_tensor([9.64, 11.74, 16.77, 1.05, 0.53, 0.56, + 0.75, 0.69, 8.51, 6.94, 5.03, 25.00]).to(self.device) + + self.criterionAU = torch.nn.BCEWithLogitsLoss(pos_weight=AU_weights).to(self.device) + self.criterionBVP = torch.nn.MSELoss().to(self.device) + self.criterionRESP = torch.nn.MSELoss().to(self.device) + self.optimizer = optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0) + + # self.scaler = torch.cuda.amp.GradScaler() # Loss scalar + + # Model info (saved more dir, chunk len, best epoch, etc.) + self.model_dir = config.MODEL.MODEL_DIR + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH + + # Epoch To Use For Test + self.used_epoch = 0 + + # Indicies corresponding to used labels + label_list = ['bp_wave', 'HR_bpm', 'systolic_bp', 'diastolic_bp', 'mean_bp', + 'resp_wave', 'resp_bpm', 'eda', + 'AU01', 'AU02', 'AU04', 'AU05', 'AU06', 'AU06int', 'AU07', 'AU09', 'AU10', 'AU10int', + 'AU11', 'AU12', 'AU12int', 'AU13', 'AU14', 'AU14int', 'AU15', 'AU16', 'AU17', 'AU17int', + 'AU18', 'AU19', 'AU20', 'AU22', 'AU23', 'AU24', 'AU27', 'AU28', 'AU29', 'AU30', 'AU31', + 'AU32', 'AU33', 'AU34', 'AU35', 'AU36', 'AU37', 'AU38', 'AU39', + 'pos_bvp','pos_env_norm_bvp'] + + used_labels = ['bp_wave', 'AU01', 'AU02', 'AU04', 'AU06', 'AU07', 'AU10', 'AU12', + 'AU14', 'AU15', 'AU17', 'AU23', 'AU24', + 'pos_env_norm_bvp', 'resp_wave'] + + # Get indicies for labels from npy array + au_label_list = [label for label in used_labels if 'AU' in label] + bvp_label_list_train = [label for label in used_labels if 'bvp' in label] + bvp_label_list_test = [label for label in used_labels if 'bp_wave' in label] + resp_label_list = [label for label in used_labels if 'resp' in label] + + self.label_idx_train_au = self.get_label_idxs(label_list, au_label_list) + self.label_idx_valid_au = self.get_label_idxs(label_list, au_label_list) + self.label_idx_test_au = self.get_label_idxs(label_list, au_label_list) + + self.label_idx_train_bvp = self.get_label_idxs(label_list, bvp_label_list_train) + self.label_idx_valid_bvp = self.get_label_idxs(label_list, bvp_label_list_train) + self.label_idx_test_bvp = self.get_label_idxs(label_list, bvp_label_list_test) + + self.label_idx_train_resp = self.get_label_idxs(label_list, resp_label_list) + self.label_idx_valid_resp = self.get_label_idxs(label_list, resp_label_list) + self.label_idx_test_resp = self.get_label_idxs(label_list, resp_label_list) + + + def train(self, data_loader): + """Model Training""" + + if data_loader["train"] is None: + raise ValueError("No data for train") + + print('Starting Training Routine') + print('') + + # Init min validation loss as infinity + min_valid_loss = np.inf # minimum validation loss + + # ARRAYS TO SAVE (LOSS ARRAYS) + train_loss_dict = dict() + train_au_loss_dict = dict() + train_bvp_loss_dict = dict() + train_resp_loss_dict = dict() + + val_loss_dict = dict() + val_au_loss_dict = dict() + val_bvp_loss_dict = dict() + val_resp_loss_dict = dict() + + # TODO: Expand tracking and subsequent plotting of these losses for BigSmall + mean_training_losses = [] + mean_valid_losses = [] + lrs = [] + + # ITERATE THROUGH EPOCHS + for epoch in range(self.max_epoch_num): + print(f"====Training Epoch: {epoch}====") + + # INIT PARAMS FOR TRAINING + running_loss = 0.0 # tracks avg loss over mini batches of 100 + train_loss = [] + train_au_loss = [] + train_bvp_loss = [] + train_resp_loss = [] + self.model.train() # put model in train mode + + # MODEL TRAINING + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + + # GATHER AND FORMAT BATCH DATA + data, labels = batch[0], batch[1] + data, labels = self.format_data_shape(data, labels) + data, labels = self.send_data_to_device(data, labels) + + # FOWARD AND BACK PROPOGATE THROUGH MODEL + self.optimizer.zero_grad() + au_out, bvp_out, resp_out = self.model(data) + au_loss = self.criterionAU(au_out, labels[:, self.label_idx_train_au, 0]) # au loss + bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_train_bvp, 0]) # bvp loss + resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_train_resp, 0]) # resp loss + loss = au_loss + bvp_loss + resp_loss # sum losses + loss.backward() + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + + self.optimizer.step() + # self.scaler.scale(loss).backward() # Loss scaling + # self.scaler.step(self.optimizer) + # self.scaler.update() + + + + + # UPDATE RUNNING LOSS AND PRINTED TERMINAL OUTPUT AND SAVED LOSSES + train_loss.append(loss.item()) + train_au_loss.append(au_loss.item()) + train_bvp_loss.append(bvp_loss.item()) + train_resp_loss.append(resp_loss.item()) + + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print(f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + + + tbar.set_postfix({"loss:": loss.item(), "lr:": self.optimizer.param_groups[0]["lr"]}) + + # APPEND EPOCH LOSS LIST TO TRAINING LOSS DICTIONARY + train_loss_dict[epoch] = train_loss + train_au_loss_dict[epoch] = train_au_loss + train_bvp_loss_dict[epoch] = train_bvp_loss + train_resp_loss_dict[epoch] = train_resp_loss + + print('') + + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(train_loss)) + + # SAVE MODEL FOR THIS EPOCH + self.save_model(epoch) + + # VALIDATION (IF ENABLED) + if not self.config.TEST.USE_LAST_EPOCH: + + # Get validation losses + valid_loss, valid_au_loss, valid_bvp_loss, valid_resp_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + val_loss_dict[epoch] = valid_loss + val_au_loss_dict[epoch] = valid_au_loss + val_bvp_loss_dict[epoch] = valid_bvp_loss + val_resp_loss_dict[epoch] = valid_resp_loss + print('validation loss: ', valid_loss) + + # Update used model + if self.model_to_use == 'best_epoch' and (valid_loss < min_valid_loss): + min_valid_loss = valid_loss + self.used_epoch = epoch + print("Update best model! Best epoch: {}".format(self.used_epoch)) + elif self.model_to_use == 'last_epoch': + self.used_epoch = epoch + + # VALIDATION (NOT ENABLED) + else: + self.used_epoch = epoch + + print('') + + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + # PRINT MODEL TO BE USED FOR TESTING + print("Used model trained epoch:{}, val_loss:{}".format(self.used_epoch, min_valid_loss)) + print('') + + + + def valid(self, data_loader): + """ Model evaluation on the validation dataset.""" + + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print("===Validating===") + + # INIT PARAMS FOR VALIDATION + valid_loss = [] + valid_au_loss = [] + valid_bvp_loss = [] + valid_resp_loss = [] + self.model.eval() + + # MODEL VALIDATION + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + + # GATHER AND FORMAT BATCH DATA + data, labels = valid_batch[0], valid_batch[1] + data, labels = self.format_data_shape(data, labels) + data, labels = self.send_data_to_device(data, labels) + + au_out, bvp_out, resp_out = self.model(data) + au_loss = self.criterionAU(au_out, labels[:, self.label_idx_valid_au, 0]) # au loss + bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_valid_bvp, 0]) # bvp loss + resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_valid_resp, 0]) # resp loss + loss = au_loss + bvp_loss + resp_loss # sum losses + + # APPEND VAL LOSS + valid_loss.append(loss.item()) + valid_au_loss.append(au_loss.item()) + valid_bvp_loss.append(bvp_loss.item()) + valid_resp_loss.append(resp_loss.item()) + vbar.set_postfix(loss=loss.item()) + + valid_loss = np.asarray(valid_loss) + valid_au_loss = np.asarray(valid_au_loss) + valid_bvp_loss = np.asarray(valid_bvp_loss) + valid_resp_loss = np.asarray(valid_resp_loss) + return np.mean(valid_loss), np.mean(valid_au_loss), np.mean(valid_bvp_loss), np.mean(valid_resp_loss) + + + + def test(self, data_loader): + """ Model evaluation on the testing dataset.""" + + print("===Testing===") + print('') + + # SETUP + if data_loader["test"] is None: + raise ValueError("No data for test") + + # Change chunk length to be test chunk length + self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + + # ARRAYS TO SAVE (PREDICTIONS AND METRICS ARRAYS) + preds_dict_au = dict() + labels_dict_au = dict() + preds_dict_bvp = dict() + labels_dict_bvp = dict() + preds_dict_resp = dict() + labels_dict_resp = dict() + + # IF ONLY_TEST MODE LOAD PRETRAINED MODEL + if self.config.TOOLBOX_MODE == "only_test": + model_path = self.config.INFERENCE.MODEL_PATH + print("Testing uses pretrained model!") + print('Model path:', model_path) + if not os.path.exists(model_path): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + + # IF USING MODEL FROM TRAINING + else: + model_path = os.path.join(self.model_dir, + self.model_file_name + '_Epoch' + str(self.used_epoch) + '.pth') + print("Testing uses non-pretrained model!") + print('Model path:', model_path) + if not os.path.exists(model_path): + raise ValueError("Something went wrong... cant find trained model...") + print('') + + # LOAD ABOVED SPECIFIED MODEL FOR TESTING + self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + self.model = self.model.to(self.device) + self.model.eval() + + # MODEL TESTING + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + + # PROCESSING - ANALYSIS, METRICS, SAVING OUT DATA + batch_size = test_batch[1].shape[0] # get batch size + + # GATHER AND FORMAT BATCH DATA + data, labels = test_batch[0], test_batch[1] + data, labels = self.format_data_shape(data, labels) + data, labels = self.send_data_to_device(data, labels) + + # Weird dataloader bug is causing the final training batch to be of size 0... + if labels.shape[0] == 0: + continue + + # GET MODEL PREDICTIONS + au_out, bvp_out, resp_out = self.model(data) + au_out = torch.sigmoid(au_out) + + # GATHER AND SLICE LABELS USED FOR TEST DATASET + TEST_AU = False + if len(self.label_idx_test_au) > 0: # if test dataset has AU + TEST_AU = True + labels_au = labels[:, self.label_idx_test_au] + else: # if not set whole AU labels array to -1 + labels_au = np.ones((batch_size, len(self.label_idx_train_au))) + labels_au = -1 * labels_au + # labels_au = torch.from_numpy(labels_au) + + TEST_BVP = False + if len(self.label_idx_test_bvp) > 0: # if test dataset has BVP + TEST_BVP = True + labels_bvp = labels[:, self.label_idx_test_bvp] + else: # if not set whole BVP labels array to -1 + labels_bvp = np.ones((batch_size, len(self.label_idx_train_bvp))) + labels_bvp = -1 * labels_bvp + # labels_bvp = torch.from_numpy(labels_bvp) + + TEST_RESP = False + if len(self.label_idx_test_resp) > 0: # if test dataset has BVP + TEST_RESP = True + labels_resp = labels[:, self.label_idx_test_resp] + else: # if not set whole BVP labels array to -1 + labels_resp = np.ones((batch_size, len(self.label_idx_train_resp))) + labels_resp = -1 * labels_resp + # labels_resp = torch.from_numpy(labels_resp) + + # ITERATE THROUGH BATCH, SORT, AND ADD TO CORRECT DICTIONARY + for idx in range(batch_size): + + # if the labels are cut off due to TSM dataformating + if idx * self.chunk_len >= labels.shape[0] and self.using_TSM: + continue + + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + + # add subject to prediction / label arrays + if subj_index not in preds_dict_bvp.keys(): + preds_dict_au[subj_index] = dict() + labels_dict_au[subj_index] = dict() + preds_dict_bvp[subj_index] = dict() + labels_dict_bvp[subj_index] = dict() + preds_dict_resp[subj_index] = dict() + labels_dict_resp[subj_index] = dict() + + # append predictions and labels to subject dict + preds_dict_au[subj_index][sort_index] = au_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels_dict_au[subj_index][sort_index] = labels_au[idx * self.chunk_len:(idx + 1) * self.chunk_len] + preds_dict_bvp[subj_index][sort_index] = bvp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels_dict_bvp[subj_index][sort_index] = labels_bvp[idx * self.chunk_len:(idx + 1) * self.chunk_len] + preds_dict_resp[subj_index][sort_index] = resp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels_dict_resp[subj_index][sort_index] = labels_resp[idx * self.chunk_len:(idx + 1) * self.chunk_len] + + # Calculate Eval Metrics + bvp_metric_dict = calculate_bvp_metrics(preds_dict_bvp, labels_dict_bvp, self.config) + resp_metric_dict = calculate_resp_metrics(preds_dict_resp, labels_dict_resp, self.config) + au_metric_dict = calculate_bp4d_au_metrics(preds_dict_au, labels_dict_au, self.config) + + + + diff --git a/neural_methods/trainer/BigSmallTrainer.py.backup b/neural_methods/trainer/BigSmallTrainer.py.backup new file mode 100644 index 0000000000000000000000000000000000000000..eb754bce0c70c5ee55ef98f67f6d3d75015f919b --- /dev/null +++ b/neural_methods/trainer/BigSmallTrainer.py.backup @@ -0,0 +1,484 @@ +"""Trainer for BigSmall Multitask Models""" + +# Training / Eval Imports +import torch +import torch.optim as optim +from neural_methods.trainer.BaseTrainer import BaseTrainer +from neural_methods import loss +from neural_methods.model.BigSmall import BigSmall +from evaluation.bigsmall_multitask_metrics import (calculate_bvp_metrics, + calculate_resp_metrics, + calculate_bp4d_au_metrics) + +# Other Imports +from collections import OrderedDict +import numpy as np +import os +from tqdm import tqdm + +class BigSmallTrainer(BaseTrainer): + + def define_model(self, config): + + # BigSmall Model + model = BigSmall(n_segment=3) + + if self.using_TSM: + self.frame_depth = config.MODEL.BIGSMALL.FRAME_DEPTH + self.base_len = self.num_of_gpu * self.frame_depth + + return model + + def format_data_shape(self, data, labels): + # reshape big data + data_big = data[0] + N, D, C, H, W = data_big.shape + data_big = data_big.view(N * D, C, H, W) + + # reshape small data + data_small = data[1] + N, D, C, H, W = data_small.shape + data_small = data_small.view(N * D, C, H, W) + + # reshape labels + if len(labels.shape) != 3: # this training format requires labels that are of shape N_label, D_label, C_label + labels = torch.unsqueeze(labels, dim=-1) + N_label, D_label, C_label = labels.shape + labels = labels.view(N_label * D_label, C_label) + + # If using temporal shift module + if self.using_TSM: + data_big = data_big[:(N * D) // self.base_len * self.base_len] + data_small = data_small[:(N * D) // self.base_len * self.base_len] + labels = labels[:(N * D) // self.base_len * self.base_len] + + data[0] = data_big + data[1] = data_small + labels = torch.unsqueeze(labels, dim=-1) + + return data, labels + + + def send_data_to_device(self, data, labels): + big_data = data[0].to(self.device) + small_data = data[1].to(self.device) + labels = labels.to(self.device) + data = (big_data, small_data) + return data, labels + + + def get_label_idxs(self, label_list, used_labels): + label_idxs = [] + for l in used_labels: + idx = label_list.index(l) + label_idxs.append(idx) + return label_idxs + + + def remove_data_parallel(self, old_state_dict): + new_state_dict = OrderedDict() + + for k, v in old_state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + + return new_state_dict + + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) + print('') + + + def __init__(self, config, data_loader): + + print('') + print('Init BigSmall Multitask Trainer\n\n') + + self.config = config # save config file + + # Set up GPU/CPU compute device + if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0: + self.device = torch.device(config.DEVICE) # set device to primary GPU + self.num_of_gpu = config.NUM_OF_GPU_TRAIN # set number of used GPUs + else: + self.device = "cpu" # if no GPUs set device is CPU + self.num_of_gpu = 0 # no GPUs used + + # Defining model + self.using_TSM = True + self.model = self.define_model(config) # define the model + + if torch.cuda.device_count() > 1 and config.NUM_OF_GPU_TRAIN > 1: # distribute model across GPUs + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) # data parallel model + + self.model = self.model.to(self.device) # send model to primary GPU + + # Training parameters + self.batch_size = config.TRAIN.BATCH_SIZE + self.max_epoch_num = config.TRAIN.EPOCHS + self.LR = config.TRAIN.LR + + # Set Loss and Optimizer + AU_weights = torch.as_tensor([9.64, 11.74, 16.77, 1.05, 0.53, 0.56, + 0.75, 0.69, 8.51, 6.94, 5.03, 25.00]).to(self.device) + + self.criterionAU = torch.nn.BCEWithLogitsLoss(pos_weight=AU_weights).to(self.device) + self.criterionBVP = torch.nn.MSELoss().to(self.device) + self.criterionRESP = torch.nn.MSELoss().to(self.device) + self.optimizer = optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0) + + # self.scaler = torch.cuda.amp.GradScaler() # Loss scalar + + # Model info (saved more dir, chunk len, best epoch, etc.) + self.model_dir = config.MODEL.MODEL_DIR + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH + + # Epoch To Use For Test + self.used_epoch = 0 + + # Indicies corresponding to used labels + label_list = ['bp_wave', 'HR_bpm', 'systolic_bp', 'diastolic_bp', 'mean_bp', + 'resp_wave', 'resp_bpm', 'eda', + 'AU01', 'AU02', 'AU04', 'AU05', 'AU06', 'AU06int', 'AU07', 'AU09', 'AU10', 'AU10int', + 'AU11', 'AU12', 'AU12int', 'AU13', 'AU14', 'AU14int', 'AU15', 'AU16', 'AU17', 'AU17int', + 'AU18', 'AU19', 'AU20', 'AU22', 'AU23', 'AU24', 'AU27', 'AU28', 'AU29', 'AU30', 'AU31', + 'AU32', 'AU33', 'AU34', 'AU35', 'AU36', 'AU37', 'AU38', 'AU39', + 'pos_bvp','pos_env_norm_bvp'] + + used_labels = ['bp_wave', 'AU01', 'AU02', 'AU04', 'AU06', 'AU07', 'AU10', 'AU12', + 'AU14', 'AU15', 'AU17', 'AU23', 'AU24', + 'pos_env_norm_bvp', 'resp_wave'] + + # Get indicies for labels from npy array + au_label_list = [label for label in used_labels if 'AU' in label] + bvp_label_list_train = [label for label in used_labels if 'bvp' in label] + bvp_label_list_test = [label for label in used_labels if 'bp_wave' in label] + resp_label_list = [label for label in used_labels if 'resp' in label] + + self.label_idx_train_au = self.get_label_idxs(label_list, au_label_list) + self.label_idx_valid_au = self.get_label_idxs(label_list, au_label_list) + self.label_idx_test_au = self.get_label_idxs(label_list, au_label_list) + + self.label_idx_train_bvp = self.get_label_idxs(label_list, bvp_label_list_train) + self.label_idx_valid_bvp = self.get_label_idxs(label_list, bvp_label_list_train) + self.label_idx_test_bvp = self.get_label_idxs(label_list, bvp_label_list_test) + + self.label_idx_train_resp = self.get_label_idxs(label_list, resp_label_list) + self.label_idx_valid_resp = self.get_label_idxs(label_list, resp_label_list) + self.label_idx_test_resp = self.get_label_idxs(label_list, resp_label_list) + + + def train(self, data_loader): + """Model Training""" + + if data_loader["train"] is None: + raise ValueError("No data for train") + + print('Starting Training Routine') + print('') + + # Init min validation loss as infinity + min_valid_loss = np.inf # minimum validation loss + + # ARRAYS TO SAVE (LOSS ARRAYS) + train_loss_dict = dict() + train_au_loss_dict = dict() + train_bvp_loss_dict = dict() + train_resp_loss_dict = dict() + + val_loss_dict = dict() + val_au_loss_dict = dict() + val_bvp_loss_dict = dict() + val_resp_loss_dict = dict() + + # TODO: Expand tracking and subsequent plotting of these losses for BigSmall + mean_training_losses = [] + mean_valid_losses = [] + lrs = [] + + # ITERATE THROUGH EPOCHS + for epoch in range(self.max_epoch_num): + print(f"====Training Epoch: {epoch}====") + + # INIT PARAMS FOR TRAINING + running_loss = 0.0 # tracks avg loss over mini batches of 100 + train_loss = [] + train_au_loss = [] + train_bvp_loss = [] + train_resp_loss = [] + self.model.train() # put model in train mode + + # MODEL TRAINING + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + + # GATHER AND FORMAT BATCH DATA + data, labels = batch[0], batch[1] + data, labels = self.format_data_shape(data, labels) + data, labels = self.send_data_to_device(data, labels) + + # FOWARD AND BACK PROPOGATE THROUGH MODEL + self.optimizer.zero_grad() + au_out, bvp_out, resp_out = self.model(data) + au_loss = self.criterionAU(au_out, labels[:, self.label_idx_train_au, 0]) # au loss + bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_train_bvp, 0]) # bvp loss + resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_train_resp, 0]) # resp loss + loss = au_loss + bvp_loss + resp_loss # sum losses + loss.backward() + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + + self.optimizer.step() + # self.scaler.scale(loss).backward() # Loss scaling + # self.scaler.step(self.optimizer) + # self.scaler.update() + + + + + # UPDATE RUNNING LOSS AND PRINTED TERMINAL OUTPUT AND SAVED LOSSES + train_loss.append(loss.item()) + train_au_loss.append(au_loss.item()) + train_bvp_loss.append(bvp_loss.item()) + train_resp_loss.append(resp_loss.item()) + + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print(f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + + + tbar.set_postfix({"loss:": loss.item(), "lr:": self.optimizer.param_groups[0]["lr"]}) + + # APPEND EPOCH LOSS LIST TO TRAINING LOSS DICTIONARY + train_loss_dict[epoch] = train_loss + train_au_loss_dict[epoch] = train_au_loss + train_bvp_loss_dict[epoch] = train_bvp_loss + train_resp_loss_dict[epoch] = train_resp_loss + + print('') + + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(train_loss)) + + # SAVE MODEL FOR THIS EPOCH + self.save_model(epoch) + + # VALIDATION (IF ENABLED) + if not self.config.TEST.USE_LAST_EPOCH: + + # Get validation losses + valid_loss, valid_au_loss, valid_bvp_loss, valid_resp_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + val_loss_dict[epoch] = valid_loss + val_au_loss_dict[epoch] = valid_au_loss + val_bvp_loss_dict[epoch] = valid_bvp_loss + val_resp_loss_dict[epoch] = valid_resp_loss + print('validation loss: ', valid_loss) + + # Update used model + if self.model_to_use == 'best_epoch' and (valid_loss < min_valid_loss): + min_valid_loss = valid_loss + self.used_epoch = epoch + print("Update best model! Best epoch: {}".format(self.used_epoch)) + elif self.model_to_use == 'last_epoch': + self.used_epoch = epoch + + # VALIDATION (NOT ENABLED) + else: + self.used_epoch = epoch + + print('') + + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + # PRINT MODEL TO BE USED FOR TESTING + print("Used model trained epoch:{}, val_loss:{}".format(self.used_epoch, min_valid_loss)) + print('') + + + + def valid(self, data_loader): + """ Model evaluation on the validation dataset.""" + + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print("===Validating===") + + # INIT PARAMS FOR VALIDATION + valid_loss = [] + valid_au_loss = [] + valid_bvp_loss = [] + valid_resp_loss = [] + self.model.eval() + + # MODEL VALIDATION + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + + # GATHER AND FORMAT BATCH DATA + data, labels = valid_batch[0], valid_batch[1] + data, labels = self.format_data_shape(data, labels) + data, labels = self.send_data_to_device(data, labels) + + au_out, bvp_out, resp_out = self.model(data) + au_loss = self.criterionAU(au_out, labels[:, self.label_idx_valid_au, 0]) # au loss + bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_valid_bvp, 0]) # bvp loss + resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_valid_resp, 0]) # resp loss + loss = au_loss + bvp_loss + resp_loss # sum losses + + # APPEND VAL LOSS + valid_loss.append(loss.item()) + valid_au_loss.append(au_loss.item()) + valid_bvp_loss.append(bvp_loss.item()) + valid_resp_loss.append(resp_loss.item()) + vbar.set_postfix(loss=loss.item()) + + valid_loss = np.asarray(valid_loss) + valid_au_loss = np.asarray(valid_au_loss) + valid_bvp_loss = np.asarray(valid_bvp_loss) + valid_resp_loss = np.asarray(valid_resp_loss) + return np.mean(valid_loss), np.mean(valid_au_loss), np.mean(valid_bvp_loss), np.mean(valid_resp_loss) + + + + def test(self, data_loader): + """ Model evaluation on the testing dataset.""" + + print("===Testing===") + print('') + + # SETUP + if data_loader["test"] is None: + raise ValueError("No data for test") + + # Change chunk length to be test chunk length + self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + + # ARRAYS TO SAVE (PREDICTIONS AND METRICS ARRAYS) + preds_dict_au = dict() + labels_dict_au = dict() + preds_dict_bvp = dict() + labels_dict_bvp = dict() + preds_dict_resp = dict() + labels_dict_resp = dict() + + # IF ONLY_TEST MODE LOAD PRETRAINED MODEL + if self.config.TOOLBOX_MODE == "only_test": + model_path = self.config.INFERENCE.MODEL_PATH + print("Testing uses pretrained model!") + print('Model path:', model_path) + if not os.path.exists(model_path): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + + # IF USING MODEL FROM TRAINING + else: + model_path = os.path.join(self.model_dir, + self.model_file_name + '_Epoch' + str(self.used_epoch) + '.pth') + print("Testing uses non-pretrained model!") + print('Model path:', model_path) + if not os.path.exists(model_path): + raise ValueError("Something went wrong... cant find trained model...") + print('') + + # LOAD ABOVED SPECIFIED MODEL FOR TESTING + self.model.load_state_dict(torch.load(model_path)) + self.model = self.model.to(self.device) + self.model.eval() + + # MODEL TESTING + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + + # PROCESSING - ANALYSIS, METRICS, SAVING OUT DATA + batch_size = test_batch[1].shape[0] # get batch size + + # GATHER AND FORMAT BATCH DATA + data, labels = test_batch[0], test_batch[1] + data, labels = self.format_data_shape(data, labels) + data, labels = self.send_data_to_device(data, labels) + + # Weird dataloader bug is causing the final training batch to be of size 0... + if labels.shape[0] == 0: + continue + + # GET MODEL PREDICTIONS + au_out, bvp_out, resp_out = self.model(data) + au_out = torch.sigmoid(au_out) + + # GATHER AND SLICE LABELS USED FOR TEST DATASET + TEST_AU = False + if len(self.label_idx_test_au) > 0: # if test dataset has AU + TEST_AU = True + labels_au = labels[:, self.label_idx_test_au] + else: # if not set whole AU labels array to -1 + labels_au = np.ones((batch_size, len(self.label_idx_train_au))) + labels_au = -1 * labels_au + # labels_au = torch.from_numpy(labels_au) + + TEST_BVP = False + if len(self.label_idx_test_bvp) > 0: # if test dataset has BVP + TEST_BVP = True + labels_bvp = labels[:, self.label_idx_test_bvp] + else: # if not set whole BVP labels array to -1 + labels_bvp = np.ones((batch_size, len(self.label_idx_train_bvp))) + labels_bvp = -1 * labels_bvp + # labels_bvp = torch.from_numpy(labels_bvp) + + TEST_RESP = False + if len(self.label_idx_test_resp) > 0: # if test dataset has BVP + TEST_RESP = True + labels_resp = labels[:, self.label_idx_test_resp] + else: # if not set whole BVP labels array to -1 + labels_resp = np.ones((batch_size, len(self.label_idx_train_resp))) + labels_resp = -1 * labels_resp + # labels_resp = torch.from_numpy(labels_resp) + + # ITERATE THROUGH BATCH, SORT, AND ADD TO CORRECT DICTIONARY + for idx in range(batch_size): + + # if the labels are cut off due to TSM dataformating + if idx * self.chunk_len >= labels.shape[0] and self.using_TSM: + continue + + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + + # add subject to prediction / label arrays + if subj_index not in preds_dict_bvp.keys(): + preds_dict_au[subj_index] = dict() + labels_dict_au[subj_index] = dict() + preds_dict_bvp[subj_index] = dict() + labels_dict_bvp[subj_index] = dict() + preds_dict_resp[subj_index] = dict() + labels_dict_resp[subj_index] = dict() + + # append predictions and labels to subject dict + preds_dict_au[subj_index][sort_index] = au_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels_dict_au[subj_index][sort_index] = labels_au[idx * self.chunk_len:(idx + 1) * self.chunk_len] + preds_dict_bvp[subj_index][sort_index] = bvp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels_dict_bvp[subj_index][sort_index] = labels_bvp[idx * self.chunk_len:(idx + 1) * self.chunk_len] + preds_dict_resp[subj_index][sort_index] = resp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels_dict_resp[subj_index][sort_index] = labels_resp[idx * self.chunk_len:(idx + 1) * self.chunk_len] + + # Calculate Eval Metrics + bvp_metric_dict = calculate_bvp_metrics(preds_dict_bvp, labels_dict_bvp, self.config) + resp_metric_dict = calculate_resp_metrics(preds_dict_resp, labels_dict_resp, self.config) + au_metric_dict = calculate_bp4d_au_metrics(preds_dict_au, labels_dict_au, self.config) + + + + diff --git a/neural_methods/trainer/DeepPhysTrainer.py b/neural_methods/trainer/DeepPhysTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..04f00d5f82ec9180051bc189b0adfc8319e0f613 --- /dev/null +++ b/neural_methods/trainer/DeepPhysTrainer.py @@ -0,0 +1,209 @@ +"""Trainer for DeepPhys.""" + +import logging +import os +from collections import OrderedDict + +import numpy as np +import torch +import torch.optim as optim +from evaluation.metrics import calculate_metrics +from neural_methods.loss.NegPearsonLoss import Neg_Pearson +from neural_methods.model.DeepPhys import DeepPhys +from neural_methods.trainer.BaseTrainer import BaseTrainer +from tqdm import tqdm + + +class DeepPhysTrainer(BaseTrainer): + + def __init__(self, config, data_loader): + """Inits parameters from args and the writer for TensorboardX.""" + super().__init__() + self.device = torch.device(config.DEVICE) + self.max_epoch_num = config.TRAIN.EPOCHS + self.model_dir = config.MODEL.MODEL_DIR + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.batch_size = config.TRAIN.BATCH_SIZE + self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH + self.config = config + self.min_valid_loss = None + self.best_epoch = 0 + + if config.TOOLBOX_MODE == "train_and_test": + self.model = DeepPhys(img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + + self.num_train_batches = len(data_loader["train"]) + self.criterion = torch.nn.MSELoss() + self.optimizer = optim.AdamW( + self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0) + # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) + elif config.TOOLBOX_MODE == "only_test": + self.model = DeepPhys(img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + else: + raise ValueError("DeepPhys trainer initialized in incorrect toolbox mode!") + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + mean_training_losses = [] + mean_valid_losses = [] + lrs = [] + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + running_loss = 0.0 + train_loss = [] + self.model.train() + # Model Training + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + data, labels = batch[0].to( + self.device), batch[1].to(self.device) + N, D, C, H, W = data.shape + data = data.view(N * D, C, H, W) + labels = labels.view(-1, 1) + self.optimizer.zero_grad() + pred_ppg = self.model(data) + loss = self.criterion(pred_ppg, labels) + loss.backward() + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + + self.optimizer.step() + self.scheduler.step() + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print( + f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + train_loss.append(loss.item()) + tbar.set_postfix({"loss": loss.item(), "lr": self.optimizer.param_groups[0]["lr"]}) + + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(train_loss)) + + self.save_model(epoch) + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + print('validation loss: ', valid_loss) + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss)) + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + def valid(self, data_loader): + """ Model evaluation on the validation dataset.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print("===Validating===") + valid_loss = [] + self.model.eval() + valid_step = 0 + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + data_valid, labels_valid = valid_batch[0].to( + self.device), valid_batch[1].to(self.device) + N, D, C, H, W = data_valid.shape + data_valid = data_valid.view(N * D, C, H, W) + labels_valid = labels_valid.view(-1, 1) + pred_ppg_valid = self.model(data_valid) + loss = self.criterion(pred_ppg_valid, labels_valid) + valid_loss.append(loss.item()) + valid_step += 1 + vbar.set_postfix(loss=loss.item()) + valid_loss = np.asarray(valid_loss) + return np.mean(valid_loss) + + def test(self, data_loader): + """ Model evaluation on the testing dataset.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + config = self.config + + print('') + print("===Testing===") + + # Change chunk length to be test chunk length + self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + + predictions = dict() + labels = dict() + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu"))) + print("Testing uses pretrained model!") + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu"))) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu"))) + + self.model = self.model.to(self.config.DEVICE) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data_test, labels_test = test_batch[0].to( + self.config.DEVICE), test_batch[1].to(self.config.DEVICE) + N, D, C, H, W = data_test.shape + data_test = data_test.view(N * D, C, H, W) + labels_test = labels_test.view(-1, 1) + pred_ppg_test = self.model(data_test) + + if self.config.TEST.OUTPUT_SAVE_DIR: + labels_test = labels_test.cpu() + pred_ppg_test = pred_ppg_test.cpu() + + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len] + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + """Inits parameters from args and the writer for TensorboardX.""" + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + \ No newline at end of file diff --git a/neural_methods/trainer/DeepPhysTrainer.py.backup b/neural_methods/trainer/DeepPhysTrainer.py.backup new file mode 100644 index 0000000000000000000000000000000000000000..bd911cb09705162d530b6fa8035afd099dfb37c4 --- /dev/null +++ b/neural_methods/trainer/DeepPhysTrainer.py.backup @@ -0,0 +1,209 @@ +"""Trainer for DeepPhys.""" + +import logging +import os +from collections import OrderedDict + +import numpy as np +import torch +import torch.optim as optim +from evaluation.metrics import calculate_metrics +from neural_methods.loss.NegPearsonLoss import Neg_Pearson +from neural_methods.model.DeepPhys import DeepPhys +from neural_methods.trainer.BaseTrainer import BaseTrainer +from tqdm import tqdm + + +class DeepPhysTrainer(BaseTrainer): + + def __init__(self, config, data_loader): + """Inits parameters from args and the writer for TensorboardX.""" + super().__init__() + self.device = torch.device(config.DEVICE) + self.max_epoch_num = config.TRAIN.EPOCHS + self.model_dir = config.MODEL.MODEL_DIR + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.batch_size = config.TRAIN.BATCH_SIZE + self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH + self.config = config + self.min_valid_loss = None + self.best_epoch = 0 + + if config.TOOLBOX_MODE == "train_and_test": + self.model = DeepPhys(img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to(self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + + self.num_train_batches = len(data_loader["train"]) + self.criterion = torch.nn.MSELoss() + self.optimizer = optim.AdamW( + self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0) + # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) + elif config.TOOLBOX_MODE == "only_test": + self.model = DeepPhys(img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to(self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + else: + raise ValueError("DeepPhys trainer initialized in incorrect toolbox mode!") + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + mean_training_losses = [] + mean_valid_losses = [] + lrs = [] + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + running_loss = 0.0 + train_loss = [] + self.model.train() + # Model Training + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + data, labels = batch[0].to( + self.device), batch[1].to(self.device) + N, D, C, H, W = data.shape + data = data.view(N * D, C, H, W) + labels = labels.view(-1, 1) + self.optimizer.zero_grad() + pred_ppg = self.model(data) + loss = self.criterion(pred_ppg, labels) + loss.backward() + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + + self.optimizer.step() + self.scheduler.step() + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print( + f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + train_loss.append(loss.item()) + tbar.set_postfix({"loss": loss.item(), "lr": self.optimizer.param_groups[0]["lr"]}) + + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(train_loss)) + + self.save_model(epoch) + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + print('validation loss: ', valid_loss) + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss)) + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + def valid(self, data_loader): + """ Model evaluation on the validation dataset.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print("===Validating===") + valid_loss = [] + self.model.eval() + valid_step = 0 + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + data_valid, labels_valid = valid_batch[0].to( + self.device), valid_batch[1].to(self.device) + N, D, C, H, W = data_valid.shape + data_valid = data_valid.view(N * D, C, H, W) + labels_valid = labels_valid.view(-1, 1) + pred_ppg_valid = self.model(data_valid) + loss = self.criterion(pred_ppg_valid, labels_valid) + valid_loss.append(loss.item()) + valid_step += 1 + vbar.set_postfix(loss=loss.item()) + valid_loss = np.asarray(valid_loss) + return np.mean(valid_loss) + + def test(self, data_loader): + """ Model evaluation on the testing dataset.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + config = self.config + + print('') + print("===Testing===") + + # Change chunk length to be test chunk length + self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + + predictions = dict() + labels = dict() + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=self.device)) + print("Testing uses pretrained model!") + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path)) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path)) + + self.model = self.model.to(self.config.DEVICE) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data_test, labels_test = test_batch[0].to( + self.config.DEVICE), test_batch[1].to(self.config.DEVICE) + N, D, C, H, W = data_test.shape + data_test = data_test.view(N * D, C, H, W) + labels_test = labels_test.view(-1, 1) + pred_ppg_test = self.model(data_test) + + if self.config.TEST.OUTPUT_SAVE_DIR: + labels_test = labels_test.cpu() + pred_ppg_test = pred_ppg_test.cpu() + + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len] + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + """Inits parameters from args and the writer for TensorboardX.""" + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + \ No newline at end of file diff --git a/neural_methods/trainer/EfficientPhysTrainer.py b/neural_methods/trainer/EfficientPhysTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac411e034f20e734f940b9b9b4c206fb9945c4f --- /dev/null +++ b/neural_methods/trainer/EfficientPhysTrainer.py @@ -0,0 +1,228 @@ +"""Trainer for EfficientPhys.""" + +import logging +import os +from collections import OrderedDict + +import numpy as np +import torch +import torch.optim as optim +from evaluation.metrics import calculate_metrics +from neural_methods.loss.NegPearsonLoss import Neg_Pearson +from neural_methods.model.EfficientPhys import EfficientPhys +from neural_methods.trainer.BaseTrainer import BaseTrainer +from tqdm import tqdm + + +class EfficientPhysTrainer(BaseTrainer): + + def __init__(self, config, data_loader): + """Inits parameters from args and the writer for TensorboardX.""" + super().__init__() + self.device = torch.device(config.DEVICE) + self.frame_depth = config.MODEL.EFFICIENTPHYS.FRAME_DEPTH + self.max_epoch_num = config.TRAIN.EPOCHS + self.model_dir = config.MODEL.MODEL_DIR + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.batch_size = config.TRAIN.BATCH_SIZE + self.num_of_gpu = config.NUM_OF_GPU_TRAIN + self.base_len = self.num_of_gpu * self.frame_depth + self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH + self.config = config + self.min_valid_loss = None + self.best_epoch = 0 + + if config.TOOLBOX_MODE == "train_and_test": + self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to( + self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + + self.num_train_batches = len(data_loader["train"]) + self.criterion = torch.nn.MSELoss() + self.optimizer = optim.AdamW( + self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0) + # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) + elif config.TOOLBOX_MODE == "only_test": + self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to( + self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + else: + raise ValueError("EfficientPhys trainer initialized in incorrect toolbox mode!") + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + mean_training_losses = [] + mean_valid_losses = [] + lrs = [] + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + running_loss = 0.0 + train_loss = [] + self.model.train() + # Model Training + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + data, labels = batch[0].to( + self.device), batch[1].to(self.device) + N, D, C, H, W = data.shape + data = data.view(N * D, C, H, W) + labels = labels.view(-1, 1) + data = data[:(N * D) // self.base_len * self.base_len] + # Add one more frame for EfficientPhys since it does torch.diff for the input + last_frame = torch.unsqueeze(data[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1) + data = torch.cat((data, last_frame), 0) + labels = labels[:(N * D) // self.base_len * self.base_len] + self.optimizer.zero_grad() + pred_ppg = self.model(data) + loss = self.criterion(pred_ppg, labels) + loss.backward() + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + + self.optimizer.step() + self.scheduler.step() + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print( + f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + train_loss.append(loss.item()) + tbar.set_postfix(loss=loss.item()) + + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(train_loss)) + + self.save_model(epoch) + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + print('validation loss: ', valid_loss) + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss)) + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + def valid(self, data_loader): + """ Model evaluation on the validation dataset.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print("===Validating===") + valid_loss = [] + self.model.eval() + valid_step = 0 + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + data_valid, labels_valid = valid_batch[0].to( + self.device), valid_batch[1].to(self.device) + N, D, C, H, W = data_valid.shape + data_valid = data_valid.view(N * D, C, H, W) + labels_valid = labels_valid.view(-1, 1) + data_valid = data_valid[:(N * D) // self.base_len * self.base_len] + # Add one more frame for EfficientPhys since it does torch.diff for the input + last_frame = torch.unsqueeze(data_valid[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1) + data_valid = torch.cat((data_valid, last_frame), 0) + labels_valid = labels_valid[:(N * D) // self.base_len * self.base_len] + pred_ppg_valid = self.model(data_valid) + loss = self.criterion(pred_ppg_valid, labels_valid) + valid_loss.append(loss.item()) + valid_step += 1 + vbar.set_postfix(loss=loss.item()) + valid_loss = np.asarray(valid_loss) + return np.mean(valid_loss) + + def test(self, data_loader): + """ Model evaluation on the testing dataset.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + + print('') + print("===Testing===") + + # Change chunk length to be test chunk length + self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + + predictions = dict() + labels = dict() + + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu"))) + print("Testing uses pretrained model!") + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu"))) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu"))) + + self.model = self.model.to(self.config.DEVICE) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data_test, labels_test = test_batch[0].to( + self.config.DEVICE), test_batch[1].to(self.config.DEVICE) + N, D, C, H, W = data_test.shape + data_test = data_test.view(N * D, C, H, W) + labels_test = labels_test.view(-1, 1) + data_test = data_test[:(N * D) // self.base_len * self.base_len] + # Add one more frame for EfficientPhys since it does torch.diff for the input + last_frame = torch.unsqueeze(data_test[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1) + data_test = torch.cat((data_test, last_frame), 0) + labels_test = labels_test[:(N * D) // self.base_len * self.base_len] + pred_ppg_test = self.model(data_test) + + if self.config.TEST.OUTPUT_SAVE_DIR: + labels_test = labels_test.cpu() + pred_ppg_test = pred_ppg_test.cpu() + + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len] + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) diff --git a/neural_methods/trainer/EfficientPhysTrainer.py.backup b/neural_methods/trainer/EfficientPhysTrainer.py.backup new file mode 100644 index 0000000000000000000000000000000000000000..93e7f80b80a2624f3db8e3125f2856a1f64f6960 --- /dev/null +++ b/neural_methods/trainer/EfficientPhysTrainer.py.backup @@ -0,0 +1,228 @@ +"""Trainer for EfficientPhys.""" + +import logging +import os +from collections import OrderedDict + +import numpy as np +import torch +import torch.optim as optim +from evaluation.metrics import calculate_metrics +from neural_methods.loss.NegPearsonLoss import Neg_Pearson +from neural_methods.model.EfficientPhys import EfficientPhys +from neural_methods.trainer.BaseTrainer import BaseTrainer +from tqdm import tqdm + + +class EfficientPhysTrainer(BaseTrainer): + + def __init__(self, config, data_loader): + """Inits parameters from args and the writer for TensorboardX.""" + super().__init__() + self.device = torch.device(config.DEVICE) + self.frame_depth = config.MODEL.EFFICIENTPHYS.FRAME_DEPTH + self.max_epoch_num = config.TRAIN.EPOCHS + self.model_dir = config.MODEL.MODEL_DIR + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.batch_size = config.TRAIN.BATCH_SIZE + self.num_of_gpu = config.NUM_OF_GPU_TRAIN + self.base_len = self.num_of_gpu * self.frame_depth + self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH + self.config = config + self.min_valid_loss = None + self.best_epoch = 0 + + if config.TOOLBOX_MODE == "train_and_test": + self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TRAIN.DATA.PREPROCESS.RESIZE.H).to( + self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + + self.num_train_batches = len(data_loader["train"]) + self.criterion = torch.nn.MSELoss() + self.optimizer = optim.AdamW( + self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0) + # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) + elif config.TOOLBOX_MODE == "only_test": + self.model = EfficientPhys(frame_depth=self.frame_depth, img_size=config.TEST.DATA.PREPROCESS.RESIZE.H).to( + self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + else: + raise ValueError("EfficientPhys trainer initialized in incorrect toolbox mode!") + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + mean_training_losses = [] + mean_valid_losses = [] + lrs = [] + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + running_loss = 0.0 + train_loss = [] + self.model.train() + # Model Training + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + data, labels = batch[0].to( + self.device), batch[1].to(self.device) + N, D, C, H, W = data.shape + data = data.view(N * D, C, H, W) + labels = labels.view(-1, 1) + data = data[:(N * D) // self.base_len * self.base_len] + # Add one more frame for EfficientPhys since it does torch.diff for the input + last_frame = torch.unsqueeze(data[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1) + data = torch.cat((data, last_frame), 0) + labels = labels[:(N * D) // self.base_len * self.base_len] + self.optimizer.zero_grad() + pred_ppg = self.model(data) + loss = self.criterion(pred_ppg, labels) + loss.backward() + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + + self.optimizer.step() + self.scheduler.step() + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print( + f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + train_loss.append(loss.item()) + tbar.set_postfix(loss=loss.item()) + + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(train_loss)) + + self.save_model(epoch) + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + print('validation loss: ', valid_loss) + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss)) + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + def valid(self, data_loader): + """ Model evaluation on the validation dataset.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print("===Validating===") + valid_loss = [] + self.model.eval() + valid_step = 0 + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + data_valid, labels_valid = valid_batch[0].to( + self.device), valid_batch[1].to(self.device) + N, D, C, H, W = data_valid.shape + data_valid = data_valid.view(N * D, C, H, W) + labels_valid = labels_valid.view(-1, 1) + data_valid = data_valid[:(N * D) // self.base_len * self.base_len] + # Add one more frame for EfficientPhys since it does torch.diff for the input + last_frame = torch.unsqueeze(data_valid[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1) + data_valid = torch.cat((data_valid, last_frame), 0) + labels_valid = labels_valid[:(N * D) // self.base_len * self.base_len] + pred_ppg_valid = self.model(data_valid) + loss = self.criterion(pred_ppg_valid, labels_valid) + valid_loss.append(loss.item()) + valid_step += 1 + vbar.set_postfix(loss=loss.item()) + valid_loss = np.asarray(valid_loss) + return np.mean(valid_loss) + + def test(self, data_loader): + """ Model evaluation on the testing dataset.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + + print('') + print("===Testing===") + + # Change chunk length to be test chunk length + self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + + predictions = dict() + labels = dict() + + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH)) + print("Testing uses pretrained model!") + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path)) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path)) + + self.model = self.model.to(self.config.DEVICE) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data_test, labels_test = test_batch[0].to( + self.config.DEVICE), test_batch[1].to(self.config.DEVICE) + N, D, C, H, W = data_test.shape + data_test = data_test.view(N * D, C, H, W) + labels_test = labels_test.view(-1, 1) + data_test = data_test[:(N * D) // self.base_len * self.base_len] + # Add one more frame for EfficientPhys since it does torch.diff for the input + last_frame = torch.unsqueeze(data_test[-1, :, :, :], 0).repeat(self.num_of_gpu, 1, 1, 1) + data_test = torch.cat((data_test, last_frame), 0) + labels_test = labels_test[:(N * D) // self.base_len * self.base_len] + pred_ppg_test = self.model(data_test) + + if self.config.TEST.OUTPUT_SAVE_DIR: + labels_test = labels_test.cpu() + pred_ppg_test = pred_ppg_test.cpu() + + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx * self.chunk_len:(idx + 1) * self.chunk_len] + labels[subj_index][sort_index] = labels_test[idx * self.chunk_len:(idx + 1) * self.chunk_len] + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) diff --git a/neural_methods/trainer/FactorizePhysTrainer.py b/neural_methods/trainer/FactorizePhysTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b565d60d37e9b097b9e3d6049a5f25e2b653d95 --- /dev/null +++ b/neural_methods/trainer/FactorizePhysTrainer.py @@ -0,0 +1,312 @@ +""" +FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing +NeurIPS 2024 +Jitesh Joshi, Sos S. Agaian, and Youngjun Cho +""" + +import os +import numpy as np +import torch +import torch.optim as optim +from evaluation.metrics import calculate_metrics +from neural_methods.loss.NegPearsonLoss import Neg_Pearson +from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys +from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig +from neural_methods.trainer.BaseTrainer import BaseTrainer +from tqdm import tqdm + + +class FactorizePhysTrainer(BaseTrainer): + + def __init__(self, config, data_loader): + """Inits parameters from args and the writer for TensorboardX.""" + super().__init__() + self.max_epoch_num = config.TRAIN.EPOCHS + self.model_dir = config.MODEL.MODEL_DIR + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.batch_size = config.TRAIN.BATCH_SIZE + self.num_of_gpu = config.NUM_OF_GPU_TRAIN + self.dropout_rate = config.MODEL.DROP_RATE + self.base_len = self.num_of_gpu + self.config = config + self.min_valid_loss = None + self.best_epoch = 0 + + if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0: + dev_list = [int(d) for d in config.DEVICE.replace("cuda:", "").split(",")] + self.device = torch.device(dev_list[0]) #currently toolbox only supports 1 GPU + self.num_of_gpu = 1 #config.NUM_OF_GPU_TRAIN # set number of used GPUs + else: + self.device = torch.device("cpu") # if no GPUs set device is CPU + self.num_of_gpu = 0 # no GPUs used + + frames = self.config.MODEL.FactorizePhys.FRAME_NUM + in_channels = self.config.MODEL.FactorizePhys.CHANNELS + model_type = self.config.MODEL.FactorizePhys.TYPE + model_type = model_type.lower() + + md_config = {} + md_config["FRAME_NUM"] = self.config.MODEL.FactorizePhys.FRAME_NUM + md_config["MD_TYPE"] = self.config.MODEL.FactorizePhys.MD_TYPE + md_config["MD_FSAM"] = self.config.MODEL.FactorizePhys.MD_FSAM + md_config["MD_TRANSFORM"] = self.config.MODEL.FactorizePhys.MD_TRANSFORM + md_config["MD_S"] = self.config.MODEL.FactorizePhys.MD_S + md_config["MD_R"] = self.config.MODEL.FactorizePhys.MD_R + md_config["MD_STEPS"] = self.config.MODEL.FactorizePhys.MD_STEPS + md_config["MD_INFERENCE"] = self.config.MODEL.FactorizePhys.MD_INFERENCE + md_config["MD_RESIDUAL"] = self.config.MODEL.FactorizePhys.MD_RESIDUAL + + self.md_infer = self.config.MODEL.FactorizePhys.MD_INFERENCE + self.use_fsam = self.config.MODEL.FactorizePhys.MD_FSAM + + if model_type == "standard": + self.model = FactorizePhys(frames=frames, md_config=md_config, in_channels=in_channels, + dropout=self.dropout_rate, device=self.device) # [3, T, 72,72] + elif model_type == "big": + self.model = FactorizePhysBig(frames=frames, md_config=md_config, in_channels=in_channels, + dropout=self.dropout_rate, device=self.device) # [3, T, 144,144] + else: + print("Unexpected model type specified. Should be standard or big, but specified:", model_type) + exit() + + if torch.cuda.device_count() > 0 and self.num_of_gpu > 0: # distribute model across GPUs + self.model = torch.nn.DataParallel(self.model, device_ids=[self.device]) # data parallel model + else: + self.model = torch.nn.DataParallel(self.model).to(self.device) + + if self.config.TOOLBOX_MODE == "train_and_test" or self.config.TOOLBOX_MODE == "only_train": + self.num_train_batches = len(data_loader["train"]) + self.criterion = Neg_Pearson() + self.optimizer = optim.Adam( + self.model.parameters(), lr=self.config.TRAIN.LR) + # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, max_lr=self.config.TRAIN.LR, epochs=self.config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) + elif self.config.TOOLBOX_MODE == "only_test": + pass + else: + raise ValueError("FactorizePhys trainer initialized in incorrect toolbox mode!") + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + mean_training_losses = [] + mean_valid_losses = [] + mean_appx_error = [] + lrs = [] + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + running_loss = 0.0 + train_loss = [] + appx_error_list = [] + self.model.train() + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + + data = batch[0].to(self.device) + labels = batch[1].to(self.device) + + if len(labels.shape) > 2: + labels = labels[..., 0] # Compatibility wigth multi-signal labelled data + labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize + last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) + data = torch.cat((data, last_frame), 2) + + # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) + # labels = torch.cat((labels, last_sample), 0) + # labels = torch.diff(labels, dim=0) + # labels = labels/ torch.std(labels) # normalize + # labels[torch.isnan(labels)] = 0 + + self.optimizer.zero_grad() + if self.model.training and self.use_fsam: + pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data) + else: + pred_ppg, vox_embed = self.model(data) + + pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize + + loss = self.criterion(pred_ppg, labels) + + loss.backward() + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print( + f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + train_loss.append(loss.item()) + if self.use_fsam: + appx_error_list.append(appx_error.item()) + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + + self.optimizer.step() + self.scheduler.step() + + if self.use_fsam: + tbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item()) + else: + tbar.set_postfix(loss=loss.item()) + + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(train_loss)) + if self.use_fsam: + mean_appx_error.append(np.mean(appx_error_list)) + print("Mean train loss: {}, Mean appx error: {}".format( + np.mean(train_loss), np.mean(appx_error_list))) + else: + print("Mean train loss: {}".format(np.mean(train_loss))) + + self.save_model(epoch) + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + print('validation loss: ', valid_loss) + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format( + self.best_epoch, self.min_valid_loss)) + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + def valid(self, data_loader): + """ Runs the model on valid sets.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print(" ====Validing===") + valid_loss = [] + self.model.eval() + valid_step = 0 + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + + data, labels = valid_batch[0].to(self.device), valid_batch[1].to(self.device) + if len(labels.shape) > 2: + labels = labels[..., 0] # Compatibility wigth multi-signal labelled data + labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize + + last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) + data = torch.cat((data, last_frame), 2) + + # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) + # labels = torch.cat((labels, last_sample), 0) + # labels = torch.diff(labels, dim=0) + # labels = labels/ torch.std(labels) # normalize + # labels[torch.isnan(labels)] = 0 + + if self.md_infer and self.use_fsam: + pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data) + else: + pred_ppg, vox_embed = self.model(data) + pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize + loss = self.criterion(pred_ppg, labels) + + valid_loss.append(loss.item()) + valid_step += 1 + # vbar.set_postfix(loss=loss.item()) + if self.md_infer and self.use_fsam: + vbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item()) + else: + vbar.set_postfix(loss=loss.item()) + valid_loss = np.asarray(valid_loss) + return np.mean(valid_loss) + + def test(self, data_loader): + """ Runs the model on test sets.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + + print('') + print("===Testing===") + predictions = dict() + labels = dict() + + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu")), strict=False) + print("Testing uses pretrained model!") + print(self.config.INFERENCE.MODEL_PATH) + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu")), strict=False) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu")), strict=False) + + self.model = self.model.to(self.device) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data, labels_test = test_batch[0].to(self.device), test_batch[1].to(self.device) + + if len(labels_test.shape) > 2: + labels_test = labels_test[..., 0] # Compatibility wigth multi-signal labelled data + labels_test = (labels_test - torch.mean(labels_test)) / torch.std(labels_test) # normalize + + last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) + data = torch.cat((data, last_frame), 2) + + # last_sample = torch.unsqueeze(labels_test[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) + # labels_test = torch.cat((labels_test, last_sample), 0) + # labels_test = torch.diff(labels_test, dim=0) + # labels_test = labels_test/ torch.std(labels_test) # normalize + # labels_test[torch.isnan(labels_test)] = 0 + + if self.md_infer and self.use_fsam: + pred_ppg_test, vox_embed, factorized_embed, appx_error = self.model(data) + else: + pred_ppg_test, vox_embed = self.model(data) + pred_ppg_test = (pred_ppg_test - torch.mean(pred_ppg_test)) / torch.std(pred_ppg_test) # normalize + + if self.config.TEST.OUTPUT_SAVE_DIR: + labels_test = labels_test.cpu() + pred_ppg_test = pred_ppg_test.cpu() + + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx] + labels[subj_index][sort_index] = labels_test[idx] + + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) diff --git a/neural_methods/trainer/FactorizePhysTrainer.py.backup b/neural_methods/trainer/FactorizePhysTrainer.py.backup new file mode 100644 index 0000000000000000000000000000000000000000..b2c3e8a9bdb93c9f3c7c78fc922d0ab28d5aa9ab --- /dev/null +++ b/neural_methods/trainer/FactorizePhysTrainer.py.backup @@ -0,0 +1,312 @@ +""" +FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing +NeurIPS 2024 +Jitesh Joshi, Sos S. Agaian, and Youngjun Cho +""" + +import os +import numpy as np +import torch +import torch.optim as optim +from evaluation.metrics import calculate_metrics +from neural_methods.loss.NegPearsonLoss import Neg_Pearson +from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys +from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig +from neural_methods.trainer.BaseTrainer import BaseTrainer +from tqdm import tqdm + + +class FactorizePhysTrainer(BaseTrainer): + + def __init__(self, config, data_loader): + """Inits parameters from args and the writer for TensorboardX.""" + super().__init__() + self.max_epoch_num = config.TRAIN.EPOCHS + self.model_dir = config.MODEL.MODEL_DIR + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.batch_size = config.TRAIN.BATCH_SIZE + self.num_of_gpu = config.NUM_OF_GPU_TRAIN + self.dropout_rate = config.MODEL.DROP_RATE + self.base_len = self.num_of_gpu + self.config = config + self.min_valid_loss = None + self.best_epoch = 0 + + if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0: + dev_list = [int(d) for d in config.DEVICE.replace("cuda:", "").split(",")] + self.device = torch.device(dev_list[0]) #currently toolbox only supports 1 GPU + self.num_of_gpu = 1 #config.NUM_OF_GPU_TRAIN # set number of used GPUs + else: + self.device = torch.device("cpu") # if no GPUs set device is CPU + self.num_of_gpu = 0 # no GPUs used + + frames = self.config.MODEL.FactorizePhys.FRAME_NUM + in_channels = self.config.MODEL.FactorizePhys.CHANNELS + model_type = self.config.MODEL.FactorizePhys.TYPE + model_type = model_type.lower() + + md_config = {} + md_config["FRAME_NUM"] = self.config.MODEL.FactorizePhys.FRAME_NUM + md_config["MD_TYPE"] = self.config.MODEL.FactorizePhys.MD_TYPE + md_config["MD_FSAM"] = self.config.MODEL.FactorizePhys.MD_FSAM + md_config["MD_TRANSFORM"] = self.config.MODEL.FactorizePhys.MD_TRANSFORM + md_config["MD_S"] = self.config.MODEL.FactorizePhys.MD_S + md_config["MD_R"] = self.config.MODEL.FactorizePhys.MD_R + md_config["MD_STEPS"] = self.config.MODEL.FactorizePhys.MD_STEPS + md_config["MD_INFERENCE"] = self.config.MODEL.FactorizePhys.MD_INFERENCE + md_config["MD_RESIDUAL"] = self.config.MODEL.FactorizePhys.MD_RESIDUAL + + self.md_infer = self.config.MODEL.FactorizePhys.MD_INFERENCE + self.use_fsam = self.config.MODEL.FactorizePhys.MD_FSAM + + if model_type == "standard": + self.model = FactorizePhys(frames=frames, md_config=md_config, in_channels=in_channels, + dropout=self.dropout_rate, device=self.device) # [3, T, 72,72] + elif model_type == "big": + self.model = FactorizePhysBig(frames=frames, md_config=md_config, in_channels=in_channels, + dropout=self.dropout_rate, device=self.device) # [3, T, 144,144] + else: + print("Unexpected model type specified. Should be standard or big, but specified:", model_type) + exit() + + if torch.cuda.device_count() > 0 and self.num_of_gpu > 0: # distribute model across GPUs + self.model = torch.nn.DataParallel(self.model, device_ids=[self.device]) # data parallel model + else: + self.model = torch.nn.DataParallel(self.model).to(self.device) + + if self.config.TOOLBOX_MODE == "train_and_test" or self.config.TOOLBOX_MODE == "only_train": + self.num_train_batches = len(data_loader["train"]) + self.criterion = Neg_Pearson() + self.optimizer = optim.Adam( + self.model.parameters(), lr=self.config.TRAIN.LR) + # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, max_lr=self.config.TRAIN.LR, epochs=self.config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) + elif self.config.TOOLBOX_MODE == "only_test": + pass + else: + raise ValueError("FactorizePhys trainer initialized in incorrect toolbox mode!") + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + mean_training_losses = [] + mean_valid_losses = [] + mean_appx_error = [] + lrs = [] + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + running_loss = 0.0 + train_loss = [] + appx_error_list = [] + self.model.train() + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + + data = batch[0].to(self.device) + labels = batch[1].to(self.device) + + if len(labels.shape) > 2: + labels = labels[..., 0] # Compatibility wigth multi-signal labelled data + labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize + last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) + data = torch.cat((data, last_frame), 2) + + # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) + # labels = torch.cat((labels, last_sample), 0) + # labels = torch.diff(labels, dim=0) + # labels = labels/ torch.std(labels) # normalize + # labels[torch.isnan(labels)] = 0 + + self.optimizer.zero_grad() + if self.model.training and self.use_fsam: + pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data) + else: + pred_ppg, vox_embed = self.model(data) + + pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize + + loss = self.criterion(pred_ppg, labels) + + loss.backward() + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print( + f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + train_loss.append(loss.item()) + if self.use_fsam: + appx_error_list.append(appx_error.item()) + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + + self.optimizer.step() + self.scheduler.step() + + if self.use_fsam: + tbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item()) + else: + tbar.set_postfix(loss=loss.item()) + + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(train_loss)) + if self.use_fsam: + mean_appx_error.append(np.mean(appx_error_list)) + print("Mean train loss: {}, Mean appx error: {}".format( + np.mean(train_loss), np.mean(appx_error_list))) + else: + print("Mean train loss: {}".format(np.mean(train_loss))) + + self.save_model(epoch) + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + print('validation loss: ', valid_loss) + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format( + self.best_epoch, self.min_valid_loss)) + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + def valid(self, data_loader): + """ Runs the model on valid sets.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print(" ====Validing===") + valid_loss = [] + self.model.eval() + valid_step = 0 + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + + data, labels = valid_batch[0].to(self.device), valid_batch[1].to(self.device) + if len(labels.shape) > 2: + labels = labels[..., 0] # Compatibility wigth multi-signal labelled data + labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize + + last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) + data = torch.cat((data, last_frame), 2) + + # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) + # labels = torch.cat((labels, last_sample), 0) + # labels = torch.diff(labels, dim=0) + # labels = labels/ torch.std(labels) # normalize + # labels[torch.isnan(labels)] = 0 + + if self.md_infer and self.use_fsam: + pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data) + else: + pred_ppg, vox_embed = self.model(data) + pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize + loss = self.criterion(pred_ppg, labels) + + valid_loss.append(loss.item()) + valid_step += 1 + # vbar.set_postfix(loss=loss.item()) + if self.md_infer and self.use_fsam: + vbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item()) + else: + vbar.set_postfix(loss=loss.item()) + valid_loss = np.asarray(valid_loss) + return np.mean(valid_loss) + + def test(self, data_loader): + """ Runs the model on test sets.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + + print('') + print("===Testing===") + predictions = dict() + labels = dict() + + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=self.device), strict=False) + print("Testing uses pretrained model!") + print(self.config.INFERENCE.MODEL_PATH) + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=self.device), strict=False) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path, map_location=self.device), strict=False) + + self.model = self.model.to(self.device) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data, labels_test = test_batch[0].to(self.device), test_batch[1].to(self.device) + + if len(labels_test.shape) > 2: + labels_test = labels_test[..., 0] # Compatibility wigth multi-signal labelled data + labels_test = (labels_test - torch.mean(labels_test)) / torch.std(labels_test) # normalize + + last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) + data = torch.cat((data, last_frame), 2) + + # last_sample = torch.unsqueeze(labels_test[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) + # labels_test = torch.cat((labels_test, last_sample), 0) + # labels_test = torch.diff(labels_test, dim=0) + # labels_test = labels_test/ torch.std(labels_test) # normalize + # labels_test[torch.isnan(labels_test)] = 0 + + if self.md_infer and self.use_fsam: + pred_ppg_test, vox_embed, factorized_embed, appx_error = self.model(data) + else: + pred_ppg_test, vox_embed = self.model(data) + pred_ppg_test = (pred_ppg_test - torch.mean(pred_ppg_test)) / torch.std(pred_ppg_test) # normalize + + if self.config.TEST.OUTPUT_SAVE_DIR: + labels_test = labels_test.cpu() + pred_ppg_test = pred_ppg_test.cpu() + + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx] + labels[subj_index][sort_index] = labels_test[idx] + + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) diff --git a/neural_methods/trainer/PhysFormerTrainer.py b/neural_methods/trainer/PhysFormerTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..1fc4ece67dae710f97e6f59c0db4c9543c4fac9a --- /dev/null +++ b/neural_methods/trainer/PhysFormerTrainer.py @@ -0,0 +1,273 @@ +"""Trainer for Physformer. + +Based on open-source code from the original PhysFormer authors below: +https://github.com/ZitongYu/PhysFormer/blob/main/train_Physformer_160_VIPL.py + +We also thank the PhysBench authors for their open-source code based on the code +of the original authors. Their code below provided a better reference for tuning loss +parameters of interest and utilizing RSME as a validation loss: +https://github.com/KegangWangCCNU/PhysBench/blob/main/benchmark_addition/PhysFormer_pure.ipynb + +""" + +import os +import numpy as np +import math +import torch +import torch.optim as optim +from evaluation.metrics import calculate_metrics +from neural_methods.loss.PhysNetNegPearsonLoss import Neg_Pearson +from neural_methods.loss.PhysFormerLossComputer import TorchLossComputer +from neural_methods.model.PhysFormer import ViT_ST_ST_Compact3_TDC_gra_sharp +from neural_methods.trainer.BaseTrainer import BaseTrainer +from tqdm import tqdm +from scipy.signal import welch + +class PhysFormerTrainer(BaseTrainer): + + def __init__(self, config, data_loader): + """Inits parameters from args and the writer for TensorboardX.""" + super().__init__() + self.device = torch.device(config.DEVICE) + self.max_epoch_num = config.TRAIN.EPOCHS + self.model_dir = config.MODEL.MODEL_DIR + self.dropout_rate = config.MODEL.DROP_RATE + self.patch_size = config.MODEL.PHYSFORMER.PATCH_SIZE + self.dim = config.MODEL.PHYSFORMER.DIM + self.ff_dim = config.MODEL.PHYSFORMER.FF_DIM + self.num_heads = config.MODEL.PHYSFORMER.NUM_HEADS + self.num_layers = config.MODEL.PHYSFORMER.NUM_LAYERS + self.theta = config.MODEL.PHYSFORMER.THETA + self.model_file_name = config.TRAIN.MODEL_FILE_NAME + self.batch_size = config.TRAIN.BATCH_SIZE + self.num_of_gpu = config.NUM_OF_GPU_TRAIN + self.frame_rate = config.TRAIN.DATA.FS + self.config = config + self.min_valid_loss = None + self.best_epoch = 0 + + if config.TOOLBOX_MODE == "train_and_test": + self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH + self.model = ViT_ST_ST_Compact3_TDC_gra_sharp( + image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W), + patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers, + dropout_rate=self.dropout_rate, theta=self.theta).to(self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + + self.num_train_batches = len(data_loader["train"]) + self.criterion_reg = torch.nn.MSELoss() + self.criterion_L1loss = torch.nn.L1Loss() + self.criterion_class = torch.nn.CrossEntropyLoss() + self.criterion_Pearson = Neg_Pearson() + self.optimizer = optim.Adam(self.model.parameters(), lr=config.TRAIN.LR, weight_decay=0.00005) + # TODO: In both the PhysFormer repo's training example and other implementations of a PhysFormer trainer, + # a step_size that doesn't end up changing the LR always seems to be used. This seems to defeat the point + # of using StepLR in the first place. Consider investigating and using another approach (e.g., OneCycleLR). + self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=50, gamma=0.5) + elif config.TOOLBOX_MODE == "only_test": + self.chunk_len = config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + self.model = ViT_ST_ST_Compact3_TDC_gra_sharp( + image_size=(self.chunk_len,config.TRAIN.DATA.PREPROCESS.RESIZE.H,config.TRAIN.DATA.PREPROCESS.RESIZE.W), + patches=(self.patch_size,) * 3, dim=self.dim, ff_dim=self.ff_dim, num_heads=self.num_heads, num_layers=self.num_layers, + dropout_rate=self.dropout_rate, theta=self.theta).to(self.device) + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + else: + raise ValueError("Physformer trainer initialized in incorrect toolbox mode!") + + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + # a --> Pearson loss; b --> frequency loss + a_start = 1.0 + b_start = 1.0 + exp_a = 0.5 # Unused + exp_b = 1.0 + + # TODO: Expand tracking and subsequent plotting of these losses for PhysFormer + mean_training_losses = [] + mean_valid_losses = [] + lrs = [] + + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + loss_rPPG_avg = [] + loss_peak_avg = [] + loss_kl_avg_test = [] + loss_hr_mae = [] + + self.model.train() + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + hr = torch.tensor([self.get_hr(i) for i in batch[1]]).float().to(self.device) + data, label = batch[0].float().to(self.device), batch[1].float().to(self.device) + + self.optimizer.zero_grad() + + gra_sharp = 2.0 + rPPG, _, _, _ = self.model(data, gra_sharp) + rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG, axis=-1).view(-1, 1) # normalize + loss_rPPG = self.criterion_Pearson(rPPG, label) + + fre_loss = 0.0 + kl_loss = 0.0 + train_mae = 0.0 + for bb in range(data.shape[0]): + loss_distribution_kl, \ + fre_loss_temp, \ + train_mae_temp = TorchLossComputer.cross_entropy_power_spectrum_DLDL_softmax2( + rPPG[bb], + hr[bb], + self.frame_rate, + std=1.0 + ) + fre_loss = fre_loss+fre_loss_temp + kl_loss = kl_loss+loss_distribution_kl + train_mae = train_mae+train_mae_temp + fre_loss /= data.shape[0] + kl_loss /= data.shape[0] + train_mae /= data.shape[0] + + if epoch>10: + a = 0.05 + b = 5.0 + else: + a = a_start + # exp ascend + b = b_start*math.pow(exp_b, epoch/10.0) + + loss = a*loss_rPPG + b*(fre_loss+kl_loss) + loss.backward() + self.optimizer.step() + + n = data.size(0) + loss_rPPG_avg.append(float(loss_rPPG.data)) + loss_peak_avg.append(float(fre_loss.data)) + loss_kl_avg_test.append(float(kl_loss.data)) + loss_hr_mae.append(float(train_mae)) + if idx % 100 == 99: # print every 100 mini-batches + print(f'\nepoch:{epoch}, batch:{idx + 1}, total:{len(data_loader["train"]) // self.batch_size}, ' + f'lr:0.0001, sharp:{gra_sharp:.3f}, a:{a:.3f}, NegPearson:{np.mean(loss_rPPG_avg[-2000:]):.4f}, ' + f'\nb:{b:.3f}, kl:{np.mean(loss_kl_avg_test[-2000:]):.3f}, fre_CEloss:{np.mean(loss_peak_avg[-2000:]):.3f}, ' + f'hr_mae:{np.mean(loss_hr_mae[-2000:]):.3f}') + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(loss_rPPG_avg)) + self.save_model(epoch) + self.scheduler.step() + self.model.eval() + + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + print(f'Validation RMSE:{valid_loss:.3f}, batch:{idx+1}') + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format( + self.best_epoch, self.min_valid_loss)) + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + def valid(self, data_loader): + """ Runs the model on valid sets.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print(" ====Validating===") + self.optimizer.zero_grad() + with torch.no_grad(): + hrs = [] + vbar = tqdm(data_loader["valid"], ncols=80) + for val_idx, val_batch in enumerate(vbar): + data, label = val_batch[0].float().to(self.device), val_batch[1].float().to(self.device) + gra_sharp = 2.0 + rPPG, _, _, _ = self.model(data, gra_sharp) + rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG).view(-1, 1) + for _1, _2 in zip(rPPG, label): + hrs.append((self.get_hr(_1.cpu().detach().numpy()), self.get_hr(_2.cpu().detach().numpy()))) + RMSE = np.mean([(i-j)**2 for i, j in hrs])**0.5 + return RMSE + + def test(self, data_loader): + """ Runs the model on test sets.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + + print('') + print("===Testing===") + + # Change chunk length to be test chunk length + self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + + predictions = dict() + labels = dict() + + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu"))) + print("Testing uses pretrained model!") + print(self.config.INFERENCE.MODEL_PATH) + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu"))) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu"))) + + self.model = self.model.to(self.config.DEVICE) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data, label = test_batch[0].to( + self.config.DEVICE), test_batch[1].to(self.config.DEVICE) + gra_sharp = 2.0 + pred_ppg_test, _, _, _ = self.model(data, gra_sharp) + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx] + labels[subj_index][sort_index] = label[idx] + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) + + # HR calculation based on ground truth label + def get_hr(self, y, sr=30, min=30, max=180): + p, q = welch(y, sr, nfft=1e5/sr, nperseg=np.min((len(y)-1, 256))) + return p[(p>min/60)&(pmin/60)&(p Pearson loss; b --> frequency loss + a_start = 1.0 + b_start = 1.0 + exp_a = 0.5 # Unused + exp_b = 1.0 + + # TODO: Expand tracking and subsequent plotting of these losses for PhysFormer + mean_training_losses = [] + mean_valid_losses = [] + lrs = [] + + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + loss_rPPG_avg = [] + loss_peak_avg = [] + loss_kl_avg_test = [] + loss_hr_mae = [] + + self.model.train() + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + hr = torch.tensor([self.get_hr(i) for i in batch[1]]).float().to(self.device) + data, label = batch[0].float().to(self.device), batch[1].float().to(self.device) + + self.optimizer.zero_grad() + + gra_sharp = 2.0 + rPPG, _, _, _ = self.model(data, gra_sharp) + rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG, axis=-1).view(-1, 1) # normalize + loss_rPPG = self.criterion_Pearson(rPPG, label) + + fre_loss = 0.0 + kl_loss = 0.0 + train_mae = 0.0 + for bb in range(data.shape[0]): + loss_distribution_kl, \ + fre_loss_temp, \ + train_mae_temp = TorchLossComputer.cross_entropy_power_spectrum_DLDL_softmax2( + rPPG[bb], + hr[bb], + self.frame_rate, + std=1.0 + ) + fre_loss = fre_loss+fre_loss_temp + kl_loss = kl_loss+loss_distribution_kl + train_mae = train_mae+train_mae_temp + fre_loss /= data.shape[0] + kl_loss /= data.shape[0] + train_mae /= data.shape[0] + + if epoch>10: + a = 0.05 + b = 5.0 + else: + a = a_start + # exp ascend + b = b_start*math.pow(exp_b, epoch/10.0) + + loss = a*loss_rPPG + b*(fre_loss+kl_loss) + loss.backward() + self.optimizer.step() + + n = data.size(0) + loss_rPPG_avg.append(float(loss_rPPG.data)) + loss_peak_avg.append(float(fre_loss.data)) + loss_kl_avg_test.append(float(kl_loss.data)) + loss_hr_mae.append(float(train_mae)) + if idx % 100 == 99: # print every 100 mini-batches + print(f'\nepoch:{epoch}, batch:{idx + 1}, total:{len(data_loader["train"]) // self.batch_size}, ' + f'lr:0.0001, sharp:{gra_sharp:.3f}, a:{a:.3f}, NegPearson:{np.mean(loss_rPPG_avg[-2000:]):.4f}, ' + f'\nb:{b:.3f}, kl:{np.mean(loss_kl_avg_test[-2000:]):.3f}, fre_CEloss:{np.mean(loss_peak_avg[-2000:]):.3f}, ' + f'hr_mae:{np.mean(loss_hr_mae[-2000:]):.3f}') + + # Append the current learning rate to the list + lrs.append(self.scheduler.get_last_lr()) + # Append the mean training loss for the epoch + mean_training_losses.append(np.mean(loss_rPPG_avg)) + self.save_model(epoch) + self.scheduler.step() + self.model.eval() + + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + mean_valid_losses.append(valid_loss) + print(f'Validation RMSE:{valid_loss:.3f}, batch:{idx+1}') + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format( + self.best_epoch, self.min_valid_loss)) + if self.config.TRAIN.PLOT_LOSSES_AND_LR: + self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) + + def valid(self, data_loader): + """ Runs the model on valid sets.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print(" ====Validating===") + self.optimizer.zero_grad() + with torch.no_grad(): + hrs = [] + vbar = tqdm(data_loader["valid"], ncols=80) + for val_idx, val_batch in enumerate(vbar): + data, label = val_batch[0].float().to(self.device), val_batch[1].float().to(self.device) + gra_sharp = 2.0 + rPPG, _, _, _ = self.model(data, gra_sharp) + rPPG = (rPPG-torch.mean(rPPG, axis=-1).view(-1, 1))/torch.std(rPPG).view(-1, 1) + for _1, _2 in zip(rPPG, label): + hrs.append((self.get_hr(_1.cpu().detach().numpy()), self.get_hr(_2.cpu().detach().numpy()))) + RMSE = np.mean([(i-j)**2 for i, j in hrs])**0.5 + return RMSE + + def test(self, data_loader): + """ Runs the model on test sets.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + + print('') + print("===Testing===") + + # Change chunk length to be test chunk length + self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH + + predictions = dict() + labels = dict() + + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH)) + print("Testing uses pretrained model!") + print(self.config.INFERENCE.MODEL_PATH) + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path)) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path)) + + self.model = self.model.to(self.config.DEVICE) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data, label = test_batch[0].to( + self.config.DEVICE), test_batch[1].to(self.config.DEVICE) + gra_sharp = 2.0 + pred_ppg_test, _, _, _ = self.model(data, gra_sharp) + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx] + labels[subj_index][sort_index] = label[idx] + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) + + # HR calculation based on ground truth label + def get_hr(self, y, sr=30, min=30, max=180): + p, q = welch(y, sr, nfft=1e5/sr, nperseg=np.min((len(y)-1, 256))) + return p[(p>min/60)&(pmin/60)&(p 0: + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + + if config.TOOLBOX_MODE == "train_and_test": + self.num_train_batches = len(data_loader["train"]) + self.criterion_Pearson = Neg_Pearson() + self.optimizer = optim.Adam( + self.model.parameters(), lr=config.TRAIN.LR, weight_decay = 0.0005) + # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) + elif config.TOOLBOX_MODE == "only_test": + self.criterion_Pearson_test = Neg_Pearson() + pass + else: + raise ValueError("PhysNet trainer initialized in incorrect toolbox mode!") + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + self.model.train() + loss_rPPG_avg = [] + running_loss = 0.0 + # Model Training + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + data, labels = batch[0].float(), batch[1].float() + N, D, C, H, W = data.shape + + data = data.to(self.device) + labels = labels.to(self.device) + + self.optimizer.zero_grad() + pred_ppg = self.model(data) + + pred_ppg = (pred_ppg-torch.mean(pred_ppg, axis=-1).view(-1, 1))/torch.std(pred_ppg, axis=-1).view(-1, 1) # normalize + + labels = (labels - torch.mean(labels)) / \ + torch.std(labels) + loss = self.criterion_Pearson(pred_ppg, labels) + + loss.backward() + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print( + f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + self.optimizer.step() + self.scheduler.step() + tbar.set_postfix(loss=loss.item()) + + self.save_model(epoch) + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + print('validation loss: ', valid_loss) + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + torch.cuda.empty_cache() + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss)) + + def valid(self, data_loader): + """ Runs the model on valid sets.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print(" ====Validing===") + valid_loss = [] + self.model.eval() + valid_step = 0 + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + BVP_label = valid_batch[1].to( + torch.float32).to(self.device) + rPPG = self.model( + valid_batch[0].to(torch.float32).to(self.device)) + rPPG = (rPPG - torch.mean(rPPG)) / torch.std(rPPG) # normalize + BVP_label = (BVP_label - torch.mean(BVP_label)) / torch.std(BVP_label) # normalize + loss_ecg = self.criterion_Pearson(rPPG, BVP_label) + valid_loss.append(loss_ecg.item()) + valid_step += 1 + vbar.set_postfix(loss=loss_ecg.item()) + valid_loss = np.asarray(valid_loss) + return np.mean(valid_loss) + + def test(self, data_loader): + """ Runs the model on test sets.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + + print('') + print("===Testing===") + predictions = dict() + labels = dict() + + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=torch.device("cpu"))) + print("Testing uses pretrained model!") + print(self.config.INFERENCE.MODEL_PATH) + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=torch.device("cpu"))) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device("cpu"))) + + self.model = self.model.to(self.config.DEVICE) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data, label = test_batch[0].to( + self.config.DEVICE), test_batch[1].to(self.config.DEVICE) + pred_ppg_test = self.model(data) + + if self.config.TEST.OUTPUT_SAVE_DIR: + label = label.cpu() + pred_ppg_test = pred_ppg_test.cpu() + + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx] + labels[subj_index][sort_index] = label[idx] + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) + + # HR calculation based on ground truth label + def get_hr(self, y, sr=30, min=30, max=180): + p, q = welch(y, sr, nfft=1e5/sr, nperseg=np.min((len(y)-1, 256))) + return p[(p>min/60)&(pmin/60)&(p 0: + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) + + if config.TOOLBOX_MODE == "train_and_test": + self.num_train_batches = len(data_loader["train"]) + self.criterion_Pearson = Neg_Pearson() + self.optimizer = optim.Adam( + self.model.parameters(), lr=config.TRAIN.LR, weight_decay = 0.0005) + # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, max_lr=config.TRAIN.LR, epochs=config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) + elif config.TOOLBOX_MODE == "only_test": + self.criterion_Pearson_test = Neg_Pearson() + pass + else: + raise ValueError("PhysNet trainer initialized in incorrect toolbox mode!") + + def train(self, data_loader): + """Training routine for model""" + if data_loader["train"] is None: + raise ValueError("No data for train") + + for epoch in range(self.max_epoch_num): + print('') + print(f"====Training Epoch: {epoch}====") + self.model.train() + loss_rPPG_avg = [] + running_loss = 0.0 + # Model Training + tbar = tqdm(data_loader["train"], ncols=80) + for idx, batch in enumerate(tbar): + tbar.set_description("Train epoch %s" % epoch) + data, labels = batch[0].float(), batch[1].float() + N, D, C, H, W = data.shape + + data = data.to(self.device) + labels = labels.to(self.device) + + self.optimizer.zero_grad() + pred_ppg = self.model(data) + + pred_ppg = (pred_ppg-torch.mean(pred_ppg, axis=-1).view(-1, 1))/torch.std(pred_ppg, axis=-1).view(-1, 1) # normalize + + labels = (labels - torch.mean(labels)) / \ + torch.std(labels) + loss = self.criterion_Pearson(pred_ppg, labels) + + loss.backward() + running_loss += loss.item() + if idx % 100 == 99: # print every 100 mini-batches + print( + f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') + running_loss = 0.0 + self.optimizer.step() + self.scheduler.step() + tbar.set_postfix(loss=loss.item()) + + self.save_model(epoch) + if not self.config.TEST.USE_LAST_EPOCH: + valid_loss = self.valid(data_loader) + print('validation loss: ', valid_loss) + if self.min_valid_loss is None: + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + elif (valid_loss < self.min_valid_loss): + self.min_valid_loss = valid_loss + self.best_epoch = epoch + print("Update best model! Best epoch: {}".format(self.best_epoch)) + torch.cuda.empty_cache() + if not self.config.TEST.USE_LAST_EPOCH: + print("best trained epoch: {}, min_val_loss: {}".format(self.best_epoch, self.min_valid_loss)) + + def valid(self, data_loader): + """ Runs the model on valid sets.""" + if data_loader["valid"] is None: + raise ValueError("No data for valid") + + print('') + print(" ====Validing===") + valid_loss = [] + self.model.eval() + valid_step = 0 + with torch.no_grad(): + vbar = tqdm(data_loader["valid"], ncols=80) + for valid_idx, valid_batch in enumerate(vbar): + vbar.set_description("Validation") + BVP_label = valid_batch[1].to( + torch.float32).to(self.device) + rPPG = self.model( + valid_batch[0].to(torch.float32).to(self.device)) + rPPG = (rPPG - torch.mean(rPPG)) / torch.std(rPPG) # normalize + BVP_label = (BVP_label - torch.mean(BVP_label)) / torch.std(BVP_label) # normalize + loss_ecg = self.criterion_Pearson(rPPG, BVP_label) + valid_loss.append(loss_ecg.item()) + valid_step += 1 + vbar.set_postfix(loss=loss_ecg.item()) + valid_loss = np.asarray(valid_loss) + return np.mean(valid_loss) + + def test(self, data_loader): + """ Runs the model on test sets.""" + if data_loader["test"] is None: + raise ValueError("No data for test") + + print('') + print("===Testing===") + predictions = dict() + labels = dict() + + if self.config.TOOLBOX_MODE == "only_test": + if not os.path.exists(self.config.INFERENCE.MODEL_PATH): + raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") + self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH)) + print("Testing uses pretrained model!") + print(self.config.INFERENCE.MODEL_PATH) + else: + if self.config.TEST.USE_LAST_EPOCH: + last_epoch_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') + print("Testing uses last epoch as non-pretrained model!") + print(last_epoch_model_path) + self.model.load_state_dict(torch.load(last_epoch_model_path)) + else: + best_model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') + print("Testing uses best epoch selected using model selection as non-pretrained model!") + print(best_model_path) + self.model.load_state_dict(torch.load(best_model_path)) + + self.model = self.model.to(self.config.DEVICE) + self.model.eval() + print("Running model evaluation on the testing dataset!") + with torch.no_grad(): + for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): + batch_size = test_batch[0].shape[0] + data, label = test_batch[0].to( + self.config.DEVICE), test_batch[1].to(self.config.DEVICE) + pred_ppg_test = self.model(data) + + if self.config.TEST.OUTPUT_SAVE_DIR: + label = label.cpu() + pred_ppg_test = pred_ppg_test.cpu() + + for idx in range(batch_size): + subj_index = test_batch[2][idx] + sort_index = int(test_batch[3][idx]) + if subj_index not in predictions.keys(): + predictions[subj_index] = dict() + labels[subj_index] = dict() + predictions[subj_index][sort_index] = pred_ppg_test[idx] + labels[subj_index][sort_index] = label[idx] + + print('') + calculate_metrics(predictions, labels, self.config) + if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs + self.save_test_outputs(predictions, labels, self.config) + + def save_model(self, index): + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + model_path = os.path.join( + self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') + torch.save(self.model.state_dict(), model_path) + print('Saved Model Path: ', model_path) + + # HR calculation based on ground truth label + def get_hr(self, y, sr=30, min=30, max=180): + p, q = welch(y, sr, nfft=1e5/sr, nperseg=np.min((len(y)-1, 256))) + return p[(p>min/60)&(pmin/60)&(p