|
|
|
import os
|
|
import random
|
|
import torch
|
|
import pickle
|
|
import numpy as np
|
|
|
|
from typing import List, Tuple
|
|
|
|
def seed_everything(seed):
|
|
random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
|
|
class FeatureStats:
|
|
'''
|
|
Class to store statistics of features, including all features and mean/covariance.
|
|
|
|
Args:
|
|
capture_all: Whether to store all the features.
|
|
capture_mean_cov: Whether to store mean and covariance.
|
|
max_items: Maximum number of items to store.
|
|
'''
|
|
def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
|
|
'''
|
|
'''
|
|
self.capture_all = capture_all
|
|
self.capture_mean_cov = capture_mean_cov
|
|
self.max_items = max_items
|
|
self.num_items = 0
|
|
self.num_features = None
|
|
self.all_features = None
|
|
self.raw_mean = None
|
|
self.raw_cov = None
|
|
|
|
def set_num_features(self, num_features: int):
|
|
'''
|
|
Set the number of features diminsions.
|
|
|
|
Args:
|
|
num_features: Number of features diminsions.
|
|
'''
|
|
if self.num_features is not None:
|
|
assert num_features == self.num_features
|
|
else:
|
|
self.num_features = num_features
|
|
self.all_features = []
|
|
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
|
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
|
|
|
def is_full(self) -> bool:
|
|
'''
|
|
Check if the maximum number of samples is reached.
|
|
|
|
Returns:
|
|
True if the storage is full, False otherwise.
|
|
'''
|
|
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
|
|
|
def append(self, x: np.ndarray):
|
|
'''
|
|
Add the newly computed features to the list. Update the mean and covariance.
|
|
|
|
Args:
|
|
x: New features to record.
|
|
'''
|
|
x = np.asarray(x, dtype=np.float32)
|
|
assert x.ndim == 2
|
|
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
|
if self.num_items >= self.max_items:
|
|
return
|
|
x = x[:self.max_items - self.num_items]
|
|
|
|
self.set_num_features(x.shape[1])
|
|
self.num_items += x.shape[0]
|
|
if self.capture_all:
|
|
self.all_features.append(x)
|
|
if self.capture_mean_cov:
|
|
x64 = x.astype(np.float64)
|
|
self.raw_mean += x64.sum(axis=0)
|
|
self.raw_cov += x64.T @ x64
|
|
|
|
def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
|
|
'''
|
|
Add the newly computed PyTorch features to the list. Update the mean and covariance.
|
|
|
|
Args:
|
|
x: New features to record.
|
|
rank: Rank of the current GPU.
|
|
num_gpus: Total number of GPUs.
|
|
'''
|
|
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
|
assert 0 <= rank < num_gpus
|
|
if num_gpus > 1:
|
|
ys = []
|
|
for src in range(num_gpus):
|
|
y = x.clone()
|
|
torch.distributed.broadcast(y, src=src)
|
|
ys.append(y)
|
|
x = torch.stack(ys, dim=1).flatten(0, 1)
|
|
self.append(x.cpu().numpy())
|
|
|
|
def get_all(self) -> np.ndarray:
|
|
'''
|
|
Get all the stored features as NumPy Array.
|
|
|
|
Returns:
|
|
Concatenation of the stored features.
|
|
'''
|
|
assert self.capture_all
|
|
return np.concatenate(self.all_features, axis=0)
|
|
|
|
def get_all_torch(self) -> torch.Tensor:
|
|
'''
|
|
Get all the stored features as PyTorch Tensor.
|
|
|
|
Returns:
|
|
Concatenation of the stored features.
|
|
'''
|
|
return torch.from_numpy(self.get_all())
|
|
|
|
def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
'''
|
|
Get the mean and covariance of the stored features.
|
|
|
|
Returns:
|
|
Mean and covariance of the stored features.
|
|
'''
|
|
assert self.capture_mean_cov
|
|
mean = self.raw_mean / self.num_items
|
|
cov = self.raw_cov / self.num_items
|
|
cov = cov - np.outer(mean, mean)
|
|
return mean, cov
|
|
|
|
def save(self, pkl_file: str):
|
|
'''
|
|
Save the features and statistics to a pickle file.
|
|
|
|
Args:
|
|
pkl_file: Path to the pickle file.
|
|
'''
|
|
with open(pkl_file, 'wb') as f:
|
|
pickle.dump(self.__dict__, f)
|
|
|
|
@staticmethod
|
|
def load(pkl_file: str) -> 'FeatureStats':
|
|
'''
|
|
Load the features and statistics from a pickle file.
|
|
|
|
Args:
|
|
pkl_file: Path to the pickle file.
|
|
'''
|
|
with open(pkl_file, 'rb') as f:
|
|
s = pickle.load(f)
|
|
obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
|
|
obj.__dict__.update(s)
|
|
print('Loaded %d features from %s' % (obj.num_items, pkl_file))
|
|
return obj
|
|
|