juwaeze commited on
Commit
e149e7f
·
verified ·
1 Parent(s): 2959999

Upload 13 files

Browse files
gca2.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stylegan2 import Generator, Encoder
2
+ from torch import nn, autograd, optim
3
+ import pandas as pd
4
+ from tqdm import tqdm
5
+ import torch
6
+ import cv2
7
+ import os
8
+ import random
9
+ from torchvision import transforms
10
+ from torchvision import utils
11
+ import numpy as np
12
+ from sklearn.svm import SVC
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.metrics import classification_report, accuracy_score
15
+ from sklearn.pipeline import make_pipeline
16
+ from sklearn.svm import LinearSVC
17
+
18
+ def accumulate(model1, model2, decay=0.999):
19
+ par1 = dict(model1.named_parameters())
20
+ par2 = dict(model2.named_parameters())
21
+
22
+ for k in par1.keys():
23
+ par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
24
+ self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage) # load model checkpoint
25
+
26
+ class GCA():
27
+ def __init__(self, distributed=False, h_path = None):
28
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ print(f"Using device: {self.device}")
30
+ self.distributed = distributed
31
+ self.h_path = h_path # path to sex and age hyperplanes
32
+ self.size, self.n_mlp, self.channel_multiplier, self.cgan = 256, 8, 2, True
33
+ self.classifier_nof_classes, self.embedding_size, self.latent = 2, 10, 512
34
+ self.g_reg_every, self.lr, self.ckpt = 4, 0.002, 'results/000500.pt'
35
+ # load model checkpoints
36
+ self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage)
37
+ self.generator = Generator(self.size, self.latent, self.n_mlp, channel_multiplier=self.channel_multiplier,
38
+ conditional_gan=self.cgan, nof_classes=self.classifier_nof_classes,
39
+ embedding_size=self.embedding_size).to(self.device)
40
+ self.encoder = Encoder(self.size, channel_multiplier=self.channel_multiplier, output_channels=self.latent).to(self.device)
41
+ self.generator.load_state_dict(self.ckpt["g"]); self.encoder.load_state_dict(self.ckpt["e"]) # load checkpoints
42
+ if self.distributed: # use multiple gpus
43
+ local_rank = int(os.environ["LOCAL_RANK"])
44
+ self.generator = nn.parallel.DistributedDataParallel(
45
+ generator,
46
+ device_ids=[local_rank],
47
+ output_device=local_rank,
48
+ broadcast_buffers=False,
49
+ )
50
+ self.encoder = nn.parallel.DistributedDataParallel(
51
+ encoder,
52
+ device_ids=[local_rank],
53
+ output_device=local_rank,
54
+ broadcast_buffers=False,
55
+ )
56
+
57
+ self.transform = transforms.Compose(
58
+ [
59
+ transforms.ToTensor(),
60
+ transforms.Resize((256,256)),
61
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
62
+ ]
63
+ )
64
+ # Get SVM coefficients
65
+ self.sex_coeff, self.age_coeff = None, None
66
+ self.__get_hyperplanes__()
67
+ self.w_shape = None
68
+
69
+
70
+ def __load_image__(self, path):
71
+ img = cv2.imread(path) # Load image using cv2
72
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB
73
+ img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device) # Preprocess
74
+ return img_tensor
75
+
76
+ def __process_in_batches__(self, patients, batch_size):
77
+ style_vectors = []
78
+ for i in range(0, len(patients), batch_size):
79
+ batch_paths = patients.iloc[i : i + batch_size]["Path"].tolist()
80
+ batch_imgs = [self.__load_image__(path) for path in batch_paths]
81
+ batch_imgs_tensor = torch.cat(batch_imgs, dim=0) # Stack images in a batch
82
+ with torch.no_grad(): # Avoid tracking gradients to save memory
83
+ # Encode batch to latent vectors in Z space
84
+ w_latents = self.encoder(batch_imgs_tensor)
85
+ # Move to CPU to save memory and add to list
86
+ style_vectors.extend(w_latents.cpu())
87
+ del batch_imgs_tensor, w_latents # Cleanup and clear cache
88
+ torch.cuda.empty_cache() # Clear cache to free memory
89
+ return style_vectors
90
+
91
+ def __load_cxr_data__(self, df):
92
+ return self.__process_in_batches__(df, batch_size=16)
93
+
94
+ def __get_patient_data__(self, rsna_csv="../datasets/rsna_patients.csv", cxpt_csv="../chexpert/versions/1/train.csv"):
95
+ if os.path.exists(rsna_csv) and os.path.exists(cxpt_csv):
96
+ n_patients = 500
97
+ rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv))
98
+ cxpt_csv = pd.DataFrame(pd.read_csv(cxpt_csv))
99
+ rsna_csv["Image Index"] = "../datasets/rsna/" + rsna_csv["Image Index"] # add prefix to path
100
+ rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True)
101
+
102
+ # Load 500 latent vectors from each class
103
+ male = rsna_csv[rsna_csv["Sex"] == "M"][:500]
104
+ female = rsna_csv[rsna_csv["Sex"] == "F"][:500]
105
+ young = rsna_csv[rsna_csv["Age"] < 20][:500]
106
+ rsna = rsna_csv[rsna_csv["Age"] > 80][:250]
107
+ cxpt = cxpt_csv[cxpt_csv["Age"] > 80][:250]
108
+ old = pd.concat([rsna, cxpt], ignore_index=True)
109
+ return {"m": male, "f": female, "y": young, "o": old}
110
+ elif os.path.exists(rsna_csv):
111
+ n_patients = 500
112
+ rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv))
113
+ rsna_csv["Image Index"] = "../datasets/rsna/" + rsna_csv["Image Index"] # add prefix to path
114
+ rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True)
115
+
116
+ # Load 500 latent vectors from each class
117
+ male = rsna_csv[rsna_csv["Sex"] == "M"][:500]
118
+ female = rsna_csv[rsna_csv["Sex"] == "F"][:500]
119
+ young = rsna_csv[rsna_csv["Age"] < 20][:500]
120
+ old = rsna_csv[rsna_csv["Age"] > 80][:250]
121
+ return {"m": male, "f": female, "y": young, "o": old}
122
+ else:
123
+ print(f"The path '{path}' does not exist.")
124
+ return None
125
+
126
+ def __learn_linearSVM__(self, d1, d2, df1, df2, key="Sex"):
127
+ # prepare dataset
128
+ styles, labels = [], []
129
+ styles.extend(d1); labels.extend(list(df1["Sex"]))
130
+ styles.extend(d2); labels.extend(list(df2["Sex"]))
131
+ # Convert to NumPy arrays for sklearn compatibility
132
+ styles = np.array([style.numpy().flatten() for style in styles])
133
+ # styles = torch.stack(styles)
134
+ labels = np.array(labels)
135
+ # Shuffle dataset with the same seed
136
+ seed = 42
137
+ random.seed(seed)
138
+ np.random.seed(seed)
139
+ # Shuffle styles and labels together
140
+ indices = np.arange(len(styles))
141
+ np.random.shuffle(indices)
142
+ styles, labels = styles[indices], labels[indices]
143
+ self.w_shape = styles[0].shape # save style vector
144
+ # Split dataset into train and test sets (80/20 split)
145
+ X_train, X_test, y_train, y_test = train_test_split(styles, labels, test_size=0.2, random_state=seed)
146
+ # Initialize and train linear SVM
147
+ clf = make_pipeline(LinearSVC(random_state=0, tol=1e-5))
148
+ clf.fit(X_train, y_train)
149
+ # Predict on the test set
150
+ y_pred = clf.predict(X_test)
151
+ return clf
152
+
153
+ def __get_hyperplanes__(self):
154
+ if os.path.exists(self.h_path):
155
+ hyperplanes = torch.load(self.h_path)
156
+ self.sex_coeff, self.age_coeff = hyperplanes[:512], hyperplanes[512:]
157
+ else:
158
+ patient_data = self.__get_patient_data__()
159
+ image_data = {}
160
+ for key in tqdm(patient_data):
161
+ image_data[key] = self.__load_cxr_data__(patient_data[key])
162
+ sex = self.__learn_linearSVM__(image_data["m"], image_data["f"], patient_data["m"], patient_data["f"]).named_steps['linearsvc'].coef_[0].reshape((self.w_shape))
163
+ age = self.__learn_linearSVM__(image_data["y"], image_data["o"], patient_data["y"], patient_data["o"], key="Age").named_steps['linearsvc'].coef_[0].reshape((self.w_shape))
164
+ self.sex_coeff = (torch.from_numpy(sex).float()).to(self.device)
165
+ self.age_coeff = (torch.from_numpy(age).float()).to(self.device)
166
+ torch.save(torch.cat([self.sex_coeff, self.age_coeff], dim=0), "hyperplanes.pt") # save for next time
167
+ print("Sex and Age coefficient loaded!")
168
+
169
+ def __age__(self, w, step_size = -2, magnitude=1):
170
+ alpha = step_size * magnitude
171
+ # v = self.age_coeff.named_steps['linearsvc'].coef_[0].reshape((self.w_shape)) # get coefficients from hyperplane
172
+ # v = (torch.from_numpy(v).float()).to(self.device)
173
+ return w + alpha * self.age_coeff
174
+
175
+ def __sex__(self, w, step_size = 1, magnitude=1):
176
+ alpha = step_size * magnitude
177
+ # v = self.age_coeff.named_steps['linearsvc'].coef_[0].reshape((self.w_shape)) # get coefficients from hyperplane
178
+ # v = (torch.from_numpy(v).float()).to(self.device)
179
+ return w + alpha * self.sex_coeff
180
+
181
+ def augment_helper(self, embedding, rate=0.8): # p = augmentation rate
182
+ # sex, age = gca.sex_coeff.predict(embedding.clone().detach().cpu().numpy())[0],\
183
+ # gca.age_coeff.predict(embedding.clone().detach().cpu().numpy())[0]
184
+ np.random.seed(None); random.seed(None)
185
+ if np.random.choice([True, False], p=[rate, 1-rate]): # random 80% chance of augmentation
186
+ w_ = self.__sex__(embedding, magnitude=random.randint(-4,4))
187
+ w_ = self.__age__(w_, magnitude=random.randint(-2,2))
188
+ # if sex == "M":
189
+ # w_ = self.__sex__(embedding, magnitude=random.randint(-4,1))
190
+ # else:
191
+ # w_ = self.__sex__(embedding, magnitude=random.randint(-1,4))
192
+ # if age == "0-20":
193
+ # w_ = self.__age__(w_, magnitude=random.randint(-1,4))
194
+ # else:
195
+ # w_ = self.__age__(w_, magnitude=random.randint(-4,1))
196
+ synth, _ = self.generator([w_], input_is_latent=True) # reconstruct image
197
+ utils.save_image(synth, "real_samples_agesex.png", nrow=int(1 ** 2), normalize=True)
198
+ return synth
199
+ # synth, _ = self.generator([embedding], input_is_latent=True) # reconstruct image
200
+ return None
201
+
202
+ def augment(self, x, rate=0.8):
203
+ x = torch.unsqueeze(self.transform(x), 0).to(self.device)
204
+ embedding = self.encoder(x) # sample patient
205
+ aug_x = self.augment_helper(embedding, rate)
206
+ if aug_x is not None:
207
+ # convert to (none, 224, 224, 3) numpy array
208
+ im = utils.make_grid(aug_x)
209
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
210
+ return im.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
211
+ im = utils.make_grid(x)
212
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
213
+ return im.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
214
+
215
+ if __name__ == "__main__":
216
+ # initialize GCA
217
+ gca = GCA(h_path="hyperplanes.pt")
218
+ # load image
219
+ img = cv2.imread("../datasets/rsna/00000007_000.png")
220
+ gca.augment(img)
221
+
222
+
223
+ # save or return image embedding
op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
op/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (266 Bytes). View file
 
op/__pycache__/conv2d_gradfix.cpython-38.pyc ADDED
Binary file (5.3 kB). View file
 
op/__pycache__/fused_act.cpython-38.pyc ADDED
Binary file (3.28 kB). View file
 
op/__pycache__/upfirdn2d.cpython-38.pyc ADDED
Binary file (4.34 kB). View file
 
op/conv2d_gradfix.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+
86
+ try:
87
+ parts = torch.__version__.split('.')
88
+ major = int(parts[0])
89
+ minor = int(parts[1])
90
+ if major >= 1 or (major == 1 and minor >= 7):
91
+ return True
92
+ except (ValueError):
93
+ pass
94
+ warnings.warn(
95
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
96
+ )
97
+
98
+ return False
99
+
100
+
101
+ def ensure_tuple(xs, ndim):
102
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
103
+
104
+ return xs
105
+
106
+
107
+ conv2d_gradfix_cache = dict()
108
+
109
+
110
+ def conv2d_gradfix(
111
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
112
+ ):
113
+ ndim = 2
114
+ weight_shape = tuple(weight_shape)
115
+ stride = ensure_tuple(stride, ndim)
116
+ padding = ensure_tuple(padding, ndim)
117
+ output_padding = ensure_tuple(output_padding, ndim)
118
+ dilation = ensure_tuple(dilation, ndim)
119
+
120
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
121
+ if key in conv2d_gradfix_cache:
122
+ return conv2d_gradfix_cache[key]
123
+
124
+ common_kwargs = dict(
125
+ stride=stride, padding=padding, dilation=dilation, groups=groups
126
+ )
127
+
128
+ def calc_output_padding(input_shape, output_shape):
129
+ if transpose:
130
+ return [0, 0]
131
+
132
+ return [
133
+ input_shape[i + 2]
134
+ - (output_shape[i + 2] - 1) * stride[i]
135
+ - (1 - 2 * padding[i])
136
+ - dilation[i] * (weight_shape[i + 2] - 1)
137
+ for i in range(ndim)
138
+ ]
139
+
140
+ class Conv2d(autograd.Function):
141
+ @staticmethod
142
+ def forward(ctx, input, weight, bias):
143
+ if not transpose:
144
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
145
+
146
+ else:
147
+ out = F.conv_transpose2d(
148
+ input=input,
149
+ weight=weight,
150
+ bias=bias,
151
+ output_padding=output_padding,
152
+ **common_kwargs,
153
+ )
154
+
155
+ ctx.save_for_backward(input, weight)
156
+
157
+ return out
158
+
159
+ @staticmethod
160
+ def backward(ctx, grad_output):
161
+ input, weight = ctx.saved_tensors
162
+ grad_input, grad_weight, grad_bias = None, None, None
163
+
164
+ if ctx.needs_input_grad[0]:
165
+ p = calc_output_padding(
166
+ input_shape=input.shape, output_shape=grad_output.shape
167
+ )
168
+ grad_input = conv2d_gradfix(
169
+ transpose=(not transpose),
170
+ weight_shape=weight_shape,
171
+ output_padding=p,
172
+ **common_kwargs,
173
+ ).apply(grad_output, weight, None)
174
+
175
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
176
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
177
+
178
+ if ctx.needs_input_grad[2]:
179
+ grad_bias = grad_output.sum((0, 2, 3))
180
+
181
+ return grad_input, grad_weight, grad_bias
182
+
183
+ class Conv2dGradWeight(autograd.Function):
184
+ @staticmethod
185
+ def forward(ctx, grad_output, input):
186
+ op = torch._C._jit_get_operation(
187
+ "aten::cudnn_convolution_backward_weight"
188
+ if not transpose
189
+ else "aten::cudnn_convolution_transpose_backward_weight"
190
+ )
191
+ flags = [
192
+ torch.backends.cudnn.benchmark,
193
+ torch.backends.cudnn.deterministic,
194
+ torch.backends.cudnn.allow_tf32,
195
+ ]
196
+ grad_weight = op(
197
+ weight_shape,
198
+ grad_output,
199
+ input,
200
+ padding,
201
+ stride,
202
+ dilation,
203
+ groups,
204
+ *flags,
205
+ )
206
+ ctx.save_for_backward(grad_output, input)
207
+
208
+ return grad_weight
209
+
210
+ @staticmethod
211
+ def backward(ctx, grad_grad_weight):
212
+ grad_output, input = ctx.saved_tensors
213
+ grad_grad_output, grad_grad_input = None, None
214
+
215
+ if ctx.needs_input_grad[0]:
216
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
217
+
218
+ if ctx.needs_input_grad[1]:
219
+ p = calc_output_padding(
220
+ input_shape=input.shape, output_shape=grad_output.shape
221
+ )
222
+ grad_grad_input = conv2d_gradfix(
223
+ transpose=(not transpose),
224
+ weight_shape=weight_shape,
225
+ output_padding=p,
226
+ **common_kwargs,
227
+ ).apply(grad_output, grad_grad_weight, None)
228
+
229
+ return grad_grad_output, grad_grad_input
230
+
231
+ conv2d_gradfix_cache[key] = Conv2d
232
+
233
+ return Conv2d
op/fused_act.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ if bias:
39
+ grad_bias = grad_input.sum(dim).detach()
40
+
41
+ else:
42
+ grad_bias = empty
43
+
44
+ return grad_input, grad_bias
45
+
46
+ @staticmethod
47
+ def backward(ctx, gradgrad_input, gradgrad_bias):
48
+ out, = ctx.saved_tensors
49
+ gradgrad_out = fused.fused_bias_act(
50
+ gradgrad_input.contiguous(),
51
+ gradgrad_bias,
52
+ out,
53
+ 3,
54
+ 1,
55
+ ctx.negative_slope,
56
+ ctx.scale,
57
+ )
58
+
59
+ return gradgrad_out, None, None, None, None
60
+
61
+
62
+ class FusedLeakyReLUFunction(Function):
63
+ @staticmethod
64
+ def forward(ctx, input, bias, negative_slope, scale):
65
+ empty = input.new_empty(0)
66
+
67
+ ctx.bias = bias is not None
68
+
69
+ if bias is None:
70
+ bias = empty
71
+
72
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
73
+ ctx.save_for_backward(out)
74
+ ctx.negative_slope = negative_slope
75
+ ctx.scale = scale
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output):
81
+ out, = ctx.saved_tensors
82
+
83
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84
+ grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85
+ )
86
+
87
+ if not ctx.bias:
88
+ grad_bias = None
89
+
90
+ return grad_input, grad_bias, None, None
91
+
92
+
93
+ class FusedLeakyReLU(nn.Module):
94
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95
+ super().__init__()
96
+
97
+ if bias:
98
+ self.bias = nn.Parameter(torch.zeros(channel))
99
+
100
+ else:
101
+ self.bias = None
102
+
103
+ self.negative_slope = negative_slope
104
+ self.scale = scale
105
+
106
+ def forward(self, input):
107
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108
+
109
+
110
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111
+ if input.device.type == "cpu":
112
+ if bias is not None:
113
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
114
+ return (
115
+ F.leaky_relu(
116
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117
+ )
118
+ * scale
119
+ )
120
+
121
+ else:
122
+ return F.leaky_relu(input, negative_slope=0.2) * scale
123
+
124
+ else:
125
+ return FusedLeakyReLUFunction.apply(
126
+ input.contiguous(), bias, negative_slope, scale
127
+ )
op/fused_bias_act.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <ATen/ATen.h>
3
+ #include <torch/extension.h>
4
+
5
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
+ const torch::Tensor &bias,
7
+ const torch::Tensor &refer, int act, int grad,
8
+ float alpha, float scale);
9
+
10
+ #define CHECK_CUDA(x) \
11
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
+ #define CHECK_CONTIGUOUS(x) \
13
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
+ #define CHECK_INPUT(x) \
15
+ CHECK_CUDA(x); \
16
+ CHECK_CONTIGUOUS(x)
17
+
18
+ torch::Tensor fused_bias_act(const torch::Tensor &input,
19
+ const torch::Tensor &bias,
20
+ const torch::Tensor &refer, int act, int grad,
21
+ float alpha, float scale) {
22
+ CHECK_INPUT(input);
23
+ CHECK_INPUT(bias);
24
+
25
+ at::DeviceGuard guard(input.device());
26
+
27
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
+ }
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
+ }
op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void
20
+ fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
+ const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
+ scalar_t scale, int loop_x, int size_x, int step_b,
23
+ int size_b, int use_bias, int use_ref) {
24
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
+
26
+ scalar_t zero = 0.0;
27
+
28
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
+ loop_idx++, xi += blockDim.x) {
30
+ scalar_t x = p_x[xi];
31
+
32
+ if (use_bias) {
33
+ x += p_b[(xi / step_b) % size_b];
34
+ }
35
+
36
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
37
+
38
+ scalar_t y;
39
+
40
+ switch (act * 10 + grad) {
41
+ default:
42
+ case 10:
43
+ y = x;
44
+ break;
45
+ case 11:
46
+ y = x;
47
+ break;
48
+ case 12:
49
+ y = 0.0;
50
+ break;
51
+
52
+ case 30:
53
+ y = (x > 0.0) ? x : x * alpha;
54
+ break;
55
+ case 31:
56
+ y = (ref > 0.0) ? x : x * alpha;
57
+ break;
58
+ case 32:
59
+ y = 0.0;
60
+ break;
61
+ }
62
+
63
+ out[xi] = y * scale;
64
+ }
65
+ }
66
+
67
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
+ const torch::Tensor &bias,
69
+ const torch::Tensor &refer, int act, int grad,
70
+ float alpha, float scale) {
71
+ int curDevice = -1;
72
+ cudaGetDevice(&curDevice);
73
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
+
75
+ auto x = input.contiguous();
76
+ auto b = bias.contiguous();
77
+ auto ref = refer.contiguous();
78
+
79
+ int use_bias = b.numel() ? 1 : 0;
80
+ int use_ref = ref.numel() ? 1 : 0;
81
+
82
+ int size_x = x.numel();
83
+ int size_b = b.numel();
84
+ int step_b = 1;
85
+
86
+ for (int i = 1 + 1; i < x.dim(); i++) {
87
+ step_b *= x.size(i);
88
+ }
89
+
90
+ int loop_x = 4;
91
+ int block_size = 4 * 32;
92
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
+
94
+ auto y = torch::empty_like(x);
95
+
96
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
+ x.scalar_type(), "fused_bias_act_kernel", [&] {
98
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
+ y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
+ b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
+ scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
+ });
103
+
104
+ return y;
105
+ }
op/upfirdn2d.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <torch/extension.h>
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
+ const torch::Tensor &kernel, int up_x, int up_y,
6
+ int down_x, int down_y, int pad_x0, int pad_x1,
7
+ int pad_y0, int pad_y1);
8
+
9
+ #define CHECK_CUDA(x) \
10
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) \
12
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+ #define CHECK_INPUT(x) \
14
+ CHECK_CUDA(x); \
15
+ CHECK_CONTIGUOUS(x)
16
+
17
+ torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
+ int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
+ int pad_x1, int pad_y0, int pad_y1) {
20
+ CHECK_INPUT(input);
21
+ CHECK_INPUT(kernel);
22
+
23
+ at::DeviceGuard guard(input.device());
24
+
25
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
+ pad_y0, pad_y1);
27
+ }
28
+
29
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
+ }
op/upfirdn2d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ import os
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ upfirdn2d_op = load(
12
+ "upfirdn2d",
13
+ sources=[
14
+ os.path.join(module_path, "upfirdn2d.cpp"),
15
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class UpFirDn2dBackward(Function):
21
+ @staticmethod
22
+ def forward(
23
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
+ ):
25
+
26
+ up_x, up_y = up
27
+ down_x, down_y = down
28
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
+
30
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
+
32
+ grad_input = upfirdn2d_op.upfirdn2d(
33
+ grad_output,
34
+ grad_kernel,
35
+ down_x,
36
+ down_y,
37
+ up_x,
38
+ up_y,
39
+ g_pad_x0,
40
+ g_pad_x1,
41
+ g_pad_y0,
42
+ g_pad_y1,
43
+ )
44
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
+
46
+ ctx.save_for_backward(kernel)
47
+
48
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
+
50
+ ctx.up_x = up_x
51
+ ctx.up_y = up_y
52
+ ctx.down_x = down_x
53
+ ctx.down_y = down_y
54
+ ctx.pad_x0 = pad_x0
55
+ ctx.pad_x1 = pad_x1
56
+ ctx.pad_y0 = pad_y0
57
+ ctx.pad_y1 = pad_y1
58
+ ctx.in_size = in_size
59
+ ctx.out_size = out_size
60
+
61
+ return grad_input
62
+
63
+ @staticmethod
64
+ def backward(ctx, gradgrad_input):
65
+ kernel, = ctx.saved_tensors
66
+
67
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
+
69
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
70
+ gradgrad_input,
71
+ kernel,
72
+ ctx.up_x,
73
+ ctx.up_y,
74
+ ctx.down_x,
75
+ ctx.down_y,
76
+ ctx.pad_x0,
77
+ ctx.pad_x1,
78
+ ctx.pad_y0,
79
+ ctx.pad_y1,
80
+ )
81
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
+ gradgrad_out = gradgrad_out.view(
83
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
+ )
85
+
86
+ return gradgrad_out, None, None, None, None, None, None, None, None
87
+
88
+
89
+ class UpFirDn2d(Function):
90
+ @staticmethod
91
+ def forward(ctx, input, kernel, up, down, pad):
92
+ up_x, up_y = up
93
+ down_x, down_y = down
94
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
+
96
+ kernel_h, kernel_w = kernel.shape
97
+ batch, channel, in_h, in_w = input.shape
98
+ ctx.in_size = input.shape
99
+
100
+ input = input.reshape(-1, in_h, in_w, 1)
101
+
102
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
+
104
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
+ ctx.out_size = (out_h, out_w)
107
+
108
+ ctx.up = (up_x, up_y)
109
+ ctx.down = (down_x, down_y)
110
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
+
112
+ g_pad_x0 = kernel_w - pad_x0 - 1
113
+ g_pad_y0 = kernel_h - pad_y0 - 1
114
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
+
117
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
+
119
+ out = upfirdn2d_op.upfirdn2d(
120
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
+ )
122
+ # out = out.view(major, out_h, out_w, minor)
123
+ out = out.view(-1, channel, out_h, out_w)
124
+
125
+ return out
126
+
127
+ @staticmethod
128
+ def backward(ctx, grad_output):
129
+ kernel, grad_kernel = ctx.saved_tensors
130
+
131
+ grad_input = None
132
+
133
+ if ctx.needs_input_grad[0]:
134
+ grad_input = UpFirDn2dBackward.apply(
135
+ grad_output,
136
+ kernel,
137
+ grad_kernel,
138
+ ctx.up,
139
+ ctx.down,
140
+ ctx.pad,
141
+ ctx.g_pad,
142
+ ctx.in_size,
143
+ ctx.out_size,
144
+ )
145
+
146
+ return grad_input, None, None, None, None
147
+
148
+
149
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
+ if not isinstance(up, abc.Iterable):
151
+ up = (up, up)
152
+
153
+ if not isinstance(down, abc.Iterable):
154
+ down = (down, down)
155
+
156
+ if len(pad) == 2:
157
+ pad = (pad[0], pad[1], pad[0], pad[1])
158
+
159
+ if input.device.type == "cpu":
160
+ out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
+
162
+ else:
163
+ out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
+
165
+ return out
166
+
167
+
168
+ def upfirdn2d_native(
169
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
+ ):
171
+ _, channel, in_h, in_w = input.shape
172
+ input = input.reshape(-1, in_h, in_w, 1)
173
+
174
+ _, in_h, in_w, minor = input.shape
175
+ kernel_h, kernel_w = kernel.shape
176
+
177
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
178
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
+
181
+ out = F.pad(
182
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
+ )
184
+ out = out[
185
+ :,
186
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
+ :,
189
+ ]
190
+
191
+ out = out.permute(0, 3, 1, 2)
192
+ out = out.reshape(
193
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
+ )
195
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
+ out = F.conv2d(out, w)
197
+ out = out.reshape(
198
+ -1,
199
+ minor,
200
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
+ )
203
+ out = out.permute(0, 2, 3, 1)
204
+ out = out[:, ::down_y, ::down_x, :]
205
+
206
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
+
209
+ return out.view(-1, channel, out_h, out_w)
op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }