R-Detect / feature_ref_loader.py
songyiliao's picture
feat: initial cmommit (#1)
1244519 verified
raw
history blame contribute delete
564 Bytes
import torch
import numpy as np
from utils import get_device, config
DEVICE = get_device()
def feature_ref_loader(feature_ref_file_name, num_ref=5000):
print("Feature Ref Loader load: ", feature_ref_file_name)
load_ref_data = torch.load(feature_ref_file_name, map_location=DEVICE) # cpu
load_ref_data = load_ref_data.to(DEVICE)
feature_ref = load_ref_data[np.random.permutation(load_ref_data.shape[0])][
:num_ref
].to(DEVICE)
return feature_ref
feature_two_sample_tester_ref = feature_ref_loader("./feature_ref_for_test.pt")