File size: 11,315 Bytes
e149e7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
from stylegan2 import Generator, Encoder
from torch import nn, autograd, optim
import pandas as pd
from tqdm import tqdm
import torch
import cv2
import os
import random
from torchvision import transforms
from torchvision import utils
import numpy as np
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage) # load model checkpoint
class GCA():
def __init__(self, distributed=False, h_path = None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.distributed = distributed
self.h_path = h_path # path to sex and age hyperplanes
self.size, self.n_mlp, self.channel_multiplier, self.cgan = 256, 8, 2, True
self.classifier_nof_classes, self.embedding_size, self.latent = 2, 10, 512
self.g_reg_every, self.lr, self.ckpt = 4, 0.002, 'results/000500.pt'
# load model checkpoints
self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage)
self.generator = Generator(self.size, self.latent, self.n_mlp, channel_multiplier=self.channel_multiplier,
conditional_gan=self.cgan, nof_classes=self.classifier_nof_classes,
embedding_size=self.embedding_size).to(self.device)
self.encoder = Encoder(self.size, channel_multiplier=self.channel_multiplier, output_channels=self.latent).to(self.device)
self.generator.load_state_dict(self.ckpt["g"]); self.encoder.load_state_dict(self.ckpt["e"]) # load checkpoints
if self.distributed: # use multiple gpus
local_rank = int(os.environ["LOCAL_RANK"])
self.generator = nn.parallel.DistributedDataParallel(
generator,
device_ids=[local_rank],
output_device=local_rank,
broadcast_buffers=False,
)
self.encoder = nn.parallel.DistributedDataParallel(
encoder,
device_ids=[local_rank],
output_device=local_rank,
broadcast_buffers=False,
)
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((256,256)),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
]
)
# Get SVM coefficients
self.sex_coeff, self.age_coeff = None, None
self.__get_hyperplanes__()
self.w_shape = None
def __load_image__(self, path):
img = cv2.imread(path) # Load image using cv2
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB
img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device) # Preprocess
return img_tensor
def __process_in_batches__(self, patients, batch_size):
style_vectors = []
for i in range(0, len(patients), batch_size):
batch_paths = patients.iloc[i : i + batch_size]["Path"].tolist()
batch_imgs = [self.__load_image__(path) for path in batch_paths]
batch_imgs_tensor = torch.cat(batch_imgs, dim=0) # Stack images in a batch
with torch.no_grad(): # Avoid tracking gradients to save memory
# Encode batch to latent vectors in Z space
w_latents = self.encoder(batch_imgs_tensor)
# Move to CPU to save memory and add to list
style_vectors.extend(w_latents.cpu())
del batch_imgs_tensor, w_latents # Cleanup and clear cache
torch.cuda.empty_cache() # Clear cache to free memory
return style_vectors
def __load_cxr_data__(self, df):
return self.__process_in_batches__(df, batch_size=16)
def __get_patient_data__(self, rsna_csv="../datasets/rsna_patients.csv", cxpt_csv="../chexpert/versions/1/train.csv"):
if os.path.exists(rsna_csv) and os.path.exists(cxpt_csv):
n_patients = 500
rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv))
cxpt_csv = pd.DataFrame(pd.read_csv(cxpt_csv))
rsna_csv["Image Index"] = "../datasets/rsna/" + rsna_csv["Image Index"] # add prefix to path
rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True)
# Load 500 latent vectors from each class
male = rsna_csv[rsna_csv["Sex"] == "M"][:500]
female = rsna_csv[rsna_csv["Sex"] == "F"][:500]
young = rsna_csv[rsna_csv["Age"] < 20][:500]
rsna = rsna_csv[rsna_csv["Age"] > 80][:250]
cxpt = cxpt_csv[cxpt_csv["Age"] > 80][:250]
old = pd.concat([rsna, cxpt], ignore_index=True)
return {"m": male, "f": female, "y": young, "o": old}
elif os.path.exists(rsna_csv):
n_patients = 500
rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv))
rsna_csv["Image Index"] = "../datasets/rsna/" + rsna_csv["Image Index"] # add prefix to path
rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True)
# Load 500 latent vectors from each class
male = rsna_csv[rsna_csv["Sex"] == "M"][:500]
female = rsna_csv[rsna_csv["Sex"] == "F"][:500]
young = rsna_csv[rsna_csv["Age"] < 20][:500]
old = rsna_csv[rsna_csv["Age"] > 80][:250]
return {"m": male, "f": female, "y": young, "o": old}
else:
print(f"The path '{path}' does not exist.")
return None
def __learn_linearSVM__(self, d1, d2, df1, df2, key="Sex"):
# prepare dataset
styles, labels = [], []
styles.extend(d1); labels.extend(list(df1["Sex"]))
styles.extend(d2); labels.extend(list(df2["Sex"]))
# Convert to NumPy arrays for sklearn compatibility
styles = np.array([style.numpy().flatten() for style in styles])
# styles = torch.stack(styles)
labels = np.array(labels)
# Shuffle dataset with the same seed
seed = 42
random.seed(seed)
np.random.seed(seed)
# Shuffle styles and labels together
indices = np.arange(len(styles))
np.random.shuffle(indices)
styles, labels = styles[indices], labels[indices]
self.w_shape = styles[0].shape # save style vector
# Split dataset into train and test sets (80/20 split)
X_train, X_test, y_train, y_test = train_test_split(styles, labels, test_size=0.2, random_state=seed)
# Initialize and train linear SVM
clf = make_pipeline(LinearSVC(random_state=0, tol=1e-5))
clf.fit(X_train, y_train)
# Predict on the test set
y_pred = clf.predict(X_test)
return clf
def __get_hyperplanes__(self):
if os.path.exists(self.h_path):
hyperplanes = torch.load(self.h_path)
self.sex_coeff, self.age_coeff = hyperplanes[:512], hyperplanes[512:]
else:
patient_data = self.__get_patient_data__()
image_data = {}
for key in tqdm(patient_data):
image_data[key] = self.__load_cxr_data__(patient_data[key])
sex = self.__learn_linearSVM__(image_data["m"], image_data["f"], patient_data["m"], patient_data["f"]).named_steps['linearsvc'].coef_[0].reshape((self.w_shape))
age = self.__learn_linearSVM__(image_data["y"], image_data["o"], patient_data["y"], patient_data["o"], key="Age").named_steps['linearsvc'].coef_[0].reshape((self.w_shape))
self.sex_coeff = (torch.from_numpy(sex).float()).to(self.device)
self.age_coeff = (torch.from_numpy(age).float()).to(self.device)
torch.save(torch.cat([self.sex_coeff, self.age_coeff], dim=0), "hyperplanes.pt") # save for next time
print("Sex and Age coefficient loaded!")
def __age__(self, w, step_size = -2, magnitude=1):
alpha = step_size * magnitude
# v = self.age_coeff.named_steps['linearsvc'].coef_[0].reshape((self.w_shape)) # get coefficients from hyperplane
# v = (torch.from_numpy(v).float()).to(self.device)
return w + alpha * self.age_coeff
def __sex__(self, w, step_size = 1, magnitude=1):
alpha = step_size * magnitude
# v = self.age_coeff.named_steps['linearsvc'].coef_[0].reshape((self.w_shape)) # get coefficients from hyperplane
# v = (torch.from_numpy(v).float()).to(self.device)
return w + alpha * self.sex_coeff
def augment_helper(self, embedding, rate=0.8): # p = augmentation rate
# sex, age = gca.sex_coeff.predict(embedding.clone().detach().cpu().numpy())[0],\
# gca.age_coeff.predict(embedding.clone().detach().cpu().numpy())[0]
np.random.seed(None); random.seed(None)
if np.random.choice([True, False], p=[rate, 1-rate]): # random 80% chance of augmentation
w_ = self.__sex__(embedding, magnitude=random.randint(-4,4))
w_ = self.__age__(w_, magnitude=random.randint(-2,2))
# if sex == "M":
# w_ = self.__sex__(embedding, magnitude=random.randint(-4,1))
# else:
# w_ = self.__sex__(embedding, magnitude=random.randint(-1,4))
# if age == "0-20":
# w_ = self.__age__(w_, magnitude=random.randint(-1,4))
# else:
# w_ = self.__age__(w_, magnitude=random.randint(-4,1))
synth, _ = self.generator([w_], input_is_latent=True) # reconstruct image
utils.save_image(synth, "real_samples_agesex.png", nrow=int(1 ** 2), normalize=True)
return synth
# synth, _ = self.generator([embedding], input_is_latent=True) # reconstruct image
return None
def augment(self, x, rate=0.8):
x = torch.unsqueeze(self.transform(x), 0).to(self.device)
embedding = self.encoder(x) # sample patient
aug_x = self.augment_helper(embedding, rate)
if aug_x is not None:
# convert to (none, 224, 224, 3) numpy array
im = utils.make_grid(aug_x)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
return im.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = utils.make_grid(x)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
return im.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
if __name__ == "__main__":
# initialize GCA
gca = GCA(h_path="hyperplanes.pt")
# load image
img = cv2.imread("../datasets/rsna/00000007_000.png")
gca.augment(img)
# save or return image embedding |