|
|
|
|
|
from typing import Tuple |
|
import scipy |
|
import numpy as np |
|
import torch |
|
|
|
|
|
def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: |
|
mu_gen, sigma_gen = compute_stats(feats_fake) |
|
mu_real, sigma_real = compute_stats(feats_real) |
|
|
|
m = np.square(mu_gen - mu_real).sum() |
|
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) |
|
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) |
|
|
|
return float(fid) |
|
|
|
|
|
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
|
mu = feats.mean(axis=0) |
|
sigma = np.cov(feats, rowvar=False) |
|
|
|
return mu, sigma |
|
|
|
|
|
@torch.no_grad() |
|
def compute_our_fvd(videos_fake: np.ndarray, videos_real: np.ndarray, device: str = "cuda") -> float: |
|
i3d_path = "checkpoints/auxiliary/i3d_torchscript.pt" |
|
i3d_kwargs = dict( |
|
rescale=False, resize=False, return_features=True |
|
) |
|
|
|
with open(i3d_path, "rb") as f: |
|
i3d_model = torch.jit.load(f).eval().to(device) |
|
|
|
videos_fake = videos_fake.permute(0, 4, 1, 2, 3).to(device) |
|
videos_real = videos_real.permute(0, 4, 1, 2, 3).to(device) |
|
|
|
feats_fake = i3d_model(videos_fake, **i3d_kwargs).cpu().numpy() |
|
feats_real = i3d_model(videos_real, **i3d_kwargs).cpu().numpy() |
|
|
|
return compute_fvd(feats_fake, feats_real) |
|
|
|
|
|
def main(): |
|
|
|
videos_fake = torch.rand(10, 16, 224, 224, 3) |
|
videos_real = torch.rand(10, 16, 224, 224, 3) |
|
|
|
our_fvd_result = compute_our_fvd(videos_fake, videos_real) |
|
print(f"[FVD scores] Ours: {our_fvd_result}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|