Upload 35 files
Browse files- FSC_pretrain.py +380 -0
- FSC_tain.py +532 -0
- FSC_test.py +352 -0
- LICENSE +21 -0
- README.md +100 -3
- __pycache__/models_crossvit.cpython-38.pyc +0 -0
- __pycache__/models_mae_cross.cpython-38.pyc +0 -0
- __pycache__/models_mae_noct.cpython-38.pyc +0 -0
- __pycache__/models_mae_noct.cpython-39.pyc +0 -0
- biclassify.py +163 -0
- datasetmake.py +53 -0
- figure.png +0 -0
- grounding_neg.py +188 -0
- grounding_pos.py +141 -0
- models_crossvit.py +155 -0
- models_mae_cross.py +253 -0
- models_mae_noct.py +234 -0
- requirements.txt +15 -0
- util/FSC147.py +524 -0
- util/__pycache__/FSC147.cpython-38.pyc +0 -0
- util/__pycache__/FSC147.cpython-39.pyc +0 -0
- util/__pycache__/FSC147_test.cpython-38.pyc +0 -0
- util/__pycache__/lr_sched.cpython-38.pyc +0 -0
- util/__pycache__/lr_sched.cpython-39.pyc +0 -0
- util/__pycache__/misc.cpython-38.pyc +0 -0
- util/__pycache__/misc.cpython-39.pyc +0 -0
- util/__pycache__/pos_embed.cpython-38.pyc +0 -0
- util/__pycache__/pos_embed.cpython-39.pyc +0 -0
- util/crop.py +42 -0
- util/datasets.py +65 -0
- util/lars.py +47 -0
- util/lr_decay.py +76 -0
- util/lr_sched.py +21 -0
- util/misc.py +624 -0
- util/pos_embed.py +97 -0
FSC_pretrain.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
|
5 |
+
import PIL.Image
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import random
|
10 |
+
from pathlib import Path
|
11 |
+
import math
|
12 |
+
import sys
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.backends.cudnn as cudnn
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch.utils.data import Dataset
|
20 |
+
import wandb
|
21 |
+
import timm
|
22 |
+
|
23 |
+
assert "0.4.5" <= timm.__version__ <= "0.4.9" # version check
|
24 |
+
import timm.optim.optim_factory as optim_factory
|
25 |
+
|
26 |
+
import util.misc as misc
|
27 |
+
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
28 |
+
import util.lr_sched as lr_sched
|
29 |
+
from util.FSC147 import transform_pre_train
|
30 |
+
import models_mae_noct
|
31 |
+
|
32 |
+
|
33 |
+
def get_args_parser():
|
34 |
+
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
|
35 |
+
parser.add_argument('--batch_size', default=8, type=int,
|
36 |
+
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
37 |
+
parser.add_argument('--epochs', default=200, type=int)
|
38 |
+
parser.add_argument('--accum_iter', default=1, type=int,
|
39 |
+
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
40 |
+
|
41 |
+
# Model parameters
|
42 |
+
parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
|
43 |
+
help='Name of model to train')
|
44 |
+
|
45 |
+
parser.add_argument('--mask_ratio', default=0.5, type=float,
|
46 |
+
help='Masking ratio (percentage of removed patches).')
|
47 |
+
|
48 |
+
parser.add_argument('--norm_pix_loss', action='store_true',
|
49 |
+
help='Use (per-patch) normalized pixels as targets for computing loss')
|
50 |
+
parser.set_defaults(norm_pix_loss=False)
|
51 |
+
|
52 |
+
# Optimizer parameters
|
53 |
+
parser.add_argument('--weight_decay', type=float, default=0.05,
|
54 |
+
help='weight decay (default: 0.05)')
|
55 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
56 |
+
help='learning rate (absolute lr)')
|
57 |
+
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
|
58 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
59 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
60 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
61 |
+
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
62 |
+
help='epochs to warmup LR')
|
63 |
+
|
64 |
+
# Dataset parameters
|
65 |
+
parser.add_argument('--data_path', default='./data/FSC147/', type=str,
|
66 |
+
help='dataset path')
|
67 |
+
parser.add_argument('--anno_file', default='annotation_FSC147_384.json', type=str,
|
68 |
+
help='annotation json file')
|
69 |
+
parser.add_argument('--data_split_file', default='Train_Test_Val_FSC_147.json', type=str,
|
70 |
+
help='data split json file')
|
71 |
+
parser.add_argument('--im_dir', default='images_384_VarV2', type=str,
|
72 |
+
help='images directory')
|
73 |
+
parser.add_argument('--gt_dir', default='gt_density_map_adaptive_384_VarV2', type=str,
|
74 |
+
help='ground truth directory')
|
75 |
+
parser.add_argument('--output_dir', default='./data/out/pre_4_dir',
|
76 |
+
help='path where to save, empty for no saving')
|
77 |
+
parser.add_argument('--device', default='cuda:5',
|
78 |
+
help='device to use for training / testing')
|
79 |
+
parser.add_argument('--seed', default=0, type=int)
|
80 |
+
parser.add_argument('--resume', default='./weights/mae_pretrain_vit_base_full.pth', # mae_visualize_vit_base
|
81 |
+
help='resume from checkpoint')
|
82 |
+
|
83 |
+
# Training parameters
|
84 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
85 |
+
help='start epoch')
|
86 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
87 |
+
parser.add_argument('--pin_mem', action='store_true',
|
88 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
89 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
90 |
+
parser.set_defaults(pin_mem=True)
|
91 |
+
|
92 |
+
# Distributed training parameters
|
93 |
+
parser.add_argument('--world_size', default=1, type=int,
|
94 |
+
help='number of distributed processes')
|
95 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
96 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
97 |
+
parser.add_argument('--dist_url', default='env://',
|
98 |
+
help='url used to set up distributed training')
|
99 |
+
|
100 |
+
# Logging parameters
|
101 |
+
parser.add_argument('--log_dir', default='./logs/pre_4_dir',
|
102 |
+
help='path where to tensorboard log')
|
103 |
+
parser.add_argument("--title", default="CounTR_pretraining", type=str)
|
104 |
+
parser.add_argument("--wandb", default="counting", type=str)
|
105 |
+
parser.add_argument("--team", default="wsense", type=str)
|
106 |
+
parser.add_argument("--wandb_id", default=None, type=str)
|
107 |
+
parser.add_argument('--anno_file_negative', default='annotation_FSC147_negative1.json', type=str,
|
108 |
+
help='annotation json file')
|
109 |
+
return parser
|
110 |
+
|
111 |
+
|
112 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = '5'
|
113 |
+
|
114 |
+
|
115 |
+
class TrainData(Dataset):
|
116 |
+
def __init__(self):
|
117 |
+
self.img = data_split['train']
|
118 |
+
random.shuffle(self.img)
|
119 |
+
self.img_dir = im_dir
|
120 |
+
self.TransformPreTrain = transform_pre_train(data_path)
|
121 |
+
|
122 |
+
def __len__(self):
|
123 |
+
return len(self.img)
|
124 |
+
|
125 |
+
def __getitem__(self, idx):
|
126 |
+
im_id = self.img[idx]
|
127 |
+
anno = annotations[im_id]
|
128 |
+
bboxes = anno['box_examples_coordinates']
|
129 |
+
# box_coordinates = anno.get('box_examples_coordinates', {}) # 获取图像的边界框坐标信息
|
130 |
+
# # print(box_coordinates)
|
131 |
+
# # 获取第一个类别的边界框坐标列表
|
132 |
+
# first_category = next(iter(box_coordinates), None)
|
133 |
+
# # print(first_category)
|
134 |
+
# first_category_bboxes = box_coordinates[first_category]
|
135 |
+
# if first_category_bboxes:
|
136 |
+
# # print(first_category_bboxes[0])
|
137 |
+
# bboxes = first_category_bboxes[0]
|
138 |
+
# else:
|
139 |
+
# bboxes = []
|
140 |
+
# # if first_category_bboxes:
|
141 |
+
# # bboxes = first_category_bboxes[0]
|
142 |
+
# # else:
|
143 |
+
# # pass
|
144 |
+
|
145 |
+
|
146 |
+
rects = list()
|
147 |
+
for bbox in bboxes:
|
148 |
+
x1 = bbox[0][0]
|
149 |
+
y1 = bbox[0][1]
|
150 |
+
x2 = bbox[2][0]
|
151 |
+
y2 = bbox[2][1]
|
152 |
+
rects.append([y1, x1, y2, x2])
|
153 |
+
|
154 |
+
image = Image.open('{}/{}'.format(im_dir, im_id))
|
155 |
+
image.load()
|
156 |
+
density_path = gt_dir / (im_id.split(".jpg")[0] + ".npy")
|
157 |
+
density = np.load(density_path).astype('float32')
|
158 |
+
sample = {'image': image, 'lines_boxes': rects, 'gt_density': density}
|
159 |
+
sample = self.TransformPreTrain(sample)
|
160 |
+
return sample['image']
|
161 |
+
|
162 |
+
|
163 |
+
def main(args):
|
164 |
+
misc.init_distributed_mode(args)
|
165 |
+
|
166 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
167 |
+
print("{}".format(args).replace(', ', ',\n'))
|
168 |
+
|
169 |
+
device = torch.device(args.device)
|
170 |
+
|
171 |
+
# fix the seed for reproducibility
|
172 |
+
seed = args.seed + misc.get_rank()
|
173 |
+
torch.manual_seed(seed)
|
174 |
+
np.random.seed(seed)
|
175 |
+
|
176 |
+
cudnn.benchmark = True
|
177 |
+
|
178 |
+
dataset_train = TrainData()
|
179 |
+
print(dataset_train)
|
180 |
+
|
181 |
+
if True: # args.distributed:
|
182 |
+
num_tasks = misc.get_world_size()
|
183 |
+
global_rank = misc.get_rank()
|
184 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
185 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
186 |
+
)
|
187 |
+
print("Sampler_train = %s" % str(sampler_train))
|
188 |
+
else:
|
189 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
190 |
+
|
191 |
+
if global_rank == 0:
|
192 |
+
if args.log_dir is not None:
|
193 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
194 |
+
log_writer = SummaryWriter(log_dir=args.log_dir)
|
195 |
+
else:
|
196 |
+
log_writer = None
|
197 |
+
if args.wandb is not None:
|
198 |
+
wandb_run = wandb.init(
|
199 |
+
config=args,
|
200 |
+
resume="allow",
|
201 |
+
project=args.wandb,
|
202 |
+
name=args.title,
|
203 |
+
# entity=args.team,
|
204 |
+
tags=["CounTR", "pretraining"],
|
205 |
+
id=args.wandb_id,
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
wandb_run = None
|
209 |
+
|
210 |
+
data_loader_train = torch.utils.data.DataLoader(
|
211 |
+
dataset_train, sampler=sampler_train,
|
212 |
+
batch_size=args.batch_size,
|
213 |
+
num_workers=args.num_workers,
|
214 |
+
pin_memory=args.pin_mem,
|
215 |
+
drop_last=False,
|
216 |
+
)
|
217 |
+
|
218 |
+
# define the model
|
219 |
+
model = models_mae_noct.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
|
220 |
+
|
221 |
+
model.to(device)
|
222 |
+
|
223 |
+
model_without_ddp = model
|
224 |
+
|
225 |
+
print("Model = %s" % str(model_without_ddp))
|
226 |
+
|
227 |
+
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
228 |
+
|
229 |
+
if args.lr is None: # only base_lr is specified
|
230 |
+
args.lr = args.blr * eff_batch_size / 256
|
231 |
+
|
232 |
+
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
233 |
+
print("actual lr: %.2e" % args.lr)
|
234 |
+
|
235 |
+
print("accumulate grad iterations: %d" % args.accum_iter)
|
236 |
+
print("effective batch size: %d" % eff_batch_size)
|
237 |
+
|
238 |
+
if args.distributed:
|
239 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
240 |
+
model_without_ddp = model.module
|
241 |
+
|
242 |
+
# following timm: set wd as 0 for bias and norm layers
|
243 |
+
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
|
244 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
245 |
+
print(optimizer)
|
246 |
+
loss_scaler = NativeScaler()
|
247 |
+
|
248 |
+
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
249 |
+
|
250 |
+
print(f"Start training for {args.epochs} epochs")
|
251 |
+
start_time = time.time()
|
252 |
+
for epoch in range(args.start_epoch, args.epochs):
|
253 |
+
if args.distributed:
|
254 |
+
data_loader_train.sampler.set_epoch(epoch)
|
255 |
+
|
256 |
+
# train one epoch
|
257 |
+
model.train(True)
|
258 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
259 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
260 |
+
header = 'Epoch: [{}]'.format(epoch)
|
261 |
+
print_freq = 20
|
262 |
+
accum_iter = args.accum_iter
|
263 |
+
|
264 |
+
optimizer.zero_grad()
|
265 |
+
|
266 |
+
if log_writer is not None:
|
267 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
268 |
+
|
269 |
+
model_ = getattr(models_mae_noct, args.model)()
|
270 |
+
|
271 |
+
for data_iter_step, samples in enumerate(metric_logger.log_every(data_loader_train, print_freq, header)):
|
272 |
+
epoch_1000x = int((data_iter_step / len(data_loader_train) + epoch) * 1000)
|
273 |
+
|
274 |
+
if data_iter_step % accum_iter == 0:
|
275 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
|
276 |
+
|
277 |
+
samples = samples.to(device, non_blocking=True)
|
278 |
+
|
279 |
+
with torch.cuda.amp.autocast():
|
280 |
+
loss, pred, mask = model(samples, mask_ratio=args.mask_ratio)
|
281 |
+
|
282 |
+
loss_value = loss.item()
|
283 |
+
|
284 |
+
if data_iter_step % 2000 == 0:
|
285 |
+
preds = model_.unpatchify(pred)
|
286 |
+
preds = preds.float()
|
287 |
+
preds = torch.einsum('nchw->nhwc', preds)
|
288 |
+
preds = torch.clip(preds, 0, 1)
|
289 |
+
|
290 |
+
if log_writer is not None:
|
291 |
+
log_writer.add_images('reconstruction', preds, int(epoch), dataformats='NHWC')
|
292 |
+
|
293 |
+
if wandb_run is not None:
|
294 |
+
wandb_images = []
|
295 |
+
w_samples = torch.einsum('nchw->nhwc', samples.float()).clip(0, 1)
|
296 |
+
masks = F.interpolate(
|
297 |
+
mask.reshape(shape=(mask.shape[0], 1, int(mask.shape[1] ** .5), int(mask.shape[1] ** .5))),
|
298 |
+
size=(preds.shape[1], preds.shape[2]))
|
299 |
+
masks = torch.einsum('nchw->nhwc', masks.float())
|
300 |
+
combos = (w_samples + masks.repeat(1, 1, 1, 3)).clip(0, 1)
|
301 |
+
w_images = (torch.cat([w_samples, combos, preds], dim=2) * 255).detach().cpu()
|
302 |
+
print("w_images:", w_samples.shape, combos.shape, preds.shape, "-->", w_images.shape)
|
303 |
+
|
304 |
+
for i in range(w_images.shape[0]):
|
305 |
+
wi = w_images[i, :, :, :]
|
306 |
+
wandb_images += [wandb.Image(wi.numpy().astype(np.uint8),
|
307 |
+
caption=f"Prediction {i} at epoch {epoch}")]
|
308 |
+
wandb.log({f"reconstruction": wandb_images}, step=epoch_1000x, commit=False)
|
309 |
+
|
310 |
+
if not math.isfinite(loss_value):
|
311 |
+
print("Loss is {}, stopping training".format(loss_value))
|
312 |
+
sys.exit(1)
|
313 |
+
|
314 |
+
loss /= accum_iter
|
315 |
+
loss_scaler(loss, optimizer, parameters=model.parameters(),
|
316 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
317 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
318 |
+
optimizer.zero_grad()
|
319 |
+
|
320 |
+
torch.cuda.synchronize()
|
321 |
+
|
322 |
+
metric_logger.update(loss=loss_value)
|
323 |
+
|
324 |
+
lr = optimizer.param_groups[0]["lr"]
|
325 |
+
metric_logger.update(lr=lr)
|
326 |
+
|
327 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
328 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
329 |
+
if log_writer is not None:
|
330 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
331 |
+
This calibrates different curves when batch size changes.
|
332 |
+
"""
|
333 |
+
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
|
334 |
+
log_writer.add_scalar('lr', lr, epoch_1000x)
|
335 |
+
if wandb_run is not None:
|
336 |
+
log = {"train/loss": loss_value_reduce, "train/lr": lr}
|
337 |
+
wandb.log(log, step=epoch_1000x, commit=True if data_iter_step == 0 else False)
|
338 |
+
|
339 |
+
metric_logger.synchronize_between_processes()
|
340 |
+
print("Averaged stats:", metric_logger)
|
341 |
+
train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
342 |
+
|
343 |
+
# save train status and model
|
344 |
+
if args.output_dir and (epoch % 100 == 0 or epoch + 1 == args.epochs):
|
345 |
+
misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
346 |
+
loss_scaler=loss_scaler, epoch=epoch, suffix=f"pretraining_{epoch}")
|
347 |
+
|
348 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
349 |
+
'epoch': epoch, }
|
350 |
+
|
351 |
+
if args.output_dir and misc.is_main_process():
|
352 |
+
if log_writer is not None:
|
353 |
+
log_writer.flush()
|
354 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
355 |
+
f.write(json.dumps(log_stats) + "\n")
|
356 |
+
|
357 |
+
total_time = time.time() - start_time
|
358 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
359 |
+
print('Training time {}'.format(total_time_str))
|
360 |
+
wandb.run.finish()
|
361 |
+
|
362 |
+
|
363 |
+
if __name__ == '__main__':
|
364 |
+
args = get_args_parser()
|
365 |
+
args = args.parse_args()
|
366 |
+
|
367 |
+
# load data
|
368 |
+
data_path = Path(args.data_path)
|
369 |
+
anno_file = data_path / args.anno_file
|
370 |
+
data_split_file = data_path / args.data_split_file
|
371 |
+
im_dir = data_path / args.im_dir
|
372 |
+
gt_dir = data_path / args.gt_dir
|
373 |
+
with open(anno_file) as f:
|
374 |
+
annotations = json.load(f)
|
375 |
+
with open(data_split_file) as f:
|
376 |
+
data_split = json.load(f)
|
377 |
+
|
378 |
+
if args.output_dir:
|
379 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
380 |
+
main(args)
|
FSC_tain.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import random
|
8 |
+
from pathlib import Path
|
9 |
+
import sys
|
10 |
+
from PIL import Image
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch
|
13 |
+
import torch.backends.cudnn as cudnn
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
import torchvision
|
16 |
+
import wandb
|
17 |
+
import timm
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
assert "0.4.5" <= timm.__version__ <= "0.4.9" # version check
|
21 |
+
import timm.optim.optim_factory as optim_factory
|
22 |
+
|
23 |
+
import util.misc as misc
|
24 |
+
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
25 |
+
import util.lr_sched as lr_sched
|
26 |
+
from util.FSC147 import transform_train, transform_val
|
27 |
+
import models_mae_cross
|
28 |
+
|
29 |
+
|
30 |
+
def get_args_parser():
|
31 |
+
parser = argparse.ArgumentParser('MAE pre-training', add_help=True)
|
32 |
+
parser.add_argument('--batch_size', default=26, type=int,
|
33 |
+
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)')
|
34 |
+
parser.add_argument('--epochs', default=200, type=int)
|
35 |
+
parser.add_argument('--accum_iter', default=1, type=int,
|
36 |
+
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
37 |
+
|
38 |
+
# Model parameters
|
39 |
+
parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
|
40 |
+
help='Name of model to train')
|
41 |
+
parser.add_argument('--mask_ratio', default=0.5, type=float,
|
42 |
+
help='Masking ratio (percentage of removed patches).')
|
43 |
+
parser.add_argument('--norm_pix_loss', action='store_true',
|
44 |
+
help='Use (per-patch) normalized pixels as targets for computing loss')
|
45 |
+
parser.set_defaults(norm_pix_loss=False)
|
46 |
+
|
47 |
+
# Optimizer parameters
|
48 |
+
parser.add_argument('--weight_decay', type=float, default=0.05,
|
49 |
+
help='weight decay (default: 0.05)')
|
50 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
51 |
+
help='learning rate (absolute lr)')
|
52 |
+
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
|
53 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
54 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
55 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
56 |
+
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
57 |
+
help='epochs to warmup LR')
|
58 |
+
|
59 |
+
# Dataset parameters
|
60 |
+
parser.add_argument('--data_path', default='./data/FSC147/', type=str,
|
61 |
+
help='dataset path')
|
62 |
+
parser.add_argument('--anno_file', default='annotation_FSC147_pos.json', type=str,
|
63 |
+
help='annotation json file for positive samples')
|
64 |
+
parser.add_argument('--anno_file_negative', default='./data/FSC147/annotation_FSC147_neg.json', type=str,
|
65 |
+
help='annotation json file for negative samples')
|
66 |
+
parser.add_argument('--data_split_file', default='Train_Test_Val_FSC_147.json', type=str,
|
67 |
+
help='data split json file')
|
68 |
+
parser.add_argument('--class_file', default='ImageClasses_FSC147.txt', type=str,
|
69 |
+
help='class json file')
|
70 |
+
parser.add_argument('--im_dir', default='images_384_VarV2', type=str,
|
71 |
+
help='images directory')
|
72 |
+
parser.add_argument('--output_dir', default='./data/out/fim6_dir',
|
73 |
+
help='path where to save, empty for no saving')
|
74 |
+
parser.add_argument('--device', default='cuda',
|
75 |
+
help='device to use for training / testing')
|
76 |
+
parser.add_argument('--seed', default=0, type=int)
|
77 |
+
parser.add_argument('--resume', default='./data/checkpoint.pth',
|
78 |
+
help='resume from checkpoint')
|
79 |
+
parser.add_argument('--do_resume', action='store_true',
|
80 |
+
help='Resume training (e.g. if crashed).')
|
81 |
+
|
82 |
+
# Training parameters
|
83 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
84 |
+
help='start epoch')
|
85 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
86 |
+
parser.add_argument('--pin_mem', action='store_true',
|
87 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
88 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
89 |
+
parser.set_defaults(pin_mem=True)
|
90 |
+
parser.add_argument('--do_aug', action='store_true',
|
91 |
+
help='Perform data augmentation.')
|
92 |
+
parser.add_argument('--no_do_aug', action='store_false', dest='do_aug')
|
93 |
+
parser.set_defaults(do_aug=True)
|
94 |
+
|
95 |
+
# Distributed training parameters
|
96 |
+
parser.add_argument('--world_size', default=1, type=int,
|
97 |
+
help='number of distributed processes')
|
98 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
99 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
100 |
+
parser.add_argument('--dist_url', default='env://',
|
101 |
+
help='url used to set up distributed training')
|
102 |
+
|
103 |
+
# Logging parameters
|
104 |
+
parser.add_argument("--title", default="count", type=str)
|
105 |
+
parser.add_argument("--wandb", default="240227", type=str)
|
106 |
+
parser.add_argument("--team", default="wsense", type=str)
|
107 |
+
parser.add_argument("--wandb_id", default=None, type=str)
|
108 |
+
|
109 |
+
return parser
|
110 |
+
|
111 |
+
|
112 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = '0'
|
113 |
+
|
114 |
+
class TrainData(Dataset):
|
115 |
+
def __init__(self, args, split='train', do_aug=True):
|
116 |
+
with open(args.anno_file) as f:
|
117 |
+
annotations = json.load(f)
|
118 |
+
# Load negative annotations
|
119 |
+
with open(args.anno_file_negative) as f:
|
120 |
+
neg_annotations = json.load(f)
|
121 |
+
with open(args.data_split_file) as f:
|
122 |
+
data_split = json.load(f)
|
123 |
+
|
124 |
+
self.img = data_split[split]
|
125 |
+
random.shuffle(self.img)
|
126 |
+
self.split = split
|
127 |
+
self.img_dir = im_dir
|
128 |
+
self.TransformTrain = transform_train(args, do_aug=do_aug)
|
129 |
+
self.TransformVal = transform_val(args)
|
130 |
+
self.annotations = annotations
|
131 |
+
self.neg_annotations = neg_annotations
|
132 |
+
self.im_dir = im_dir
|
133 |
+
|
134 |
+
def __len__(self):
|
135 |
+
return len(self.img)
|
136 |
+
|
137 |
+
def __getitem__(self, idx):
|
138 |
+
im_id = self.img[idx]
|
139 |
+
anno = self.annotations[im_id]
|
140 |
+
bboxes = anno['box_examples_coordinates']
|
141 |
+
dots = np.array(anno['points'])
|
142 |
+
|
143 |
+
# 加载负样本的框
|
144 |
+
neg_anno = self.neg_annotations[im_id] # 假设每个图像ID在负样本注释中都有对应的条目
|
145 |
+
neg_bboxes = neg_anno['box_examples_coordinates']
|
146 |
+
|
147 |
+
rects = list()
|
148 |
+
for bbox in bboxes:
|
149 |
+
x1 = bbox[0][0]
|
150 |
+
y1 = bbox[0][1]
|
151 |
+
x2 = bbox[2][0]
|
152 |
+
y2 = bbox[2][1]
|
153 |
+
if x1 < 0:
|
154 |
+
x1 = 0
|
155 |
+
if x2 < 0:
|
156 |
+
x2 = 0
|
157 |
+
if y1 < 0:
|
158 |
+
y1 = 0
|
159 |
+
if y2 < 0:
|
160 |
+
y2 = 0
|
161 |
+
|
162 |
+
rects.append([y1, x1, y2, x2])
|
163 |
+
neg_rects = list()
|
164 |
+
for neg_bbox in neg_bboxes:
|
165 |
+
x1 = neg_bbox[0][0]
|
166 |
+
y1 = neg_bbox[0][1]
|
167 |
+
x2 = neg_bbox[2][0]
|
168 |
+
y2 = neg_bbox[2][1]
|
169 |
+
if x1 < 0:
|
170 |
+
x1 = 0
|
171 |
+
if x2 < 0:
|
172 |
+
x2 = 0
|
173 |
+
if y1 < 0:
|
174 |
+
y1 = 0
|
175 |
+
if y2 < 0:
|
176 |
+
y2 = 0
|
177 |
+
|
178 |
+
neg_rects.append([y1, x1, y2, x2])
|
179 |
+
|
180 |
+
image = Image.open('{}/{}'.format(self.im_dir, im_id))
|
181 |
+
if image.mode == "RGBA":
|
182 |
+
image = image.convert("RGB")
|
183 |
+
image.load()
|
184 |
+
m_flag = 0
|
185 |
+
|
186 |
+
sample = {'image': image, 'lines_boxes': rects, 'neg_lines_boxes': neg_rects,'dots': dots, 'id': im_id, 'm_flag': m_flag}
|
187 |
+
sample = self.TransformTrain(sample) if self.split == "train" else self.TransformVal(sample)
|
188 |
+
return sample['image'], sample['gt_density'], len(dots), sample['boxes'],sample['neg_boxes'], sample['pos'],sample['m_flag'], im_id
|
189 |
+
|
190 |
+
def main(args):
|
191 |
+
wandb_run = None
|
192 |
+
try:
|
193 |
+
misc.init_distributed_mode(args)
|
194 |
+
|
195 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
196 |
+
print("{}".format(args).replace(', ', ',\n'))
|
197 |
+
|
198 |
+
device = torch.device(args.device)
|
199 |
+
# if torch.cuda.is_available():
|
200 |
+
# device = torch.device("cuda:5")
|
201 |
+
|
202 |
+
# fix the seed for reproducibility
|
203 |
+
seed = args.seed + misc.get_rank()
|
204 |
+
torch.manual_seed(seed)
|
205 |
+
np.random.seed(seed)
|
206 |
+
cudnn.benchmark = True
|
207 |
+
|
208 |
+
dataset_train = TrainData(args, do_aug=args.do_aug)
|
209 |
+
dataset_val = TrainData(args, split='val')
|
210 |
+
|
211 |
+
num_tasks = misc.get_world_size()
|
212 |
+
global_rank = misc.get_rank()
|
213 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
214 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
215 |
+
)
|
216 |
+
sampler_val = torch.utils.data.DistributedSampler(
|
217 |
+
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
218 |
+
)
|
219 |
+
|
220 |
+
if global_rank == 0:
|
221 |
+
if args.wandb is not None:
|
222 |
+
wandb_run = wandb.init(
|
223 |
+
config=args,
|
224 |
+
resume="allow",
|
225 |
+
project=args.wandb,
|
226 |
+
name=args.title,
|
227 |
+
# entity=args.team,
|
228 |
+
tags=["count", "finetuning"],
|
229 |
+
id=args.wandb_id,
|
230 |
+
)
|
231 |
+
|
232 |
+
data_loader_train = torch.utils.data.DataLoader(
|
233 |
+
dataset_train, sampler=sampler_train,
|
234 |
+
batch_size=args.batch_size,
|
235 |
+
num_workers=args.num_workers,
|
236 |
+
pin_memory=args.pin_mem,
|
237 |
+
drop_last=False,
|
238 |
+
)
|
239 |
+
data_loader_val = torch.utils.data.DataLoader(
|
240 |
+
dataset_val, sampler=sampler_val,
|
241 |
+
batch_size=args.batch_size,
|
242 |
+
num_workers=args.num_workers,
|
243 |
+
pin_memory=args.pin_mem,
|
244 |
+
drop_last=False,
|
245 |
+
)
|
246 |
+
|
247 |
+
# define the model
|
248 |
+
model = models_mae_cross.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
|
249 |
+
model.to(device)
|
250 |
+
model_without_ddp = model
|
251 |
+
# print("Model = %s" % str(model_without_ddp))
|
252 |
+
|
253 |
+
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
254 |
+
|
255 |
+
if args.lr is None: # only base_lr is specified
|
256 |
+
args.lr = args.blr * eff_batch_size / 256
|
257 |
+
|
258 |
+
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
259 |
+
print("actual lr: %.2e" % args.lr)
|
260 |
+
|
261 |
+
print("accumulate grad iterations: %d" % args.accum_iter)
|
262 |
+
print("effective batch size: %d" % eff_batch_size)
|
263 |
+
|
264 |
+
if args.distributed:
|
265 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
266 |
+
model_without_ddp = model.module
|
267 |
+
|
268 |
+
# following timm: set wd as 0 for bias and norm layers
|
269 |
+
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
|
270 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
271 |
+
print(optimizer)
|
272 |
+
|
273 |
+
loss_scaler = NativeScaler()
|
274 |
+
|
275 |
+
min_MAE = 99999
|
276 |
+
print_freq = 50
|
277 |
+
save_freq = 50
|
278 |
+
|
279 |
+
misc.load_model_FSC_full(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
280 |
+
|
281 |
+
print(f"Start training for {args.epochs - args.start_epoch} epochs - rank {global_rank}")
|
282 |
+
start_time = time.time()
|
283 |
+
for epoch in range(args.start_epoch, args.epochs):
|
284 |
+
if args.distributed:
|
285 |
+
data_loader_train.sampler.set_epoch(epoch)
|
286 |
+
|
287 |
+
# train one epoch
|
288 |
+
model.train(True)
|
289 |
+
accum_iter = args.accum_iter
|
290 |
+
|
291 |
+
# some parameters in training
|
292 |
+
train_mae = torch.tensor([0], dtype=torch.float64, device=device)
|
293 |
+
train_mse = torch.tensor([0], dtype=torch.float64, device=device)
|
294 |
+
val_mae = torch.tensor([0], dtype=torch.float64, device=device)
|
295 |
+
val_mse = torch.tensor([0], dtype=torch.float64, device=device)
|
296 |
+
val_nae = torch.tensor([0], dtype=torch.float64, device=device)
|
297 |
+
|
298 |
+
optimizer.zero_grad()
|
299 |
+
|
300 |
+
for data_iter_step, (samples, gt_density, _, pos_boxes, neg_boxes, pos, m_flag, im_names) in enumerate(
|
301 |
+
tqdm(data_loader_train, total=len(data_loader_train), desc=f"Train [e. {epoch} - r. {global_rank}]")):
|
302 |
+
idx = data_iter_step + (epoch * len(data_loader_train))
|
303 |
+
|
304 |
+
if data_iter_step % accum_iter == 0:
|
305 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
|
306 |
+
|
307 |
+
samples = samples.to(device, non_blocking=True, dtype=torch.half)
|
308 |
+
gt_density = gt_density.to(device, non_blocking=True, dtype=torch.half)
|
309 |
+
pos_boxes = pos_boxes.to(device, non_blocking=True, dtype=torch.half)
|
310 |
+
neg_boxes = neg_boxes.to(device, non_blocking=True, dtype=torch.half)
|
311 |
+
|
312 |
+
# 如果至少有一个图像在批处理中使用了Type 2 Mosaic,则禁止0-shot。
|
313 |
+
flag = 0
|
314 |
+
for i in range(m_flag.shape[0]):
|
315 |
+
flag += m_flag[i].item()
|
316 |
+
if flag == 0:
|
317 |
+
shot_num = random.randint(0, 3)
|
318 |
+
else:
|
319 |
+
shot_num = random.randint(1, 3)
|
320 |
+
|
321 |
+
with torch.cuda.amp.autocast():
|
322 |
+
pos_output = model(samples, pos_boxes, shot_num) # 正样本输出
|
323 |
+
|
324 |
+
# 计算正样本损失
|
325 |
+
mask = np.random.binomial(n=1, p=0.8, size=[384, 384])
|
326 |
+
masks = np.tile(mask, (pos_output.shape[0], 1))
|
327 |
+
masks = masks.reshape(pos_output.shape[0], 384, 384)
|
328 |
+
masks = torch.from_numpy(masks).to(device)
|
329 |
+
pos_loss = ((pos_output - gt_density) ** 2)
|
330 |
+
pos_loss = (pos_loss * masks / (384 * 384)).sum() / pos_output.shape[0]
|
331 |
+
# 负样本输出
|
332 |
+
|
333 |
+
with torch.cuda.amp.autocast():
|
334 |
+
neg_output = model(samples, neg_boxes, 1) # 负样本输出
|
335 |
+
|
336 |
+
cnt1 = 1-torch.exp(-(torch.abs(pos_output.sum()/60 - gt_density.sum()/60).mean()))
|
337 |
+
if neg_output.shape[0] == 0:
|
338 |
+
cnt2 = 0
|
339 |
+
else:
|
340 |
+
# cnt2 = torch.log(torch.abs((neg_output.sum() / neg_output.shape[0]) - 1).mean()+1)
|
341 |
+
cnt2 = 1-torch.exp(-(torch.abs((neg_output.sum() / (neg_output.shape[0]*60)) - 1).mean()))
|
342 |
+
cnt = cnt1+cnt2
|
343 |
+
|
344 |
+
# 计算正样本损失
|
345 |
+
mask = np.random.binomial(n=1, p=0.8, size=[384, 384])
|
346 |
+
masks = np.tile(mask, (neg_output.shape[0], 1))
|
347 |
+
masks = masks.reshape(neg_output.shape[0], 384, 384)
|
348 |
+
masks = torch.from_numpy(masks).to(device)
|
349 |
+
neg_loss = ((neg_output - gt_density) ** 2)
|
350 |
+
if neg_output.shape[0] == 0:
|
351 |
+
neg_loss = 1
|
352 |
+
else:
|
353 |
+
neg_loss = (neg_loss * masks / (384 * 384)).sum() / neg_output.shape[0]
|
354 |
+
margin = 0.5
|
355 |
+
contrastive_loss = torch.relu(pos_loss - neg_loss + margin)
|
356 |
+
total_loss = contrastive_loss+pos_loss
|
357 |
+
|
358 |
+
|
359 |
+
# 更新 MAE 和 RMSE
|
360 |
+
with torch.no_grad():
|
361 |
+
pred_cnt = (pos_output.view(len(samples), -1)).sum(1) / 60
|
362 |
+
gt_cnt = (gt_density.view(len(samples), -1)).sum(1) / 60
|
363 |
+
cnt_err = torch.abs(pred_cnt - gt_cnt).float()
|
364 |
+
batch_mae = cnt_err.double().mean()
|
365 |
+
batch_mse = (cnt_err ** 2).double().mean()
|
366 |
+
|
367 |
+
train_mae += batch_mae
|
368 |
+
train_mse += batch_mse
|
369 |
+
|
370 |
+
if not torch.isfinite(total_loss):
|
371 |
+
print("Loss is {}, stopping training".format(total_loss))
|
372 |
+
sys.exit(1)
|
373 |
+
|
374 |
+
total_loss /= accum_iter
|
375 |
+
loss_scaler(total_loss, optimizer, parameters=model.parameters(),
|
376 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
377 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
378 |
+
optimizer.zero_grad()
|
379 |
+
|
380 |
+
lr = optimizer.param_groups[0]["lr"]
|
381 |
+
loss_value_reduce = misc.all_reduce_mean(total_loss)
|
382 |
+
|
383 |
+
if (data_iter_step + 1) % (print_freq * accum_iter) == 0 and (data_iter_step + 1) != len(data_loader_train) and data_iter_step != 0:
|
384 |
+
if wandb_run is not None:
|
385 |
+
log = {"train/loss": loss_value_reduce,
|
386 |
+
"train/lr": lr,
|
387 |
+
"train/MAE": batch_mae,
|
388 |
+
"train/RMSE": batch_mse ** 0.5}
|
389 |
+
wandb.log(log, step=idx)
|
390 |
+
|
391 |
+
# evaluation on Validation split
|
392 |
+
for val_samples, val_gt_density, val_n_ppl, val_boxes,_, val_pos, _, val_im_names in \
|
393 |
+
tqdm(data_loader_val, total=len(data_loader_val),
|
394 |
+
desc=f"Val [e. {epoch} - r. {global_rank}]"):
|
395 |
+
|
396 |
+
val_samples = val_samples.to(device, non_blocking=True, dtype=torch.half)
|
397 |
+
val_gt_density = val_gt_density.to(device, non_blocking=True, dtype=torch.half)
|
398 |
+
val_boxes = val_boxes.to(device, non_blocking=True, dtype=torch.half)
|
399 |
+
val_n_ppl = val_n_ppl.to(device, non_blocking=True)
|
400 |
+
shot_num = random.randint(0, 3)
|
401 |
+
|
402 |
+
with torch.no_grad():
|
403 |
+
with torch.cuda.amp.autocast():
|
404 |
+
val_output = model(val_samples, val_boxes, shot_num)
|
405 |
+
|
406 |
+
val_pred_cnt = (val_output.view(len(val_samples), -1)).sum(1) / 60
|
407 |
+
val_gt_cnt = (val_gt_density.view(len(val_samples), -1)).sum(1) / 60
|
408 |
+
# print('val_pred_cnt',val_pred_cnt)
|
409 |
+
# print('val_gt_cnt',val_gt_cnt)
|
410 |
+
val_cnt_err = torch.abs(val_pred_cnt - val_gt_cnt).float()
|
411 |
+
# print('val_cnt_err',val_cnt_err.mean())
|
412 |
+
val_cnt_err[val_cnt_err == float('inf')] = 0
|
413 |
+
val_mae += val_cnt_err.double().mean()
|
414 |
+
|
415 |
+
# val_mae += val_cnt_err
|
416 |
+
# print('val_mae',val_mae.mean())
|
417 |
+
val_cnt_err[val_cnt_err == float('inf')] = 0
|
418 |
+
val_mse += (val_cnt_err ** 2).double().mean()
|
419 |
+
|
420 |
+
# val_mse += (val_cnt_err ** 2)
|
421 |
+
_val_nae = val_cnt_err / val_gt_cnt
|
422 |
+
_val_nae[_val_nae == float('inf')] = 0
|
423 |
+
val_nae += _val_nae.double().mean()
|
424 |
+
# val_mae = val_mae/len(data_loader_val)
|
425 |
+
# val_mse = val_mse/len(data_loader_val)
|
426 |
+
# print('val_mae',val_mae)
|
427 |
+
# print('val_mse',val_mse)
|
428 |
+
# Output visualisation information to W&B
|
429 |
+
if wandb_run is not None:
|
430 |
+
train_wandb_densities = []
|
431 |
+
train_wandb_bboxes = []
|
432 |
+
val_wandb_densities = []
|
433 |
+
val_wandb_bboxes = []
|
434 |
+
black = torch.zeros([384, 384], device=device)
|
435 |
+
|
436 |
+
for i in range(pos_output.shape[0]):
|
437 |
+
# gt and predicted density
|
438 |
+
w_d_map = torch.stack([pos_output[i], black, black])
|
439 |
+
gt_map = torch.stack([gt_density[i], black, black])
|
440 |
+
box_map = misc.get_box_map(samples[i], pos[i], device)
|
441 |
+
w_gt_density = samples[i] / 2 + gt_map + box_map
|
442 |
+
w_d_map_overlay = samples[i] / 2 + w_d_map
|
443 |
+
w_densities = torch.cat([w_gt_density, w_d_map, w_d_map_overlay], dim=2)
|
444 |
+
w_densities = torch.clamp(w_densities, 0, 1)
|
445 |
+
train_wandb_densities += [wandb.Image(torchvision.transforms.ToPILImage()(w_densities),
|
446 |
+
caption=f"[E#{epoch}] {im_names[i]} ({torch.sum(gt_density[i]).item()}, {torch.sum(pos_output[i]).item()})")]
|
447 |
+
|
448 |
+
# exemplars
|
449 |
+
w_boxes = torch.cat([pos_boxes[i][x, :, :, :] for x in range(pos_boxes[i].shape[0])], 2)
|
450 |
+
train_wandb_bboxes += [wandb.Image(torchvision.transforms.ToPILImage()(w_boxes),
|
451 |
+
caption=f"[E#{epoch}] {im_names[i]}")]
|
452 |
+
|
453 |
+
for i in range(val_output.shape[0]):
|
454 |
+
# gt and predicted density
|
455 |
+
w_d_map = torch.stack([val_output[i], black, black])
|
456 |
+
gt_map = torch.stack([val_gt_density[i], black, black])
|
457 |
+
box_map = misc.get_box_map(val_samples[i], val_pos[i], device)
|
458 |
+
w_gt_density = val_samples[i] / 2 + gt_map + box_map
|
459 |
+
w_d_map_overlay = val_samples[i] / 2 + w_d_map
|
460 |
+
w_densities = torch.cat([w_gt_density, w_d_map, w_d_map_overlay], dim=2)
|
461 |
+
w_densities = torch.clamp(w_densities, 0, 1)
|
462 |
+
val_wandb_densities += [wandb.Image(torchvision.transforms.ToPILImage()(w_densities),
|
463 |
+
caption=f"[E#{epoch}] {val_im_names[i]} ({torch.sum(val_gt_density[i]).item()}, {torch.sum(val_output[i]).item()})")]
|
464 |
+
|
465 |
+
# exemplars
|
466 |
+
w_boxes = torch.cat([val_boxes[i][x, :, :, :] for x in range(val_boxes[i].shape[0])], 2)
|
467 |
+
val_wandb_bboxes += [wandb.Image(torchvision.transforms.ToPILImage()(w_boxes),
|
468 |
+
caption=f"[E#{epoch}] {val_im_names[i]}")]
|
469 |
+
|
470 |
+
log = {"train/loss": loss_value_reduce,
|
471 |
+
"train/lr": lr,
|
472 |
+
"train/MAE": batch_mae,
|
473 |
+
"train/RMSE": batch_mse ** 0.5,
|
474 |
+
"val/MAE": val_mae / len(data_loader_val),
|
475 |
+
"val/RMSE": (val_mse / len(data_loader_val)) ** 0.5,
|
476 |
+
"val/NAE": val_nae / len(data_loader_val),
|
477 |
+
"train_densitss": train_wandb_densities,
|
478 |
+
"val_densites": val_wandb_densities,
|
479 |
+
"train_boxes": train_wandb_bboxes,
|
480 |
+
"val_boxes": val_wandb_bboxes}
|
481 |
+
wandb.log(log, step=idx)
|
482 |
+
|
483 |
+
# save train status and model
|
484 |
+
if args.output_dir and (epoch % save_freq == 0 or epoch + 1 == args.epochs) and epoch != 0:
|
485 |
+
misc.save_model(
|
486 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
487 |
+
loss_scaler=loss_scaler, epoch=epoch, suffix=f"finetuning_{epoch}", upload=epoch % 100 == 0)
|
488 |
+
elif True:
|
489 |
+
misc.save_model(
|
490 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
491 |
+
loss_scaler=loss_scaler, epoch=epoch, suffix=f"finetuning_last", upload=False)
|
492 |
+
if args.output_dir and val_mae / len(data_loader_val) < min_MAE:
|
493 |
+
min_MAE = val_mae / len(data_loader_val)
|
494 |
+
misc.save_model(
|
495 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
496 |
+
loss_scaler=loss_scaler, epoch=epoch, suffix="finetuning_minMAE")
|
497 |
+
|
498 |
+
print(f'[Train Epoch #{epoch}] - MAE: {train_mae.item() / len(data_loader_train):5.2f}, RMSE: {(train_mse.item() / len(data_loader_train)) ** 0.5:5.2f}', flush=True)
|
499 |
+
print(f'[Val Epoch #{epoch}] - MAE: {val_mae.item() / len(data_loader_val):5.2f}, RMSE: {(val_mse.item() / len(data_loader_val)) ** 0.5:5.2f}, NAE: {val_nae.item() / len(data_loader_val):5.2f}', flush=True)
|
500 |
+
|
501 |
+
total_time = time.time() - start_time
|
502 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
503 |
+
print('Training time {}'.format(total_time_str))
|
504 |
+
|
505 |
+
finally:
|
506 |
+
if wandb_run is not None:
|
507 |
+
wandb.run.finish()
|
508 |
+
|
509 |
+
|
510 |
+
if __name__ == '__main__':
|
511 |
+
args = get_args_parser()
|
512 |
+
args = args.parse_args()
|
513 |
+
|
514 |
+
data_path = Path(args.data_path)
|
515 |
+
anno_file = data_path / args.anno_file
|
516 |
+
data_split_file = data_path / args.data_split_file
|
517 |
+
im_dir = data_path / args.im_dir
|
518 |
+
|
519 |
+
if args.do_aug:
|
520 |
+
class_file = data_path / args.class_file
|
521 |
+
else:
|
522 |
+
class_file = None
|
523 |
+
|
524 |
+
args.anno_file = anno_file
|
525 |
+
args.data_split_file = data_split_file
|
526 |
+
args.im_dir = im_dir
|
527 |
+
args.class_file = class_file
|
528 |
+
|
529 |
+
if args.output_dir:
|
530 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
531 |
+
|
532 |
+
main(args)
|
FSC_test.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
from PIL import Image, ImageDraw
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import scipy.ndimage as ndimage
|
9 |
+
import pandas as pd
|
10 |
+
import random
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.backends.cudnn as cudnn
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
import torchvision
|
16 |
+
from torchvision import transforms
|
17 |
+
import torchvision.transforms.functional as TF
|
18 |
+
import timm
|
19 |
+
from util.FSC147 import transform_train, transform_val
|
20 |
+
from tqdm import tqdm
|
21 |
+
assert "0.4.5" <= timm.__version__ <= "0.4.9" # version check
|
22 |
+
|
23 |
+
import util.misc as misc
|
24 |
+
import models_mae_cross
|
25 |
+
|
26 |
+
|
27 |
+
def get_args_parser():
|
28 |
+
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
|
29 |
+
|
30 |
+
# Model parameters
|
31 |
+
parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
|
32 |
+
help='Name of model to train')
|
33 |
+
parser.add_argument('--mask_ratio', default=0.5, type=float,
|
34 |
+
help='Masking ratio (percentage of removed patches).')
|
35 |
+
parser.add_argument('--norm_pix_loss', action='store_true',
|
36 |
+
help='Use (per-patch) normalized pixels as targets for computing loss')
|
37 |
+
parser.set_defaults(norm_pix_loss=False)
|
38 |
+
|
39 |
+
# Dataset parameters
|
40 |
+
parser.add_argument('--data_path', default='./data/FSC147/', type=str,
|
41 |
+
help='dataset path')
|
42 |
+
parser.add_argument('--anno_file', default='annotation_FSC147_positive.json', type=str,
|
43 |
+
help='annotation json file')
|
44 |
+
parser.add_argument('--anno_file_negative', default='./data/FSC147/annotation_FSC147_neg2.json', type=str,
|
45 |
+
help='annotation json file')
|
46 |
+
parser.add_argument('--data_split_file', default='Train_Test_Val_FSC_147.json', type=str,
|
47 |
+
help='data split json file')
|
48 |
+
parser.add_argument('--im_dir', default='images_384_VarV2', type=str,
|
49 |
+
help='images directory')
|
50 |
+
parser.add_argument('--output_dir', default='./Image',
|
51 |
+
help='path where to save, empty for no saving')
|
52 |
+
parser.add_argument('--device', default='cuda',
|
53 |
+
help='device to use for training / testing')
|
54 |
+
parser.add_argument('--seed', default=0, type=int)
|
55 |
+
parser.add_argument('--resume', default='./output_fim6_dir/checkpoint-0.pth',
|
56 |
+
help='resume from checkpoint')
|
57 |
+
parser.add_argument('--external', action='store_true',
|
58 |
+
help='Set this param for using external exemplars')
|
59 |
+
parser.add_argument('--box_bound', default=-1, type=int,
|
60 |
+
help='The max number of exemplars to be considered')
|
61 |
+
parser.add_argument('--split', default="test", type=str)
|
62 |
+
|
63 |
+
# Training parameters
|
64 |
+
parser.add_argument('--num_workers', default=0, type=int)
|
65 |
+
parser.add_argument('--pin_mem', action='store_true',
|
66 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
67 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
68 |
+
parser.set_defaults(pin_mem=True)
|
69 |
+
parser.add_argument('--normalization', default=True, help='Set to False to disable test-time normalization')
|
70 |
+
|
71 |
+
# Distributed training parameters
|
72 |
+
parser.add_argument('--world_size', default=1, type=int,
|
73 |
+
help='number of distributed processes')
|
74 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
75 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
76 |
+
parser.add_argument('--dist_url', default='env://',
|
77 |
+
help='url used to set up distributed training')
|
78 |
+
|
79 |
+
return parser
|
80 |
+
|
81 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = '5'
|
82 |
+
|
83 |
+
class TestData(Dataset):
|
84 |
+
def __init__(self, args, split='val', do_aug=True):
|
85 |
+
with open(data_path/args.anno_file) as f:
|
86 |
+
annotations = json.load(f)
|
87 |
+
# Load negative annotations
|
88 |
+
with open(args.anno_file_negative) as f:
|
89 |
+
neg_annotations = json.load(f)
|
90 |
+
with open(data_path/args.data_split_file) as f:
|
91 |
+
data_split = json.load(f)
|
92 |
+
|
93 |
+
self.img = data_split[split]
|
94 |
+
random.shuffle(self.img)
|
95 |
+
self.split = split
|
96 |
+
self.img_dir = im_dir
|
97 |
+
# self.TransformTrain = transform_train(args, do_aug=do_aug)
|
98 |
+
self.TransformVal = transform_val(args)
|
99 |
+
self.annotations = annotations
|
100 |
+
self.neg_annotations = neg_annotations
|
101 |
+
self.im_dir = im_dir
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
return len(self.img)
|
105 |
+
|
106 |
+
def __getitem__(self, idx):
|
107 |
+
im_id = self.img[idx]
|
108 |
+
anno = self.annotations[im_id]
|
109 |
+
bboxes = anno['box_examples_coordinates']
|
110 |
+
dots = np.array(anno['points'])
|
111 |
+
|
112 |
+
# 加载负样本的框
|
113 |
+
neg_anno = self.neg_annotations[im_id] # 假设每个图像ID在负样本注释中都有对应的条目
|
114 |
+
neg_bboxes = neg_anno['box_examples_coordinates']
|
115 |
+
|
116 |
+
rects = list()
|
117 |
+
for bbox in bboxes:
|
118 |
+
x1 = bbox[0][0]
|
119 |
+
y1 = bbox[0][1]
|
120 |
+
x2 = bbox[2][0]
|
121 |
+
y2 = bbox[2][1]
|
122 |
+
if x1 < 0:
|
123 |
+
x1 = 0
|
124 |
+
if x2 < 0:
|
125 |
+
x2 = 0
|
126 |
+
if y1 < 0:
|
127 |
+
y1 = 0
|
128 |
+
if y2 < 0:
|
129 |
+
y2 = 0
|
130 |
+
|
131 |
+
rects.append([y1, x1, y2, x2])
|
132 |
+
neg_rects = list()
|
133 |
+
for neg_bbox in neg_bboxes:
|
134 |
+
x1 = neg_bbox[0][0]
|
135 |
+
y1 = neg_bbox[0][1]
|
136 |
+
x2 = neg_bbox[2][0]
|
137 |
+
y2 = neg_bbox[2][1]
|
138 |
+
if x1 < 0:
|
139 |
+
x1 = 0
|
140 |
+
if x2 < 0:
|
141 |
+
x2 = 0
|
142 |
+
if y1 < 0:
|
143 |
+
y1 = 0
|
144 |
+
if y2 < 0:
|
145 |
+
y2 = 0
|
146 |
+
|
147 |
+
neg_rects.append([y1, x1, y2, x2])
|
148 |
+
|
149 |
+
image = Image.open('{}/{}'.format(self.im_dir, im_id))
|
150 |
+
if image.mode == "RGBA":
|
151 |
+
image = image.convert("RGB")
|
152 |
+
image.load()
|
153 |
+
m_flag = 0
|
154 |
+
|
155 |
+
sample = {'image': image, 'lines_boxes': rects,'neg_lines_boxes': neg_rects, 'dots': dots, 'id': im_id, 'm_flag': m_flag}
|
156 |
+
sample = self.TransformTrain(sample) if self.split == "train" else self.TransformVal(sample)
|
157 |
+
# if self.split == "train":
|
158 |
+
# sample = self.TransformTrain(sample)
|
159 |
+
# # print(sample.keys())
|
160 |
+
return sample['image'], sample['gt_density'], len(dots), sample['boxes'], sample['neg_boxes'], sample['pos'],sample['m_flag'], im_id
|
161 |
+
|
162 |
+
def batched_rmse(predictions, targets, batch_size=100):
|
163 |
+
"""
|
164 |
+
分批计算RMSE
|
165 |
+
:param predictions: 模型预测的值,一个PyTorch张量
|
166 |
+
:param targets: 真实的值,一个PyTorch张量,与predictions形状相同
|
167 |
+
:param batch_size: 每个批次的大小
|
168 |
+
:return: RMSE值
|
169 |
+
"""
|
170 |
+
total_mse = 0.0
|
171 |
+
total_count = 0
|
172 |
+
|
173 |
+
# 分批处理
|
174 |
+
for i in range(0, len(predictions), batch_size):
|
175 |
+
batch_predictions = predictions[i:i+batch_size]
|
176 |
+
batch_targets = targets[i:i+batch_size]
|
177 |
+
|
178 |
+
# 确保使用float64进行计算以提高精度
|
179 |
+
batch_predictions = batch_predictions.double()
|
180 |
+
batch_targets = batch_targets.double()
|
181 |
+
|
182 |
+
# 计算批次的MSE
|
183 |
+
difference = batch_predictions - batch_targets
|
184 |
+
mse = torch.mean(difference ** 2)
|
185 |
+
|
186 |
+
# 累加MSE和计数
|
187 |
+
total_mse += mse * len(batch_predictions)
|
188 |
+
total_count += len(batch_predictions)
|
189 |
+
|
190 |
+
# 计算平均MSE
|
191 |
+
avg_mse = total_mse / total_count
|
192 |
+
|
193 |
+
# 计算RMSE
|
194 |
+
rmse_val = torch.sqrt(avg_mse)
|
195 |
+
|
196 |
+
return rmse_val
|
197 |
+
def batched_mae(predictions, targets, batch_size=100):
|
198 |
+
"""
|
199 |
+
分批计算MAE
|
200 |
+
:param predictions: 模型预测的值,一个PyTorch张量
|
201 |
+
:param targets: 真实的值,一个PyTorch张量,与predictions形状相同
|
202 |
+
:param batch_size: 每个批次的大小
|
203 |
+
:return: MAE值
|
204 |
+
"""
|
205 |
+
total_mae = 0.0
|
206 |
+
total_count = 0
|
207 |
+
|
208 |
+
# 分批处理
|
209 |
+
for i in range(0, len(predictions), batch_size):
|
210 |
+
batch_predictions = predictions[i:i+batch_size]
|
211 |
+
batch_targets = targets[i:i+batch_size]
|
212 |
+
|
213 |
+
# 计算批次的绝对误差
|
214 |
+
absolute_errors = torch.abs(batch_predictions - batch_targets)
|
215 |
+
|
216 |
+
# 累加绝对误差和计数
|
217 |
+
total_mae += torch.sum(absolute_errors)
|
218 |
+
total_count += len(batch_predictions)
|
219 |
+
|
220 |
+
# 计算平均绝对误差
|
221 |
+
avg_mae = total_mae / total_count
|
222 |
+
|
223 |
+
return avg_mae
|
224 |
+
|
225 |
+
def main(args):
|
226 |
+
misc.init_distributed_mode(args)
|
227 |
+
|
228 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
229 |
+
print("{}".format(args).replace(', ', ',\n'))
|
230 |
+
|
231 |
+
device = torch.device(args.device)
|
232 |
+
|
233 |
+
# fix the seed for reproducibility
|
234 |
+
seed = args.seed + misc.get_rank()
|
235 |
+
torch.manual_seed(seed)
|
236 |
+
np.random.seed(seed)
|
237 |
+
|
238 |
+
cudnn.benchmark = True
|
239 |
+
|
240 |
+
# dataset_test = TestData(external=args.external, box_bound=args.box_bound, split=args.split)
|
241 |
+
dataset_test = TestData(args, split='test')
|
242 |
+
num_tasks = misc.get_world_size()
|
243 |
+
global_rank = misc.get_rank()
|
244 |
+
sampler_test = torch.utils.data.DistributedSampler(
|
245 |
+
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
246 |
+
)
|
247 |
+
|
248 |
+
data_loader_test = torch.utils.data.DataLoader(
|
249 |
+
dataset_test, sampler=sampler_test,
|
250 |
+
batch_size=1,
|
251 |
+
num_workers=args.num_workers,
|
252 |
+
pin_memory=args.pin_mem,
|
253 |
+
drop_last=False,
|
254 |
+
)
|
255 |
+
|
256 |
+
# define the model
|
257 |
+
model = models_mae_cross.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
|
258 |
+
model.to(device)
|
259 |
+
model_without_ddp = model
|
260 |
+
|
261 |
+
if args.distributed:
|
262 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
263 |
+
model_without_ddp = model.module
|
264 |
+
|
265 |
+
misc.load_model_FSC(args=args, model_without_ddp=model_without_ddp)
|
266 |
+
|
267 |
+
print(f"Start testing.")
|
268 |
+
|
269 |
+
# test
|
270 |
+
model.eval()
|
271 |
+
|
272 |
+
# some parameters in training
|
273 |
+
train_mae = 0
|
274 |
+
train_rmse = 0
|
275 |
+
train_nae = 0
|
276 |
+
tot_load_time = 0
|
277 |
+
tot_infer_time = 0
|
278 |
+
|
279 |
+
loss_array = []
|
280 |
+
gt_array = []
|
281 |
+
pred_arr = []
|
282 |
+
name_arr = []
|
283 |
+
empties = []
|
284 |
+
|
285 |
+
total_mae = 0.0
|
286 |
+
total_mse = 0.0
|
287 |
+
total_nae = 0.0
|
288 |
+
total_count = 0
|
289 |
+
sub_batch_size = 50
|
290 |
+
for val_samples, val_gt_density, val_n_ppl, val_boxes,neg_val_boxes, val_pos, _, val_im_names in tqdm(data_loader_test, total=len(data_loader_test), desc="Validation"):
|
291 |
+
val_samples = val_samples.to(device, non_blocking=True, dtype=torch.float) # 使用更高精度
|
292 |
+
val_gt_density = val_gt_density.to(device, non_blocking=True, dtype=torch.float)
|
293 |
+
val_boxes = val_boxes.to(device, non_blocking=True, dtype=torch.float)
|
294 |
+
neg_val_boxes = neg_val_boxes.to(device, non_blocking=True, dtype=torch.float)
|
295 |
+
num_samples = val_samples.size(0)
|
296 |
+
total_count += num_samples
|
297 |
+
|
298 |
+
for i in range(0, num_samples, sub_batch_size):
|
299 |
+
sub_val_samples = val_samples[i:i+sub_batch_size]
|
300 |
+
sub_val_gt_density = val_gt_density[i:i+sub_batch_size]
|
301 |
+
|
302 |
+
with torch.no_grad():
|
303 |
+
with torch.cuda.amp.autocast():
|
304 |
+
sub_val_output = model(sub_val_samples, val_boxes[i:i+sub_batch_size], 3)
|
305 |
+
with torch.no_grad():
|
306 |
+
with torch.cuda.amp.autocast():
|
307 |
+
neg_sub_val_output = model(sub_val_samples, neg_val_boxes[i:i+sub_batch_size], 3)
|
308 |
+
# output = torch.clamp((sub_val_output-neg_sub_val_output),min=0)
|
309 |
+
sub_val_pred_cnt = torch.abs(sub_val_output.sum()) / 60
|
310 |
+
# sub_val_pred_cnt = torch.abs(output.sum()) / 60
|
311 |
+
# neg_sub_val_pred_cnt = torch.abs(neg_sub_val_output.sum()) / 60
|
312 |
+
sub_val_gt_cnt = sub_val_gt_density.sum() / 60
|
313 |
+
|
314 |
+
sub_val_cnt_err = torch.abs(sub_val_pred_cnt - sub_val_gt_cnt)
|
315 |
+
|
316 |
+
# 逐项添加并检查
|
317 |
+
if not torch.isinf(sub_val_cnt_err) and not torch.isnan(sub_val_cnt_err):
|
318 |
+
batch_mae = sub_val_cnt_err.item()
|
319 |
+
batch_mse = sub_val_cnt_err.item() ** 2
|
320 |
+
batch_nae = sub_val_cnt_err.item() / sub_val_gt_cnt.item() if sub_val_gt_cnt.item() != 0 else 0
|
321 |
+
|
322 |
+
total_mae += batch_mae * sub_val_samples.size(0)
|
323 |
+
total_mse += batch_mse * sub_val_samples.size(0)
|
324 |
+
total_nae += batch_nae * sub_val_samples.size(0)
|
325 |
+
sub_val_pred_cnt = (sub_val_pred_cnt).int()
|
326 |
+
final_mae = total_mae / total_count
|
327 |
+
final_rmse = (total_mse / total_count) ** 0.5
|
328 |
+
final_nae = total_nae / total_count
|
329 |
+
|
330 |
+
print(f'MAE: {final_mae}, RMSE: {final_rmse}, NAE: {final_nae}')
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
if __name__ == '__main__':
|
335 |
+
args = get_args_parser()
|
336 |
+
args = args.parse_args()
|
337 |
+
|
338 |
+
# load data
|
339 |
+
data_path = Path(args.data_path)
|
340 |
+
anno_file = data_path / args.anno_file
|
341 |
+
data_split_file = data_path / args.data_split_file
|
342 |
+
im_dir = data_path / args.im_dir
|
343 |
+
|
344 |
+
with open(anno_file) as f:
|
345 |
+
annotations = json.load(f)
|
346 |
+
|
347 |
+
with open(data_split_file) as f:
|
348 |
+
data_split = json.load(f)
|
349 |
+
|
350 |
+
if args.output_dir:
|
351 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
352 |
+
main(args)
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Chang Liu
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,100 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VA-Count
|
2 |
+
[ECCV 2024] Zero-shot Object Counting with Good Exemplars
|
3 |
+
[[paper](https://arxiv.org/abs/2407.04948)]
|
4 |
+

|
5 |
+
# Zero-shot Object Counting with Good Exemplars
|
6 |
+
## News🚀
|
7 |
+
* **2024.09.27**: Our code is released.
|
8 |
+
* **2024.09.26**: Our inference code has been updated, and the code for selecting exemplars and the training code will be coming soon.
|
9 |
+
* **2024.07.02**: VA-Count is accepted by ECCV2024.
|
10 |
+
## Overview
|
11 |
+
Overview of the proposed method. The proposed method focuses on two main elements: the Exemplar Enhancement Module (EEM) for improving exemplar quality through a patch selection integrated with Grounding DINO, and the Noise Suppression Module (NSM) that distinguishes between positive and negative class samples using density maps. It employs a Contrastive Loss function to refine the precision in identifying target class objects from others in an image.
|
12 |
+
## Environment
|
13 |
+
```
|
14 |
+
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
|
15 |
+
pip install timm==0.3.2
|
16 |
+
pip install numpy
|
17 |
+
pip install matplotlib tqdm
|
18 |
+
pip install tensorboard
|
19 |
+
pip install scipy
|
20 |
+
pip install imgaug
|
21 |
+
pip install opencv-python
|
22 |
+
pip3 install hub
|
23 |
+
```
|
24 |
+
### For more information on Grounding DINO, please refer to the following link:
|
25 |
+
[GroundingDINO](https://github.com/IDEA-Research/GroundingDINO)
|
26 |
+
We are very grateful for the Grounding DINO approach, which has been instrumental in our work!
|
27 |
+
|
28 |
+
## Datasets
|
29 |
+
|
30 |
+
* [FSC147](https://github.com/cvlab-stonybrook/LearningToCountEverything)
|
31 |
+
|
32 |
+
* [CARPK](https://lafi.github.io/LPN/)
|
33 |
+
|
34 |
+
Preparing the datasets as follows:
|
35 |
+
|
36 |
+
```
|
37 |
+
./data/
|
38 |
+
|--FSC147
|
39 |
+
| |--images_384_VarV2
|
40 |
+
| | |--2.jpg
|
41 |
+
| | |--3.jpg
|
42 |
+
| |--gt_density_map_adaptive_384_VarV2
|
43 |
+
| | |--2.npy
|
44 |
+
| | |--3.npy
|
45 |
+
| |--annotation_FSC147_384.json
|
46 |
+
| |--Train_Test_Val_FSC_147.json
|
47 |
+
| |--ImageClasses_FSC147.txt
|
48 |
+
| |--train.txt
|
49 |
+
| |--test.txt
|
50 |
+
| |--val.txt
|
51 |
+
|--CARPK/
|
52 |
+
| |--Annotations/
|
53 |
+
| |--Images/
|
54 |
+
| |--ImageSets/
|
55 |
+
```
|
56 |
+
## Inference
|
57 |
+
+ For inference, you can download the model from [Baidu-Disk](https://pan.baidu.com/s/11sbdDYLDfTOIPx5pZvBpmw?pwd=paeh), passward:paeh
|
58 |
+
```
|
59 |
+
python FSC_test.py --output_dir ./data/out/results_base --resume ./data/checkpoint_FSC.pth
|
60 |
+
```
|
61 |
+
## Single and Multiple Object Classifier Training
|
62 |
+
```
|
63 |
+
python datasetmake.py
|
64 |
+
python biclassify.py
|
65 |
+
```
|
66 |
+
+ You can also directly download the model from [Baidu-Disk](https://pan.baidu.com/s/1fOF0giI3yQpvGTiNFUI7cQ?pwd=psum), passward:psum Save it in ./data/out/classify/
|
67 |
+
## Generate exemplars
|
68 |
+
```
|
69 |
+
python grounding_pos.py --root_path ./data/FSC147/
|
70 |
+
python grounding_neg.py --root_path ./data/FSC147/
|
71 |
+
```
|
72 |
+
|
73 |
+
## Train
|
74 |
+
|
75 |
+
```
|
76 |
+
CUDA_VISIBLE_DEVICES=0 python FSC_pretrain.py \
|
77 |
+
--epochs 500 \
|
78 |
+
--warmup_epochs 10 \
|
79 |
+
--blr 1.5e-4 --weight_decay 0.05
|
80 |
+
```
|
81 |
+
+ You can also directly download the pre-train model from [Baidu-Disk](https://pan.baidu.com/s/1_-w_9I4bPA66pMZkHTrdrg?pwd=xynw), passward:xynw Save it in ./data/
|
82 |
+
```
|
83 |
+
CUDA_VISIBLE_DEVICES=0 python FSC_train.py --epochs 1000 --batch_size 8 --lr 1e-5 --output_dir ./data/out/
|
84 |
+
```
|
85 |
+
|
86 |
+
## Citation
|
87 |
+
|
88 |
+
```
|
89 |
+
@inproceedings{zhu2024zero,
|
90 |
+
title={Zero-shot Object Counting with Good Exemplars},
|
91 |
+
author={Zhu, Huilin and Yuan, Jingling and Yang, Zhengwei and Guo, Yu and Wang, Zheng and Zhong, Xian and He, Shengfeng},
|
92 |
+
booktitle={Proceedings of the European Conference on Computer Vision},
|
93 |
+
year={2024}
|
94 |
+
}
|
95 |
+
```
|
96 |
+
|
97 |
+
## Acknowledgement
|
98 |
+
This project is based on the implementation from [CounTR](https://github.com/Verg-Avesta/CounTR), we are very grateful for this work and [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO).
|
99 |
+
|
100 |
+
#### If you have any questions, please get in touch with me ([email protected]).
|
__pycache__/models_crossvit.cpython-38.pyc
ADDED
Binary file (6.28 kB). View file
|
|
__pycache__/models_mae_cross.cpython-38.pyc
ADDED
Binary file (6.69 kB). View file
|
|
__pycache__/models_mae_noct.cpython-38.pyc
ADDED
Binary file (7.03 kB). View file
|
|
__pycache__/models_mae_noct.cpython-39.pyc
ADDED
Binary file (6.96 kB). View file
|
|
biclassify.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
|
6 |
+
from PIL import Image
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
+
import clip
|
11 |
+
import re
|
12 |
+
import torchvision.models as models
|
13 |
+
# 1. 读取数据和预处理
|
14 |
+
def read_label_file(file_path):
|
15 |
+
data = []
|
16 |
+
with open(file_path, 'r') as f:
|
17 |
+
for line in f.readlines():
|
18 |
+
image_name, label = line.strip().split(',')
|
19 |
+
data.append([image_name, 1 if label == 'one' else 0])
|
20 |
+
return pd.DataFrame(data, columns=['image', 'label'])
|
21 |
+
# 读取a.txt中的图片名称
|
22 |
+
with open('./data/FSC147/train.txt', 'r') as file:
|
23 |
+
a_txt_images = file.read().splitlines()
|
24 |
+
|
25 |
+
# 提取.jpg前的数字
|
26 |
+
a_txt_numbers = set([name.split('.')[0] for name in a_txt_images])
|
27 |
+
|
28 |
+
# 从label.txt中读取图片名称和标签
|
29 |
+
with open('./data/FSC147/one/labels.txt', 'r') as file:
|
30 |
+
label_txt_lines = file.read().splitlines()
|
31 |
+
|
32 |
+
# 筛选出存在于a.txt中的图片
|
33 |
+
filtered_images = []
|
34 |
+
for line in label_txt_lines:
|
35 |
+
image_name, label = line.strip().split(',')
|
36 |
+
# 使用正则表达式匹配开头的数字
|
37 |
+
match = re.match(r'(\d+)', image_name)
|
38 |
+
if match:
|
39 |
+
image_number = match.group(1)
|
40 |
+
if image_number in a_txt_numbers:
|
41 |
+
# 转换'label'的值
|
42 |
+
label_value = 1 if label == 'one' else 0
|
43 |
+
filtered_images.append([image_name, label_value]) # 注意这里是列表,以匹配read_label_file的输出
|
44 |
+
|
45 |
+
# 将筛选后的图片和标签转换为DataFrame,确保列名与read_label_file函数的输出相匹配
|
46 |
+
df_filtered = pd.DataFrame(filtered_images, columns=['image', 'label'])
|
47 |
+
|
48 |
+
# 自定义Dataset类
|
49 |
+
class CustomDataset(Dataset):
|
50 |
+
def __init__(self, dataframe, root_dir, transform=None):
|
51 |
+
self.dataframe = dataframe
|
52 |
+
self.root_dir = root_dir
|
53 |
+
self.transform = transform
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.dataframe)
|
57 |
+
|
58 |
+
def __getitem__(self, idx):
|
59 |
+
img_name = os.path.join(self.root_dir, self.dataframe.iloc[idx, 0])
|
60 |
+
image = Image.open(img_name).convert('RGB')
|
61 |
+
label = self.dataframe.iloc[idx, 1]
|
62 |
+
if self.transform:
|
63 |
+
image = self.transform(image)
|
64 |
+
return image, label
|
65 |
+
|
66 |
+
# 2. 数据集划分
|
67 |
+
data_folder = './data/FSC147/one'
|
68 |
+
label_file = os.path.join(data_folder, 'labels.txt')
|
69 |
+
# df = read_label_file(label_file)
|
70 |
+
df = df_filtered
|
71 |
+
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
|
72 |
+
|
73 |
+
# 3. 数据加载
|
74 |
+
transform = Compose([
|
75 |
+
Resize((224, 224)),
|
76 |
+
ToTensor(),
|
77 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
78 |
+
])
|
79 |
+
|
80 |
+
train_dataset = CustomDataset(train_df, data_folder, transform=transform)
|
81 |
+
test_dataset = CustomDataset(test_df, data_folder, transform=transform)
|
82 |
+
|
83 |
+
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
84 |
+
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
85 |
+
|
86 |
+
# 4. 模型定义
|
87 |
+
class ClipClassifier(nn.Module):
|
88 |
+
def __init__(self, clip_model, embed_dim=512):
|
89 |
+
super(ClipClassifier, self).__init__()
|
90 |
+
self.clip_model = clip_model
|
91 |
+
# 冻结CLIP模型的参数
|
92 |
+
for param in self.clip_model.parameters():
|
93 |
+
param.requires_grad = False
|
94 |
+
self.fc = nn.Linear(clip_model.visual.output_dim, embed_dim)
|
95 |
+
self.classifier = nn.Linear(embed_dim, 2) # 二分类
|
96 |
+
|
97 |
+
def forward(self, images):
|
98 |
+
with torch.no_grad():
|
99 |
+
image_features = self.clip_model.encode_image(images).float()
|
100 |
+
x = self.fc(image_features)
|
101 |
+
x = F.relu(x)
|
102 |
+
logits = self.classifier(x)
|
103 |
+
return logits
|
104 |
+
class ResNetClassifier(nn.Module):
|
105 |
+
def __init__(self, num_classes=2):
|
106 |
+
super(ResNetClassifier, self).__init__()
|
107 |
+
# 加载预训练的ResNet50模型
|
108 |
+
self.resnet50 = models.resnet50(pretrained=True)
|
109 |
+
# 冻结所有预训练层的参数
|
110 |
+
for param in self.resnet50.parameters():
|
111 |
+
param.requires_grad = False
|
112 |
+
# 替换最后的全连接层以适应二分类任务
|
113 |
+
num_ftrs = self.resnet50.fc.in_features
|
114 |
+
self.resnet50.fc = nn.Linear(num_ftrs, num_classes)
|
115 |
+
|
116 |
+
def forward(self, images):
|
117 |
+
return self.resnet50(images)
|
118 |
+
|
119 |
+
# 5. 训练和测试
|
120 |
+
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
|
121 |
+
clip_model, _ = clip.load("ViT-B/32", device=device)
|
122 |
+
# model = ClipClassifier(clip_model).to(device)
|
123 |
+
model = ResNetClassifier().to(device)
|
124 |
+
|
125 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
126 |
+
criterion = nn.CrossEntropyLoss()
|
127 |
+
|
128 |
+
def train(model, device, train_loader, optimizer, epoch):
|
129 |
+
model.train()
|
130 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
131 |
+
data, target = data.to(device), target.to(device)
|
132 |
+
optimizer.zero_grad()
|
133 |
+
output = model(data)
|
134 |
+
loss = criterion(output, target)
|
135 |
+
loss.backward()
|
136 |
+
optimizer.step()
|
137 |
+
if batch_idx % 10 == 0:
|
138 |
+
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
|
139 |
+
|
140 |
+
def test(model, device, test_loader):
|
141 |
+
model.eval()
|
142 |
+
test_loss = 0
|
143 |
+
correct = 0
|
144 |
+
with torch.no_grad():
|
145 |
+
for data, target in test_loader:
|
146 |
+
data, target = data.to(device), target.to(device)
|
147 |
+
output = model(data)
|
148 |
+
test_loss += criterion(output, target).item()
|
149 |
+
pred = output.argmax(dim=1, keepdim=True)
|
150 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
151 |
+
test_loss /= len(test_loader.dataset)
|
152 |
+
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')
|
153 |
+
return 100. * correct / len(test_loader.dataset)
|
154 |
+
|
155 |
+
best_accuracy = 0.0
|
156 |
+
for epoch in range(1, 11):
|
157 |
+
train(model, device, train_loader, optimizer, epoch)
|
158 |
+
accuracy = test(model, device, test_loader)
|
159 |
+
if accuracy > best_accuracy:
|
160 |
+
best_accuracy = accuracy
|
161 |
+
torch.save(model.state_dict(), './data/out/classify/best_model.pth')
|
162 |
+
print(f'Best model saved with accuracy: {best_accuracy:.2f}%')
|
163 |
+
|
datasetmake.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
def is_image_file(filename):
|
6 |
+
"""判断文件是否是图像文件"""
|
7 |
+
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] # 支持的图像文件扩展名列表
|
8 |
+
return any(filename.lower().endswith(ext) for ext in image_extensions)
|
9 |
+
|
10 |
+
def random_crop(img, size=(256, 256)):
|
11 |
+
"""从给定的图片中随机裁剪出指定大小的区域"""
|
12 |
+
width, height = img.size
|
13 |
+
crop_width, crop_height = size
|
14 |
+
|
15 |
+
if width < crop_width or height < crop_height:
|
16 |
+
return None # 如果图片尺寸小于裁剪尺寸,则返回None
|
17 |
+
|
18 |
+
x_left = random.randint(0, width - crop_width)
|
19 |
+
y_upper = random.randint(0, height - crop_height)
|
20 |
+
|
21 |
+
return img.crop((x_left, y_upper, x_left + crop_width, y_upper + crop_height))
|
22 |
+
|
23 |
+
# 文件夹路径设置(根据实际情况修改)
|
24 |
+
single_object_folder = './data/FSC147/box'
|
25 |
+
multiple_objects_folder = './data/FSC147/images_384_VarV2'
|
26 |
+
output_folder = './data/FSC147/one'
|
27 |
+
|
28 |
+
# 确保输出文件夹存在
|
29 |
+
if not os.path.exists(output_folder):
|
30 |
+
os.makedirs(output_folder)
|
31 |
+
|
32 |
+
output_txt_path = os.path.join(output_folder, 'labels.txt')
|
33 |
+
with open(output_txt_path, 'w') as f:
|
34 |
+
for folder, label in [(single_object_folder, 'one'), (multiple_objects_folder, 'more')]:
|
35 |
+
for filename in os.listdir(folder):
|
36 |
+
if is_image_file(filename): # 只处理图像文件
|
37 |
+
img_path = os.path.join(folder, filename)
|
38 |
+
img = Image.open(img_path)
|
39 |
+
|
40 |
+
# 保存原图并记录到txt文件
|
41 |
+
original_img_output_path = os.path.join(output_folder, filename)
|
42 |
+
img.save(original_img_output_path)
|
43 |
+
f.write(f"{filename},{label}\n")
|
44 |
+
|
45 |
+
# 从原图中随机裁剪并保存裁剪图像
|
46 |
+
for size in [(256, 384), (256, 256), (384, 384),(128,256),(256,128)]:
|
47 |
+
img_cropped = random_crop(img, size=size)
|
48 |
+
if img_cropped:
|
49 |
+
cropped_img_output_path = os.path.join(output_folder, f"{filename[:-4]}_random_{size[0]}x{size[1]}.jpg")
|
50 |
+
img_cropped.save(cropped_img_output_path)
|
51 |
+
f.write(f"{filename[:-4]}_random_{size[0]}x{size[1]}.jpg,{label}\n")
|
52 |
+
|
53 |
+
print("数据集准备完成。")
|
figure.png
ADDED
![]() |
grounding_neg.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import inflect
|
4 |
+
import argparse
|
5 |
+
from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
from torchvision.ops import box_convert
|
9 |
+
import json
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import clip
|
13 |
+
|
14 |
+
# 定义全局变量
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
|
17 |
+
# 阈值设置
|
18 |
+
BOX_THRESHOLD = 0.02
|
19 |
+
TEXT_THRESHOLD = 0.02
|
20 |
+
BOX_THRESHOLD_class = 0.01
|
21 |
+
TEXT_THRESHOLD_class = 0.01
|
22 |
+
|
23 |
+
# 初始化inflect引擎
|
24 |
+
p = inflect.engine()
|
25 |
+
|
26 |
+
# 将单词转换为单数形式的函数
|
27 |
+
def to_singular(word):
|
28 |
+
singular_word = p.singular_noun(word)
|
29 |
+
return singular_word if singular_word else word
|
30 |
+
|
31 |
+
# 定义ClipClassifier类
|
32 |
+
class ClipClassifier(nn.Module):
|
33 |
+
def __init__(self, clip_model, embed_dim=512):
|
34 |
+
super(ClipClassifier, self).__init__()
|
35 |
+
self.clip_model = clip_model.to(device)
|
36 |
+
for param in self.clip_model.parameters():
|
37 |
+
param.requires_grad = False
|
38 |
+
self.fc = nn.Linear(clip_model.visual.output_dim, embed_dim)
|
39 |
+
self.classifier = nn.Linear(embed_dim, 2) # 二分类
|
40 |
+
|
41 |
+
def forward(self, images):
|
42 |
+
with torch.no_grad():
|
43 |
+
image_features = self.clip_model.encode_image(images).float().to(device)
|
44 |
+
x = self.fc(image_features)
|
45 |
+
x = F.relu(x)
|
46 |
+
logits = self.classifier(x)
|
47 |
+
return logits
|
48 |
+
|
49 |
+
# 初始化和加载二分类模型
|
50 |
+
clip_model, preprocess = clip.load("ViT-B/32", device)
|
51 |
+
binary_classifier = ClipClassifier(clip_model).to(device)
|
52 |
+
|
53 |
+
# 加载保存的权重
|
54 |
+
model_weights_path = './data/out/classify/best_model.pth'
|
55 |
+
binary_classifier.load_state_dict(torch.load(model_weights_path, map_location=device))
|
56 |
+
|
57 |
+
# 确认模型已经被设置为评估模式
|
58 |
+
binary_classifier.eval()
|
59 |
+
|
60 |
+
# 计算两个边界框的IoU
|
61 |
+
def calculate_iou(box1, box2):
|
62 |
+
x1, y1, w1, h1 = box1
|
63 |
+
x2, y2, w2, h2 = box2
|
64 |
+
|
65 |
+
intersection_x1 = max(x1, x2)
|
66 |
+
intersection_y1 = max(y1, y2)
|
67 |
+
intersection_x2 = min(x1 + w1, x2 + w2)
|
68 |
+
intersection_y2 = min(y1 + h1, y2 + h2)
|
69 |
+
|
70 |
+
intersection_area = max(intersection_x2 - intersection_x1, 0) * max(intersection_y2 - intersection_y1, 0)
|
71 |
+
box1_area = w1 * h1
|
72 |
+
box2_area = w2 * h2
|
73 |
+
union_area = box1_area + box2_area - intersection_area
|
74 |
+
iou = intersection_area / union_area if union_area > 0 else 0
|
75 |
+
|
76 |
+
return iou
|
77 |
+
|
78 |
+
# 检查patch是否有效
|
79 |
+
def is_valid_patch(patch, binary_classifier, preprocess, device):
|
80 |
+
if patch.size[0] <= 0 or patch.size[1] <= 0:
|
81 |
+
return False
|
82 |
+
|
83 |
+
patch_tensor = preprocess(patch).unsqueeze(0).to(device)
|
84 |
+
with torch.no_grad():
|
85 |
+
logits = binary_classifier(patch_tensor)
|
86 |
+
probabilities = torch.softmax(logits, dim=1)
|
87 |
+
prob_label_1 = probabilities[0, 1]
|
88 |
+
return prob_label_1.item() > 0.8
|
89 |
+
|
90 |
+
# 处理图片的主函数
|
91 |
+
def process_images(text_file_path, dataset_path, model, preprocess, binary_classifier, output_folder, device='cpu'):
|
92 |
+
boxes_dict = {}
|
93 |
+
|
94 |
+
with open(text_file_path, 'r') as f:
|
95 |
+
for line in f:
|
96 |
+
image_name, class_name = line.strip().split('\t')
|
97 |
+
print(f"Processing image: {image_name}")
|
98 |
+
text_prompt = class_name + ' .'
|
99 |
+
object_prompt = "object ."
|
100 |
+
image_path = os.path.join(dataset_path, image_name)
|
101 |
+
img = Image.open(image_path).convert("RGB")
|
102 |
+
image_source, image = load_image(image_path)
|
103 |
+
h, w, _ = image_source.shape
|
104 |
+
boxes_object, logits_object, _ = predict(model, image, object_prompt, BOX_THRESHOLD, TEXT_THRESHOLD)
|
105 |
+
boxes_class, logits_class, _ = predict(model, image, text_prompt, BOX_THRESHOLD_class, TEXT_THRESHOLD_class)
|
106 |
+
|
107 |
+
patches_object = box_convert(boxes_object, in_fmt="cxcywh", out_fmt="xyxy")
|
108 |
+
patches_class = box_convert(boxes_class, in_fmt="cxcywh", out_fmt="xyxy")
|
109 |
+
|
110 |
+
top_patches = []
|
111 |
+
iou_matrix = np.zeros((len(boxes_object), len(boxes_class)))
|
112 |
+
|
113 |
+
for j, box_class in enumerate(patches_class):
|
114 |
+
box_object_class = box_class.cpu().numpy() * np.array([w, h, w, h], dtype=np.float32)
|
115 |
+
x1_, y1_, x2_, y2_ = box_object_class.astype(int)
|
116 |
+
x1_, y1_, x2_, y2_ = max(x1_, 0), max(y1_, 0), min(x2_, w), min(y2_, h)
|
117 |
+
patch_ = img.crop((x1_, y1_, x2_, y2_))
|
118 |
+
if x2_ - x1_ > w / 2 or y2_ - y1_ > h / 2 or not is_valid_patch(patch_, binary_classifier, preprocess, device):
|
119 |
+
print(f"Skipping patch at box {box_class}")
|
120 |
+
continue
|
121 |
+
for i, box_object in enumerate(patches_object):
|
122 |
+
iou_matrix[i][j] = calculate_iou(box_object.cpu().numpy(), box_class.cpu().numpy())
|
123 |
+
|
124 |
+
for i, box_object in enumerate(patches_object):
|
125 |
+
max_iou = np.max(iou_matrix[i])
|
126 |
+
if max_iou < 0.5:
|
127 |
+
box_object = box_object.cpu().numpy() * np.array([w, h, w, h], dtype=np.float32)
|
128 |
+
x1, y1, x2, y2 = box_object.astype(int)
|
129 |
+
x1, y1, x2, y2 = max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)
|
130 |
+
patch = img.crop((x1, y1, x2, y2))
|
131 |
+
if patch.size == (0, 0) or not is_valid_patch(patch, binary_classifier, preprocess, device) or x2 - x1 > w / 2 or y2 - y1 > h / 2 or y2 - y1 < 5 or x2 - x1 < 5:
|
132 |
+
print(f"Skipping patch at box {box_object}")
|
133 |
+
continue
|
134 |
+
patch_logits = logits_object[i]
|
135 |
+
top_patches.append((i, patch_logits.item()))
|
136 |
+
|
137 |
+
top_patches.sort(key=lambda x: x[1], reverse=True)
|
138 |
+
top_3_indices = [patch[0] for patch in top_patches[:3]]
|
139 |
+
|
140 |
+
while len(top_3_indices) < 3:
|
141 |
+
if len(top_3_indices) > 0:
|
142 |
+
top_3_indices.append(top_3_indices[-1])
|
143 |
+
else:
|
144 |
+
default_box = torch.tensor([0,0,20/w,20/h]).unsqueeze(0)
|
145 |
+
patches_object = torch.cat((patches_object, default_box.to(boxes_object.device)), dim=0)
|
146 |
+
top_3_indices.append(len(patches_object) - 1)
|
147 |
+
|
148 |
+
boxes_dict[image_name] = [patches_object[idx].cpu().numpy().tolist() * np.array([w, h, w, h], dtype=np.float32) for idx in top_3_indices]
|
149 |
+
|
150 |
+
return boxes_dict
|
151 |
+
|
152 |
+
def main(args):
|
153 |
+
# 设置固定的默认路径
|
154 |
+
model_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
155 |
+
model_weights = "GroundingDINO/weights/groundingdino_swint_ogc.pth"
|
156 |
+
|
157 |
+
# 根据root_path设置路径
|
158 |
+
text_file_path = os.path.join(args.root_path, "ImageClasses_FSC147.txt")
|
159 |
+
dataset_path = os.path.join(args.root_path, "images_384_VarV2")
|
160 |
+
input_json_path = os.path.join(args.root_path, "annotation_FSC147_384.json")
|
161 |
+
output_json_path = os.path.join(args.root_path, "annotation_FSC147_neg.json")
|
162 |
+
output_folder = os.path.join(args.root_path, "annotated_images_n")
|
163 |
+
|
164 |
+
os.makedirs(output_folder, exist_ok=True)
|
165 |
+
|
166 |
+
# 加载GroundingDINO模型
|
167 |
+
model = load_model(model_config, model_weights, device=device)
|
168 |
+
|
169 |
+
# 处理图片并生成边界框
|
170 |
+
boxes_dict = process_images(text_file_path, dataset_path, model, preprocess, binary_classifier, output_folder, device=device)
|
171 |
+
|
172 |
+
# 更新JSON文件
|
173 |
+
with open(input_json_path, 'r') as f:
|
174 |
+
data = json.load(f)
|
175 |
+
|
176 |
+
for image_name, boxes in boxes_dict.items():
|
177 |
+
if image_name in data:
|
178 |
+
new_boxes = [[[x1, y1], [x1, y2], [x2, y2], [x2, y1]] for x1, y1, x2, y2 in boxes]
|
179 |
+
data[image_name]["box_examples_coordinates"] = new_boxes
|
180 |
+
|
181 |
+
with open(output_json_path, 'w') as f:
|
182 |
+
json.dump(data, f, indent=4)
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
parser = argparse.ArgumentParser(description="Image Processing Script")
|
186 |
+
parser.add_argument("--root_path", type=str, required=True, help="Root path to the dataset and output files")
|
187 |
+
args = parser.parse_args()
|
188 |
+
main(args)
|
grounding_pos.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import clip
|
4 |
+
import inflect
|
5 |
+
import argparse
|
6 |
+
from torchvision.ops import box_convert
|
7 |
+
from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import json
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
# 定义全局变量
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
BOX_THRESHOLD = 0.05
|
17 |
+
TEXT_THRESHOLD = 0.05
|
18 |
+
|
19 |
+
# 初始化inflect引擎
|
20 |
+
p = inflect.engine()
|
21 |
+
|
22 |
+
# 定义 ClipClassifier 类
|
23 |
+
class ClipClassifier(nn.Module):
|
24 |
+
def __init__(self, clip_model, embed_dim=512):
|
25 |
+
super(ClipClassifier, self).__init__()
|
26 |
+
self.clip_model = clip_model.to(device)
|
27 |
+
for param in self.clip_model.parameters():
|
28 |
+
param.requires_grad = False
|
29 |
+
self.fc = nn.Linear(clip_model.visual.output_dim, embed_dim)
|
30 |
+
self.classifier = nn.Linear(embed_dim, 2) # 二分类
|
31 |
+
|
32 |
+
def forward(self, images):
|
33 |
+
with torch.no_grad():
|
34 |
+
image_features = self.clip_model.encode_image(images).float().to(device)
|
35 |
+
x = self.fc(image_features)
|
36 |
+
x = F.relu(x)
|
37 |
+
logits = self.classifier(x)
|
38 |
+
return logits
|
39 |
+
|
40 |
+
# 加载 CLIP 模型
|
41 |
+
clip_model, preprocess = clip.load("ViT-B/32", device)
|
42 |
+
clip_model.eval()
|
43 |
+
|
44 |
+
# 初始化并加载二分类模型
|
45 |
+
binary_classifier = ClipClassifier(clip_model).to(device)
|
46 |
+
model_weights_path = './data/out/classify/best_model.pth'
|
47 |
+
binary_classifier.load_state_dict(torch.load(model_weights_path, map_location=device))
|
48 |
+
binary_classifier.eval()
|
49 |
+
|
50 |
+
# 判断 patch 是否有效
|
51 |
+
def is_valid_patch(patch, binary_classifier, preprocess, device):
|
52 |
+
if patch.size[0] <= 0 or patch.size[1] <= 0:
|
53 |
+
return False
|
54 |
+
patch_tensor = preprocess(patch).unsqueeze(0).to(device)
|
55 |
+
with torch.no_grad():
|
56 |
+
logits = binary_classifier(patch_tensor)
|
57 |
+
probabilities = torch.softmax(logits, dim=1)
|
58 |
+
prob_label_1 = probabilities[0, 1]
|
59 |
+
return prob_label_1.item() > 0.8
|
60 |
+
|
61 |
+
# 处理图片的主函数
|
62 |
+
def process_images(text_file_path, dataset_path, model, preprocess, clip_model, output_folder, device='cpu'):
|
63 |
+
boxes_dict = {}
|
64 |
+
with open(text_file_path, 'r') as f:
|
65 |
+
for line in f:
|
66 |
+
image_name, class_name = line.strip().split('\t')
|
67 |
+
print(f"Processing image: {image_name}")
|
68 |
+
text_prompt = class_name + ' .'
|
69 |
+
image_path = os.path.join(dataset_path, image_name)
|
70 |
+
img = Image.open(image_path).convert("RGB")
|
71 |
+
image_source, image = load_image(image_path)
|
72 |
+
h, w, _ = image_source.shape
|
73 |
+
boxes, logits, _ = predict(model, image, text_prompt, BOX_THRESHOLD, TEXT_THRESHOLD)
|
74 |
+
patches = box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy")
|
75 |
+
|
76 |
+
top_patches = []
|
77 |
+
for i, (box, logit) in enumerate(zip(patches, logits)):
|
78 |
+
box = box.cpu().numpy() * np.array([w, h, w, h], dtype=np.float32)
|
79 |
+
x1, y1, x2, y2 = box.astype(int)
|
80 |
+
x1, y1, x2, y2 = max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)
|
81 |
+
patch = img.crop((x1, y1, x2, y2))
|
82 |
+
|
83 |
+
if patch.size == (0, 0) or not is_valid_patch(patch, binary_classifier, preprocess, device) or x2 - x1 > w / 2 or y2 - y1 > h / 2 or y2 - y1 < 5 or x2 - x1 < 5:
|
84 |
+
print(f"Skipping patch due to binary classifier at box {box}")
|
85 |
+
continue
|
86 |
+
top_patches.append((i, logit))
|
87 |
+
|
88 |
+
top_patches.sort(key=lambda x: x[1], reverse=True)
|
89 |
+
top_3_indices = [patch[0] for patch in top_patches[:3]]
|
90 |
+
|
91 |
+
# 确保每张图像都有三个边界框
|
92 |
+
while len(top_3_indices) < 3:
|
93 |
+
if len(top_3_indices) > 0:
|
94 |
+
top_3_indices.append(top_3_indices[-1])
|
95 |
+
else:
|
96 |
+
default_box = torch.tensor([0, 0, 20 / w, 20 / h]).unsqueeze(0)
|
97 |
+
patches = torch.cat((patches, default_box.to(boxes.device)), dim=0)
|
98 |
+
top_3_indices.append(len(patches) - 1)
|
99 |
+
|
100 |
+
boxes_dict[image_name] = [patches[idx].cpu().numpy().tolist() * np.array([w, h, w, h], dtype=np.float32) for idx in top_3_indices]
|
101 |
+
|
102 |
+
return boxes_dict
|
103 |
+
|
104 |
+
# 主函数
|
105 |
+
def main(args):
|
106 |
+
# 设置固定的默认路径
|
107 |
+
model_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
108 |
+
model_weights = "GroundingDINO/weights/groundingdino_swint_ogc.pth"
|
109 |
+
output_folder = os.path.join(args.root_path, "annotated_images")
|
110 |
+
|
111 |
+
# 根据 root_path 设置路径
|
112 |
+
text_file_path = os.path.join(args.root_path, "ImageClasses_FSC147.txt")
|
113 |
+
dataset_path = os.path.join(args.root_path, "images_384_VarV2")
|
114 |
+
input_json_path = os.path.join(args.root_path, "annotation_FSC147_384_old.json")
|
115 |
+
output_json_path = os.path.join(args.root_path, "annotation_FSC147_pos.json")
|
116 |
+
|
117 |
+
os.makedirs(output_folder, exist_ok=True)
|
118 |
+
|
119 |
+
# 加载 GroundingDINO 模型
|
120 |
+
model = load_model(model_config, model_weights, device=device)
|
121 |
+
|
122 |
+
# 处理��片并生成边界框
|
123 |
+
boxes_dict = process_images(text_file_path, dataset_path, model, preprocess, clip_model, output_folder, device=device)
|
124 |
+
|
125 |
+
# 更新 JSON 文件
|
126 |
+
with open(input_json_path, 'r') as f:
|
127 |
+
data = json.load(f)
|
128 |
+
|
129 |
+
for image_name, boxes in boxes_dict.items():
|
130 |
+
if image_name in data:
|
131 |
+
new_boxes = [[[x1, y1], [x1, y2], [x2, y2], [x2, y1]] for x1, y1, x2, y2 in boxes]
|
132 |
+
data[image_name]["box_examples_coordinates"] = new_boxes
|
133 |
+
|
134 |
+
with open(output_json_path, 'w') as f:
|
135 |
+
json.dump(data, f, indent=4)
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
parser = argparse.ArgumentParser(description="Image Processing Script")
|
139 |
+
parser.add_argument("--root_path", type=str, required=True, help="Root path to the dataset and output files")
|
140 |
+
args = parser.parse_args()
|
141 |
+
main(args)
|
models_crossvit.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.hub
|
5 |
+
from itertools import repeat
|
6 |
+
import collections.abc
|
7 |
+
|
8 |
+
|
9 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
10 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
11 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
12 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
13 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
14 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
15 |
+
'survival rate' as the argument.
|
16 |
+
"""
|
17 |
+
if drop_prob == 0. or not training:
|
18 |
+
return x
|
19 |
+
keep_prob = 1 - drop_prob
|
20 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
21 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
22 |
+
if keep_prob > 0.0 and scale_by_keep:
|
23 |
+
random_tensor.div_(keep_prob)
|
24 |
+
return x * random_tensor
|
25 |
+
|
26 |
+
class DropPath(nn.Module):
|
27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
28 |
+
"""
|
29 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
30 |
+
super(DropPath, self).__init__()
|
31 |
+
self.drop_prob = drop_prob
|
32 |
+
self.scale_by_keep = scale_by_keep
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
36 |
+
|
37 |
+
def _ntuple(n):
|
38 |
+
def parse(x):
|
39 |
+
if isinstance(x, collections.abc.Iterable):
|
40 |
+
return x
|
41 |
+
return tuple(repeat(x, n))
|
42 |
+
return parse
|
43 |
+
|
44 |
+
to_2tuple = _ntuple(2)
|
45 |
+
|
46 |
+
class Mlp(nn.Module):
|
47 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
48 |
+
"""
|
49 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
50 |
+
super().__init__()
|
51 |
+
out_features = out_features or in_features
|
52 |
+
hidden_features = hidden_features or in_features
|
53 |
+
drop_probs = to_2tuple(drop)
|
54 |
+
|
55 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
56 |
+
self.act = act_layer()
|
57 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
58 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
59 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = self.fc1(x)
|
63 |
+
x = self.act(x)
|
64 |
+
x = self.drop1(x)
|
65 |
+
x = self.fc2(x)
|
66 |
+
x = self.drop2(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
class Attention(nn.Module):
|
70 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
71 |
+
super().__init__()
|
72 |
+
self.num_heads = num_heads
|
73 |
+
head_dim = dim // num_heads
|
74 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
75 |
+
self.scale = qk_scale or head_dim ** -0.5
|
76 |
+
|
77 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
78 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
79 |
+
self.proj = nn.Linear(dim, dim)
|
80 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
B, N, C = x.shape
|
84 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
85 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
86 |
+
|
87 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
88 |
+
attn = attn.softmax(dim=-1)
|
89 |
+
attn = self.attn_drop(attn)
|
90 |
+
|
91 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
92 |
+
x = self.proj(x)
|
93 |
+
x = self.proj_drop(x)
|
94 |
+
return x
|
95 |
+
|
96 |
+
class CrossAttention(nn.Module):
|
97 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
98 |
+
super().__init__()
|
99 |
+
self.num_heads = num_heads
|
100 |
+
head_dim = dim // num_heads
|
101 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
102 |
+
self.scale = qk_scale or head_dim ** -0.5
|
103 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
104 |
+
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
105 |
+
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
|
106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
107 |
+
self.proj = nn.Linear(dim, dim)
|
108 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
109 |
+
|
110 |
+
def forward(self, x, y):
|
111 |
+
B, Nx, C = x.shape
|
112 |
+
Ny = y.shape[1]
|
113 |
+
# BNxC -> BNxH(C/H) -> BHNx(C/H)
|
114 |
+
q = self.wq(x).reshape(B, Nx, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
115 |
+
# BNyC -> BNyH(C/H) -> BHNy(C/H)
|
116 |
+
k = self.wk(y).reshape(B, Ny, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
117 |
+
# BNyC -> BNyH(C/H) -> BHNy(C/H)
|
118 |
+
v = self.wv(y).reshape(B, Ny, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
119 |
+
|
120 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale # BHNx(C/H) @ BH(C/H)Ny -> BHNxNy
|
121 |
+
attn = attn.softmax(dim=-1)
|
122 |
+
attn = self.attn_drop(attn)
|
123 |
+
|
124 |
+
x = (attn @ v).transpose(1, 2).reshape(B, Nx, C) # (BHNxNy @ BHNy(C/H)) -> BHNx(C/H) -> BNxH(C/H) -> BNxC
|
125 |
+
x = self.proj(x)
|
126 |
+
x = self.proj_drop(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
class CrossAttentionBlock(nn.Module):
|
130 |
+
|
131 |
+
def __init__(
|
132 |
+
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
133 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
self.norm0 = norm_layer(dim)
|
137 |
+
self.selfattn = Attention(
|
138 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
139 |
+
self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
140 |
+
|
141 |
+
self.norm1 = norm_layer(dim)
|
142 |
+
self.attn = CrossAttention(
|
143 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
144 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
145 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
146 |
+
|
147 |
+
self.norm2 = norm_layer(dim)
|
148 |
+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
149 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
150 |
+
|
151 |
+
def forward(self, x, y):
|
152 |
+
x = x + self.drop_path0(self.selfattn(self.norm0(x)))
|
153 |
+
x = x + self.drop_path1(self.attn(self.norm1(x), y))
|
154 |
+
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
155 |
+
return x
|
models_mae_cross.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from functools import partial
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torchvision.utils
|
12 |
+
|
13 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
14 |
+
from models_crossvit import CrossAttentionBlock
|
15 |
+
|
16 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
17 |
+
|
18 |
+
class SupervisedMAE(nn.Module):
|
19 |
+
def __init__(self, img_size=384, patch_size=16, in_chans=3,
|
20 |
+
embed_dim=1024, depth=24, num_heads=16,
|
21 |
+
decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
|
22 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
# --------------------------------------------------------------------------
|
26 |
+
# MAE encoder specifics
|
27 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
28 |
+
num_patches = self.patch_embed.num_patches
|
29 |
+
|
30 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
31 |
+
|
32 |
+
self.blocks = nn.ModuleList([
|
33 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
|
34 |
+
for i in range(depth)])
|
35 |
+
self.norm = norm_layer(embed_dim)
|
36 |
+
# --------------------------------------------------------------------------
|
37 |
+
|
38 |
+
# --------------------------------------------------------------------------
|
39 |
+
# MAE decoder specifics
|
40 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
41 |
+
|
42 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
43 |
+
|
44 |
+
self.shot_token = nn.Parameter(torch.zeros(512))
|
45 |
+
|
46 |
+
# Exemplar encoder with CNN
|
47 |
+
self.decoder_proj1 = nn.Sequential(
|
48 |
+
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
|
49 |
+
nn.InstanceNorm2d(64),
|
50 |
+
nn.ReLU(inplace=True),
|
51 |
+
nn.MaxPool2d(2) #[3,64,64]->[64,32,32]
|
52 |
+
)
|
53 |
+
self.decoder_proj2 = nn.Sequential(
|
54 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
55 |
+
nn.InstanceNorm2d(128),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
nn.MaxPool2d(2) #[64,32,32]->[128,16,16]
|
58 |
+
)
|
59 |
+
self.decoder_proj3 = nn.Sequential(
|
60 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
61 |
+
nn.InstanceNorm2d(256),
|
62 |
+
nn.ReLU(inplace=True),
|
63 |
+
nn.MaxPool2d(2) # [128,16,16]->[256,8,8]
|
64 |
+
)
|
65 |
+
self.decoder_proj4 = nn.Sequential(
|
66 |
+
nn.Conv2d(256, decoder_embed_dim, kernel_size=3, stride=1, padding=1),
|
67 |
+
nn.InstanceNorm2d(512),
|
68 |
+
nn.ReLU(inplace=True),
|
69 |
+
nn.AdaptiveAvgPool2d((1,1))
|
70 |
+
# [256,8,8]->[512,1,1]
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
self.decoder_blocks = nn.ModuleList([
|
75 |
+
CrossAttentionBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
|
76 |
+
for i in range(decoder_depth)])
|
77 |
+
|
78 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
79 |
+
# Density map regresssion module
|
80 |
+
self.decode_head0 = nn.Sequential(
|
81 |
+
nn.Conv2d(decoder_embed_dim, 256, kernel_size=3, stride=1, padding=1),
|
82 |
+
nn.GroupNorm(8, 256),
|
83 |
+
nn.ReLU(inplace=True)
|
84 |
+
)
|
85 |
+
self.decode_head1 = nn.Sequential(
|
86 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
87 |
+
nn.GroupNorm(8, 256),
|
88 |
+
nn.ReLU(inplace=True)
|
89 |
+
)
|
90 |
+
self.decode_head2 = nn.Sequential(
|
91 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
92 |
+
nn.GroupNorm(8, 256),
|
93 |
+
nn.ReLU(inplace=True)
|
94 |
+
)
|
95 |
+
self.decode_head3 = nn.Sequential(
|
96 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
97 |
+
nn.GroupNorm(8, 256),
|
98 |
+
nn.ReLU(inplace=True),
|
99 |
+
nn.Conv2d(256, 1, kernel_size=1, stride=1)
|
100 |
+
)
|
101 |
+
|
102 |
+
# --------------------------------------------------------------------------
|
103 |
+
|
104 |
+
self.norm_pix_loss = norm_pix_loss
|
105 |
+
|
106 |
+
self.initialize_weights()
|
107 |
+
|
108 |
+
def initialize_weights(self):
|
109 |
+
# initialization
|
110 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
111 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
|
112 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
113 |
+
|
114 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
|
115 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
116 |
+
|
117 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
118 |
+
w = self.patch_embed.proj.weight.data
|
119 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
120 |
+
|
121 |
+
torch.nn.init.normal_(self.shot_token, std=.02)
|
122 |
+
|
123 |
+
# initialize nn.Linear and nn.LayerNorm
|
124 |
+
self.apply(self._init_weights)
|
125 |
+
|
126 |
+
def _init_weights(self, m):
|
127 |
+
if isinstance(m, nn.Linear):
|
128 |
+
# we use xavier_uniform following official JAX ViT:
|
129 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
130 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
131 |
+
nn.init.constant_(m.bias, 0)
|
132 |
+
elif isinstance(m, nn.LayerNorm):
|
133 |
+
nn.init.constant_(m.bias, 0)
|
134 |
+
nn.init.constant_(m.weight, 1.0)
|
135 |
+
|
136 |
+
def forward_encoder(self, x):
|
137 |
+
# embed patches
|
138 |
+
x = self.patch_embed(x)
|
139 |
+
|
140 |
+
# add pos embed w/o cls token
|
141 |
+
x = x + self.pos_embed
|
142 |
+
|
143 |
+
# apply Transformer blocks
|
144 |
+
for blk in self.blocks:
|
145 |
+
x = blk(x)
|
146 |
+
x = self.norm(x)
|
147 |
+
|
148 |
+
return x
|
149 |
+
|
150 |
+
def forward_decoder(self, x, y_, shot_num=3):
|
151 |
+
# embed tokens
|
152 |
+
x = self.decoder_embed(x)
|
153 |
+
# add pos embed
|
154 |
+
x = x + self.decoder_pos_embed
|
155 |
+
|
156 |
+
# Exemplar encoder
|
157 |
+
y_ = y_.transpose(0,1) # y_ [N,3,3,64,64]->[3,N,3,64,64]
|
158 |
+
y1=[]
|
159 |
+
C=0
|
160 |
+
N=0
|
161 |
+
cnt = 0
|
162 |
+
for yi in y_:
|
163 |
+
cnt+=1
|
164 |
+
if cnt > shot_num:
|
165 |
+
break
|
166 |
+
yi = self.decoder_proj1(yi)
|
167 |
+
yi = self.decoder_proj2(yi)
|
168 |
+
yi = self.decoder_proj3(yi)
|
169 |
+
yi = self.decoder_proj4(yi)
|
170 |
+
N, C,_,_ = yi.shape
|
171 |
+
y1.append(yi.squeeze(-1).squeeze(-1)) # yi [N,C,1,1]->[N,C]
|
172 |
+
|
173 |
+
if shot_num > 0:
|
174 |
+
y = torch.cat(y1,dim=0).reshape(shot_num,N,C).to(x.device)
|
175 |
+
else:
|
176 |
+
y = self.shot_token.repeat(y_.shape[1],1).unsqueeze(0).to(x.device)
|
177 |
+
y = y.transpose(0,1) # y [3,N,C]->[N,3,C]
|
178 |
+
|
179 |
+
# apply Transformer blocks
|
180 |
+
for blk in self.decoder_blocks:
|
181 |
+
x = blk(x, y)
|
182 |
+
x = self.decoder_norm(x)
|
183 |
+
|
184 |
+
# Density map regression
|
185 |
+
n, hw, c = x.shape
|
186 |
+
h = w = int(math.sqrt(hw))
|
187 |
+
x = x.transpose(1, 2).reshape(n, c, h, w)
|
188 |
+
|
189 |
+
x = F.interpolate(
|
190 |
+
self.decode_head0(x), size=x.shape[-1]*2, mode='bilinear', align_corners=False)
|
191 |
+
x = F.interpolate(
|
192 |
+
self.decode_head1(x), size=x.shape[-1]*2, mode='bilinear', align_corners=False)
|
193 |
+
x = F.interpolate(
|
194 |
+
self.decode_head2(x), size=x.shape[-1]*2, mode='bilinear', align_corners=False)
|
195 |
+
x = F.interpolate(
|
196 |
+
self.decode_head3(x), size=x.shape[-1]*2, mode='bilinear', align_corners=False)
|
197 |
+
x = x.squeeze(-3)
|
198 |
+
|
199 |
+
return x
|
200 |
+
|
201 |
+
def forward(self, imgs, boxes, shot_num):
|
202 |
+
# if boxes.nelement() > 0:
|
203 |
+
# torchvision.utils.save_image(boxes[0], f"data/out/crops/box_{time.time()}_{random.randint(0, 99999):>5}.png")
|
204 |
+
with torch.no_grad():
|
205 |
+
latent = self.forward_encoder(imgs)
|
206 |
+
pred = self.forward_decoder(latent, boxes, shot_num) # [N, 384, 384]
|
207 |
+
return pred
|
208 |
+
|
209 |
+
|
210 |
+
def mae_vit_base_patch16_dec512d8b(**kwargs):
|
211 |
+
model = SupervisedMAE(
|
212 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
213 |
+
decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
|
214 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
215 |
+
return model
|
216 |
+
|
217 |
+
|
218 |
+
def mae_vit_large_patch16_dec512d8b(**kwargs):
|
219 |
+
model = SupervisedMAE(
|
220 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
221 |
+
decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
|
222 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
223 |
+
return model
|
224 |
+
|
225 |
+
|
226 |
+
def mae_vit_huge_patch14_dec512d8b(**kwargs):
|
227 |
+
model = SupervisedMAE(
|
228 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
|
229 |
+
decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
|
230 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
231 |
+
return model
|
232 |
+
|
233 |
+
def mae_vit_base_patch16_fim4(**kwargs):
|
234 |
+
model = SupervisedMAE(
|
235 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
236 |
+
decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16,
|
237 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
238 |
+
return model
|
239 |
+
|
240 |
+
def mae_vit_base_patch16_fim6(**kwargs):
|
241 |
+
model = SupervisedMAE(
|
242 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
243 |
+
decoder_embed_dim=512, decoder_depth=6, decoder_num_heads=16,
|
244 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
245 |
+
return model
|
246 |
+
|
247 |
+
|
248 |
+
# set recommended archs
|
249 |
+
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b
|
250 |
+
mae_vit_base4_patch16 = mae_vit_base_patch16_fim4 # decoder: 4 blocks
|
251 |
+
mae_vit_base6_patch16 = mae_vit_base_patch16_fim6 # decoder: 6 blocks
|
252 |
+
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b
|
253 |
+
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b
|
models_mae_noct.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
7 |
+
|
8 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
9 |
+
|
10 |
+
|
11 |
+
class MaskedAutoencoderViTNoCT(nn.Module):
|
12 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
13 |
+
"""
|
14 |
+
def __init__(self, img_size=384, patch_size=16, in_chans=3,
|
15 |
+
embed_dim=1024, depth=24, num_heads=16,
|
16 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
17 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
# --------------------------------------------------------------------------
|
21 |
+
# MAE encoder specifics
|
22 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
23 |
+
num_patches = self.patch_embed.num_patches
|
24 |
+
|
25 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
26 |
+
|
27 |
+
self.blocks = nn.ModuleList([
|
28 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
|
29 |
+
for i in range(depth)])
|
30 |
+
self.norm = norm_layer(embed_dim)
|
31 |
+
# --------------------------------------------------------------------------
|
32 |
+
|
33 |
+
# --------------------------------------------------------------------------
|
34 |
+
# MAE decoder specifics
|
35 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
36 |
+
|
37 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
38 |
+
|
39 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
40 |
+
|
41 |
+
self.decoder_blocks = nn.ModuleList([
|
42 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
|
43 |
+
for i in range(decoder_depth)])
|
44 |
+
|
45 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
46 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
|
47 |
+
# --------------------------------------------------------------------------
|
48 |
+
|
49 |
+
self.norm_pix_loss = norm_pix_loss
|
50 |
+
|
51 |
+
self.initialize_weights()
|
52 |
+
|
53 |
+
def initialize_weights(self):
|
54 |
+
# initialization
|
55 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
56 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
|
57 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
58 |
+
|
59 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
|
60 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
61 |
+
|
62 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
63 |
+
w = self.patch_embed.proj.weight.data
|
64 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
65 |
+
|
66 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
67 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
68 |
+
|
69 |
+
# initialize nn.Linear and nn.LayerNorm
|
70 |
+
self.apply(self._init_weights)
|
71 |
+
|
72 |
+
def _init_weights(self, m):
|
73 |
+
if isinstance(m, nn.Linear):
|
74 |
+
# we use xavier_uniform following official JAX ViT:
|
75 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
76 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
77 |
+
nn.init.constant_(m.bias, 0)
|
78 |
+
elif isinstance(m, nn.LayerNorm):
|
79 |
+
nn.init.constant_(m.bias, 0)
|
80 |
+
nn.init.constant_(m.weight, 1.0)
|
81 |
+
|
82 |
+
def patchify(self, imgs):
|
83 |
+
"""
|
84 |
+
imgs: (N, 3, H, W)
|
85 |
+
x: (N, L, patch_size**2 *3)
|
86 |
+
"""
|
87 |
+
p = self.patch_embed.patch_size[0]
|
88 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
89 |
+
|
90 |
+
h = w = imgs.shape[2] // p
|
91 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
92 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
93 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
94 |
+
return x
|
95 |
+
|
96 |
+
def unpatchify(self, x):
|
97 |
+
"""
|
98 |
+
x: (N, L, patch_size**2 *3)
|
99 |
+
imgs: (N, 3, H, W)
|
100 |
+
"""
|
101 |
+
p = self.patch_embed.patch_size[0]
|
102 |
+
h = w = int(x.shape[1]**.5)
|
103 |
+
assert h * w == x.shape[1]
|
104 |
+
|
105 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
106 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
107 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
108 |
+
return imgs
|
109 |
+
|
110 |
+
def random_masking(self, x, mask_ratio):
|
111 |
+
"""
|
112 |
+
Perform per-sample random masking by per-sample shuffling.
|
113 |
+
Per-sample shuffling is done by argsort random noise.
|
114 |
+
x: [N, L, D], sequence
|
115 |
+
"""
|
116 |
+
N, L, D = x.shape # batch, length, dim
|
117 |
+
len_keep = int(L * (1 - mask_ratio))
|
118 |
+
|
119 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
120 |
+
|
121 |
+
# sort noise for each sample
|
122 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
123 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
124 |
+
|
125 |
+
# keep the first subset
|
126 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
127 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
128 |
+
|
129 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
130 |
+
mask = torch.ones([N, L], device=x.device)
|
131 |
+
mask[:, :len_keep] = 0
|
132 |
+
# unshuffle to get the binary mask
|
133 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
134 |
+
|
135 |
+
return x_masked, mask, ids_restore
|
136 |
+
|
137 |
+
def forward_encoder(self, x, mask_ratio):
|
138 |
+
# embed patches
|
139 |
+
x = self.patch_embed(x)
|
140 |
+
|
141 |
+
# add pos embed w/o cls token
|
142 |
+
x = x + self.pos_embed
|
143 |
+
|
144 |
+
# masking: length -> length * mask_ratio
|
145 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
146 |
+
|
147 |
+
# apply Transformer blocks
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
x = self.norm(x)
|
151 |
+
|
152 |
+
return x, mask, ids_restore
|
153 |
+
|
154 |
+
def forward_decoder(self, x, ids_restore):
|
155 |
+
# embed tokens
|
156 |
+
x = self.decoder_embed(x)
|
157 |
+
|
158 |
+
# append mask tokens to sequence
|
159 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
|
160 |
+
x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
|
161 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
162 |
+
x = x_ # append cls token
|
163 |
+
|
164 |
+
# add pos embed
|
165 |
+
x = x + self.decoder_pos_embed
|
166 |
+
|
167 |
+
# apply Transformer blocks
|
168 |
+
for blk in self.decoder_blocks:
|
169 |
+
x = blk(x)
|
170 |
+
x = self.decoder_norm(x)
|
171 |
+
|
172 |
+
# predictor projection
|
173 |
+
x = self.decoder_pred(x)
|
174 |
+
|
175 |
+
return x
|
176 |
+
|
177 |
+
def forward_loss(self, imgs, pred, mask):
|
178 |
+
"""
|
179 |
+
imgs: [N, 3, H, W]
|
180 |
+
pred: [N, L, p*p*3]
|
181 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
182 |
+
"""
|
183 |
+
target = self.patchify(imgs)
|
184 |
+
if self.norm_pix_loss:
|
185 |
+
mean = target.mean(dim=-1, keepdim=True)
|
186 |
+
var = target.var(dim=-1, keepdim=True)
|
187 |
+
target = (target - mean) / (var + 1.e-6)**.5
|
188 |
+
|
189 |
+
loss = (pred - target) ** 2
|
190 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
191 |
+
|
192 |
+
# For mean loss on all patches
|
193 |
+
N, L = mask.shape
|
194 |
+
mask_s = torch.ones([N, L], device=imgs.device)
|
195 |
+
loss = (loss * mask_s).sum() / mask_s.sum()
|
196 |
+
|
197 |
+
#loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
198 |
+
return loss
|
199 |
+
|
200 |
+
def forward(self, imgs, mask_ratio=0.75):
|
201 |
+
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
|
202 |
+
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
203 |
+
loss = self.forward_loss(imgs, pred, mask)
|
204 |
+
return loss, pred, mask
|
205 |
+
|
206 |
+
|
207 |
+
def mae_vit_base_patch16_dec512d8b(**kwargs):
|
208 |
+
model = MaskedAutoencoderViTNoCT(
|
209 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
210 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
211 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
212 |
+
return model
|
213 |
+
|
214 |
+
|
215 |
+
def mae_vit_large_patch16_dec512d8b(**kwargs):
|
216 |
+
model = MaskedAutoencoderViTNoCT(
|
217 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
218 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
219 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
220 |
+
return model
|
221 |
+
|
222 |
+
|
223 |
+
def mae_vit_huge_patch14_dec512d8b(**kwargs):
|
224 |
+
model = MaskedAutoencoderViTNoCT(
|
225 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
|
226 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
227 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
228 |
+
return model
|
229 |
+
|
230 |
+
|
231 |
+
# set recommended archs
|
232 |
+
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
|
233 |
+
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
|
234 |
+
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
2 |
+
|
3 |
+
torch==1.13.1+cu116
|
4 |
+
torchvision==0.14.1+cu116
|
5 |
+
timm==0.4.9
|
6 |
+
numpy==1.23.4
|
7 |
+
scipy==1.10.1
|
8 |
+
imgaug==0.4.0
|
9 |
+
pillow==9.3.0
|
10 |
+
matplotlib==3.6.3
|
11 |
+
hub==3.0.1
|
12 |
+
pandas==1.5.2
|
13 |
+
six==1.16.0
|
14 |
+
wandb
|
15 |
+
tqdm
|
util/FSC147.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
from torchvision import transforms
|
8 |
+
import torch
|
9 |
+
import cv2
|
10 |
+
import torchvision.transforms.functional as TF
|
11 |
+
import scipy.ndimage as ndimage
|
12 |
+
from PIL import Image
|
13 |
+
import argparse
|
14 |
+
import imgaug.augmenters as iaa
|
15 |
+
from imgaug.augmentables import Keypoint, KeypointsOnImage
|
16 |
+
|
17 |
+
MAX_HW = 384
|
18 |
+
IM_NORM_MEAN = [0.485, 0.456, 0.406]
|
19 |
+
IM_NORM_STD = [0.229, 0.224, 0.225]
|
20 |
+
|
21 |
+
def get_args_parser():
|
22 |
+
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
|
23 |
+
parser.add_argument('--batch_size', default=8, type=int,
|
24 |
+
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
25 |
+
parser.add_argument('--epochs', default=200, type=int)
|
26 |
+
parser.add_argument('--accum_iter', default=1, type=int,
|
27 |
+
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
28 |
+
|
29 |
+
# Model parameters
|
30 |
+
parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
|
31 |
+
help='Name of model to train')
|
32 |
+
|
33 |
+
parser.add_argument('--mask_ratio', default=0.5, type=float,
|
34 |
+
help='Masking ratio (percentage of removed patches).')
|
35 |
+
|
36 |
+
parser.add_argument('--norm_pix_loss', action='store_true',
|
37 |
+
help='Use (per-patch) normalized pixels as targets for computing loss')
|
38 |
+
parser.set_defaults(norm_pix_loss=False)
|
39 |
+
|
40 |
+
# Optimizer parameters
|
41 |
+
parser.add_argument('--weight_decay', type=float, default=0.05,
|
42 |
+
help='weight decay (default: 0.05)')
|
43 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
44 |
+
help='learning rate (absolute lr)')
|
45 |
+
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
|
46 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
47 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
48 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
49 |
+
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
50 |
+
help='epochs to warmup LR')
|
51 |
+
|
52 |
+
# Dataset parameters
|
53 |
+
parser.add_argument('--data_path', default='./data/FSC147/', type=str,
|
54 |
+
help='dataset path')
|
55 |
+
parser.add_argument('--anno_file', default='annotation_FSC147_384.json', type=str,
|
56 |
+
help='annotation json file')
|
57 |
+
parser.add_argument('--data_split_file', default='Train_Test_Val_FSC_147.json', type=str,
|
58 |
+
help='data split json file')
|
59 |
+
parser.add_argument('--im_dir', default='images_384_VarV2', type=str,
|
60 |
+
help='images directory')
|
61 |
+
parser.add_argument('--gt_dir', default='./data/FSC147/gt_density_map_adaptive_384_VarV2', type=str,
|
62 |
+
help='ground truth directory')
|
63 |
+
parser.add_argument('--output_dir', default='./data/out/pre_4_dir',
|
64 |
+
help='path where to save, empty for no saving')
|
65 |
+
parser.add_argument('--device', default='cuda',
|
66 |
+
help='device to use for training / testing')
|
67 |
+
parser.add_argument('--seed', default=0, type=int)
|
68 |
+
parser.add_argument('--resume', default='./weights/mae_pretrain_vit_base_full.pth', # mae_visualize_vit_base
|
69 |
+
help='resume from checkpoint')
|
70 |
+
|
71 |
+
# Training parameters
|
72 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
73 |
+
help='start epoch')
|
74 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
75 |
+
parser.add_argument('--pin_mem', action='store_true',
|
76 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
77 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
78 |
+
parser.set_defaults(pin_mem=True)
|
79 |
+
|
80 |
+
# Distributed training parameters
|
81 |
+
parser.add_argument('--world_size', default=1, type=int,
|
82 |
+
help='number of distributed processes')
|
83 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
84 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
85 |
+
parser.add_argument('--dist_url', default='env://',
|
86 |
+
help='url used to set up distributed training')
|
87 |
+
|
88 |
+
# Logging parameters
|
89 |
+
parser.add_argument('--log_dir', default='./logs/pre_4_dir',
|
90 |
+
help='path where to tensorboard log')
|
91 |
+
parser.add_argument("--title", default="CounTR_pretraining", type=str)
|
92 |
+
parser.add_argument("--wandb", default="counting", type=str)
|
93 |
+
parser.add_argument("--team", default="wsense", type=str)
|
94 |
+
parser.add_argument("--wandb_id", default=None, type=str)
|
95 |
+
parser.add_argument("--do_aug", default=True, type=bool)
|
96 |
+
parser.add_argument('--class_file', default='./data/FSC147/ImageClasses_FSC147.txt', type=str,
|
97 |
+
help='class json file')
|
98 |
+
return parser
|
99 |
+
|
100 |
+
args = get_args_parser()
|
101 |
+
args = args.parse_args()
|
102 |
+
|
103 |
+
class ResizeSomeImage(object):
|
104 |
+
def __init__(self, args):
|
105 |
+
args = get_args_parser()
|
106 |
+
args = args.parse_args()
|
107 |
+
# print(dir(args.im_dir.as_posix()))
|
108 |
+
self.data_path = Path(args.data_path)
|
109 |
+
self.im_dir = self.data_path/args.im_dir
|
110 |
+
anno_file = self.data_path/args.anno_file
|
111 |
+
data_split_file = self.data_path/args.data_split_file
|
112 |
+
|
113 |
+
with open(anno_file) as f:
|
114 |
+
self.annotations = json.load(f)
|
115 |
+
|
116 |
+
with open(data_split_file) as f:
|
117 |
+
data_split = json.load(f)
|
118 |
+
|
119 |
+
self.train_set = data_split['train']
|
120 |
+
|
121 |
+
self.class_dict = {}
|
122 |
+
if args.do_aug:
|
123 |
+
with open(args.class_file) as f:
|
124 |
+
for line in f:
|
125 |
+
key = line.split()[0]
|
126 |
+
val = line.split()[1:]
|
127 |
+
self.class_dict[key] = val
|
128 |
+
|
129 |
+
|
130 |
+
class ResizePreTrainImage(ResizeSomeImage):
|
131 |
+
"""
|
132 |
+
Resize the image so that:
|
133 |
+
1. Image is equal to 384 * 384
|
134 |
+
2. The new height and new width are divisible by 16
|
135 |
+
3. The aspect ratio is preserved
|
136 |
+
Density and boxes correctness not preserved(crop and horizontal flip)
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __init__(self, args, MAX_HW=384):
|
140 |
+
super().__init__(args)
|
141 |
+
self.max_hw = MAX_HW
|
142 |
+
|
143 |
+
def __call__(self, sample):
|
144 |
+
image, lines_boxes, density = sample['image'], sample['lines_boxes'], sample['gt_density']
|
145 |
+
|
146 |
+
W, H = image.size
|
147 |
+
|
148 |
+
new_H = 16 * int(H / 16)
|
149 |
+
new_W = 16 * int(W / 16)
|
150 |
+
resized_image = transforms.Resize((new_H, new_W))(image)
|
151 |
+
resized_density = cv2.resize(density, (new_W, new_H))
|
152 |
+
orig_count = np.sum(density)
|
153 |
+
new_count = np.sum(resized_density)
|
154 |
+
|
155 |
+
if new_count > 0:
|
156 |
+
resized_density = resized_density * (orig_count / new_count)
|
157 |
+
|
158 |
+
boxes = list()
|
159 |
+
for box in lines_boxes:
|
160 |
+
box2 = [int(k) for k in box]
|
161 |
+
y1, x1, y2, x2 = box2[0], box2[1], box2[2], box2[3]
|
162 |
+
boxes.append([0, y1, x1, y2, x2])
|
163 |
+
|
164 |
+
boxes = torch.Tensor(boxes).unsqueeze(0)
|
165 |
+
resized_image = PreTrainNormalize(resized_image)
|
166 |
+
resized_density = torch.from_numpy(resized_density).unsqueeze(0).unsqueeze(0)
|
167 |
+
sample = {'image': resized_image, 'boxes': boxes, 'gt_density': resized_density}
|
168 |
+
return sample
|
169 |
+
|
170 |
+
|
171 |
+
class ResizeTrainImage(ResizeSomeImage):
|
172 |
+
"""
|
173 |
+
Resize the image so that:
|
174 |
+
1. Image is equal to 384 * 384
|
175 |
+
2. The new height and new width are divisible by 16
|
176 |
+
3. The aspect ratio is possibly preserved
|
177 |
+
Density map is cropped to have the same size(and position) with the cropped image
|
178 |
+
Exemplar boxes may be outside the cropped area.
|
179 |
+
Augmentation including Gaussian noise, Color jitter, Gaussian blur, Random affine, Random horizontal flip and Mosaic (or Random Crop if no Mosaic) is used.
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(self, args, MAX_HW=384, do_aug=True):
|
183 |
+
super().__init__(args)
|
184 |
+
self.max_hw = MAX_HW
|
185 |
+
self.do_aug = do_aug
|
186 |
+
|
187 |
+
def __call__(self, sample):
|
188 |
+
image, lines_boxes, neg_lines_boxes, dots, im_id, m_flag = sample['image'], sample['lines_boxes'], sample['neg_lines_boxes'], \
|
189 |
+
sample['dots'], sample['id'], sample['m_flag']
|
190 |
+
|
191 |
+
W, H = image.size
|
192 |
+
|
193 |
+
new_H = 16 * int(H / 16)
|
194 |
+
new_W = 16 * int(W / 16)
|
195 |
+
scale_factor_h = float(new_H) / H
|
196 |
+
scale_factor_w = float(new_W) / W
|
197 |
+
resized_image = transforms.Resize((new_H, new_W))(image)
|
198 |
+
resized_image = TTensor(resized_image)
|
199 |
+
resized_density = np.zeros((new_H, new_W), dtype='float32')
|
200 |
+
|
201 |
+
# Augmentation probability
|
202 |
+
aug_flag = self.do_aug
|
203 |
+
mosaic_flag = random.random() < 0.25
|
204 |
+
|
205 |
+
if aug_flag:
|
206 |
+
# Gaussian noise
|
207 |
+
noise = np.random.normal(0, 0.1, resized_image.size())
|
208 |
+
noise = torch.from_numpy(noise)
|
209 |
+
re_image = resized_image + noise
|
210 |
+
re_image = torch.clamp(re_image, 0, 1)
|
211 |
+
|
212 |
+
# Color jitter and Gaussian blur
|
213 |
+
re_image = Augmentation(re_image)
|
214 |
+
|
215 |
+
# Random affine
|
216 |
+
re1_image = re_image.transpose(0, 1).transpose(1, 2).numpy()
|
217 |
+
keypoints = []
|
218 |
+
for i in range(dots.shape[0]):
|
219 |
+
keypoints.append(Keypoint(x=min(new_W - 1, int(dots[i][0] * scale_factor_w)), y=min(new_H - 1, int(dots[i][1] * scale_factor_h))))
|
220 |
+
kps = KeypointsOnImage(keypoints, re1_image.shape)
|
221 |
+
|
222 |
+
seq = iaa.Sequential([
|
223 |
+
iaa.Affine(
|
224 |
+
rotate=(-15, 15),
|
225 |
+
scale=(0.8, 1.2),
|
226 |
+
shear=(-10, 10),
|
227 |
+
translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}
|
228 |
+
)
|
229 |
+
])
|
230 |
+
re1_image, kps_aug = seq(image=re1_image, keypoints=kps)
|
231 |
+
|
232 |
+
# Produce dot annotation map
|
233 |
+
resized_density = np.zeros((resized_density.shape[0], resized_density.shape[1]), dtype='float32')
|
234 |
+
for i in range(len(kps.keypoints)):
|
235 |
+
if (int(kps_aug.keypoints[i].y) <= new_H - 1 and int(kps_aug.keypoints[i].x) <= new_W - 1) and not \
|
236 |
+
kps_aug.keypoints[i].is_out_of_image(re1_image):
|
237 |
+
resized_density[int(kps_aug.keypoints[i].y)][int(kps_aug.keypoints[i].x)] = 1
|
238 |
+
resized_density = torch.from_numpy(resized_density)
|
239 |
+
|
240 |
+
re_image = TTensor(re1_image)
|
241 |
+
|
242 |
+
# Random horizontal flip
|
243 |
+
flip_p = random.random()
|
244 |
+
if flip_p > 0.5:
|
245 |
+
re_image = TF.hflip(re_image)
|
246 |
+
resized_density = TF.hflip(resized_density)
|
247 |
+
|
248 |
+
# Random self mosaic
|
249 |
+
if mosaic_flag:
|
250 |
+
image_array = []
|
251 |
+
map_array = []
|
252 |
+
blending_l = random.randint(10, 20)
|
253 |
+
resize_l = 192 + 2 * blending_l
|
254 |
+
if dots.shape[0] >= 70:
|
255 |
+
for i in range(4):
|
256 |
+
length = random.randint(150, 384)
|
257 |
+
start_W = random.randint(0, new_W - length)
|
258 |
+
start_H = random.randint(0, new_H - length)
|
259 |
+
reresized_image1 = TF.crop(resized_image, start_H, start_W, length, length)
|
260 |
+
reresized_image1 = transforms.Resize((resize_l, resize_l))(reresized_image1)
|
261 |
+
reresized_density1 = np.zeros((resize_l, resize_l), dtype='float32')
|
262 |
+
for i in range(dots.shape[0]):
|
263 |
+
if start_H <= min(new_H - 1, int(dots[i][1] * scale_factor_h)) < start_H + length and start_W <= min(new_W - 1, int(dots[i][0] * scale_factor_w)) < start_W + length:
|
264 |
+
reresized_density1[min(resize_l-1,int((min(new_H-1,int(dots[i][1] * scale_factor_h))-start_H)*resize_l/length))][min(resize_l-1,int((min(new_W-1,int(dots[i][0] * scale_factor_w))-start_W)*resize_l/length))]=1
|
265 |
+
reresized_density1 = torch.from_numpy(reresized_density1)
|
266 |
+
image_array.append(reresized_image1)
|
267 |
+
map_array.append(reresized_density1)
|
268 |
+
else:
|
269 |
+
m_flag = 1
|
270 |
+
prob = random.random()
|
271 |
+
if prob > 0.25:
|
272 |
+
gt_pos = random.randint(0, 3)
|
273 |
+
else:
|
274 |
+
gt_pos = random.randint(0, 4) # 5% 0 objects
|
275 |
+
for i in range(4):
|
276 |
+
if i == gt_pos:
|
277 |
+
Tim_id = im_id
|
278 |
+
r_image = resized_image
|
279 |
+
Tdots = dots
|
280 |
+
new_TH = new_H
|
281 |
+
new_TW = new_W
|
282 |
+
Tscale_factor_w = scale_factor_w
|
283 |
+
Tscale_factor_h = scale_factor_h
|
284 |
+
else:
|
285 |
+
Tim_id = self.train_set[random.randint(0, len(self.train_set) - 1)]
|
286 |
+
Tdots = np.array(self.annotations[Tim_id]['points'])
|
287 |
+
Timage = Image.open('{}/{}'.format(self.im_dir, Tim_id))
|
288 |
+
Timage.load()
|
289 |
+
new_TW = 16 * int(Timage.size[0] / 16)
|
290 |
+
new_TH = 16 * int(Timage.size[1] / 16)
|
291 |
+
Tscale_factor_w = float(new_TW) / Timage.size[0]
|
292 |
+
Tscale_factor_h = float(new_TH) / Timage.size[1]
|
293 |
+
r_image = TTensor(transforms.Resize((new_TH, new_TW))(Timage))
|
294 |
+
|
295 |
+
length = random.randint(250, 384)
|
296 |
+
start_W = random.randint(0, new_TW - length)
|
297 |
+
start_H = random.randint(0, new_TH - length)
|
298 |
+
r_image1 = TF.crop(r_image, start_H, start_W, length, length)
|
299 |
+
r_image1 = transforms.Resize((resize_l, resize_l))(r_image1)
|
300 |
+
r_density1 = np.zeros((resize_l, resize_l), dtype='float32')
|
301 |
+
# try:
|
302 |
+
# class_value = self.class_dict[im_id]
|
303 |
+
# Tim_value = self.class_dict[Tim_id]
|
304 |
+
# except KeyError:
|
305 |
+
# # Handle the case when the key doesn't exist
|
306 |
+
# class_value = None # Or any appropriate default value
|
307 |
+
# Tim_value = None # Or any appropriate default value
|
308 |
+
if self.class_dict[im_id] == self.class_dict[Tim_id]:
|
309 |
+
# if class_value == Tim_value:
|
310 |
+
# if im_id in self.class_dict and Tim_id in self.class_dict:
|
311 |
+
# if im_id in self.class_dict and Tim_id in self.class_dict:
|
312 |
+
# class_value = self.class_dict[im_id]
|
313 |
+
# Tim_value = self.class_dict[Tim_id]
|
314 |
+
|
315 |
+
# # Proceed with your comparison and processing here
|
316 |
+
# if class_value == Tim_value:
|
317 |
+
for i in range(Tdots.shape[0]):
|
318 |
+
if start_H <= min(new_TH - 1, int(Tdots[i][1] * Tscale_factor_h)) < start_H + length and start_W <= min(new_TW - 1, int(Tdots[i][0] * Tscale_factor_w)) < start_W + length:
|
319 |
+
r_density1[min(resize_l-1,int((min(new_TH-1, int(Tdots[i][1] * Tscale_factor_h))-start_H)*resize_l/length))][min(resize_l-1,int((min(new_TW-1,int(Tdots[i][0] * Tscale_factor_w))-start_W)*resize_l/length))]=1
|
320 |
+
r_density1 = torch.from_numpy(r_density1)
|
321 |
+
image_array.append(r_image1)
|
322 |
+
map_array.append(r_density1)
|
323 |
+
|
324 |
+
reresized_image5 = torch.cat((image_array[0][:, blending_l:resize_l-blending_l], image_array[1][:, blending_l: resize_l-blending_l]), 1)
|
325 |
+
reresized_density5 = torch.cat((map_array[0][blending_l:resize_l-blending_l], map_array[1][blending_l: resize_l-blending_l]), 0)
|
326 |
+
for i in range(blending_l):
|
327 |
+
reresized_image5[:, 192+i] = image_array[0][:, resize_l-1-blending_l+i] * (blending_l-i)/(2 * blending_l) + reresized_image5[:, 192+i] * (i+blending_l)/(2*blending_l)
|
328 |
+
reresized_image5[:, 191-i] = image_array[1][:, blending_l-i] * (blending_l-i)/(2*blending_l) + reresized_image5[:, 191-i] * (i+blending_l)/(2*blending_l)
|
329 |
+
reresized_image5 = torch.clamp(reresized_image5, 0, 1)
|
330 |
+
|
331 |
+
reresized_image6 = torch.cat((image_array[2][:, blending_l:resize_l-blending_l], image_array[3][:, blending_l: resize_l-blending_l]), 1)
|
332 |
+
reresized_density6 = torch.cat((map_array[2][blending_l:resize_l-blending_l], map_array[3][blending_l:resize_l-blending_l]), 0)
|
333 |
+
for i in range(blending_l):
|
334 |
+
reresized_image6[:, 192+i] = image_array[2][:, resize_l-1-blending_l+i] * (blending_l-i)/(2*blending_l) + reresized_image6[:, 192+i] * (i+blending_l)/(2*blending_l)
|
335 |
+
reresized_image6[:, 191-i] = image_array[3][:, blending_l-i] * (blending_l-i)/(2*blending_l) + reresized_image6[:, 191-i] * (i+blending_l)/(2*blending_l)
|
336 |
+
reresized_image6 = torch.clamp(reresized_image6, 0, 1)
|
337 |
+
|
338 |
+
reresized_image = torch.cat((reresized_image5[:, :, blending_l:resize_l-blending_l], reresized_image6[:, :, blending_l:resize_l-blending_l]), 2)
|
339 |
+
reresized_density = torch.cat((reresized_density5[:, blending_l:resize_l-blending_l], reresized_density6[:, blending_l:resize_l-blending_l]), 1)
|
340 |
+
for i in range(blending_l):
|
341 |
+
reresized_image[:, :, 192+i] = reresized_image5[:, :, resize_l-1-blending_l+i] * (blending_l-i)/(2*blending_l) + reresized_image[:, :, 192+i] * (i+blending_l)/(2*blending_l)
|
342 |
+
reresized_image[:, :, 191-i] = reresized_image6[:, :, blending_l-i] * (blending_l-i)/(2*blending_l) + reresized_image[:, :, 191-i] * (i+blending_l)/(2*blending_l)
|
343 |
+
reresized_image = torch.clamp(reresized_image, 0, 1)
|
344 |
+
|
345 |
+
else:
|
346 |
+
# Random 384*384 crop in a new_W*384 image and 384*new_W density map
|
347 |
+
start = random.randint(0, new_W - 1 - 383)
|
348 |
+
reresized_image = TF.crop(re_image, 0, start, 384, 384)
|
349 |
+
reresized_density = resized_density[:, start:start + 384]
|
350 |
+
|
351 |
+
else:
|
352 |
+
# Random 384*384 crop in a new_W*384 image and 384*new_W density map
|
353 |
+
for i in range(dots.shape[0]):
|
354 |
+
resized_density[min(new_H - 1, int(dots[i][1] * scale_factor_h))] \
|
355 |
+
[min(new_W - 1, int(dots[i][0] * scale_factor_w))] = 1
|
356 |
+
resized_density = torch.from_numpy(resized_density)
|
357 |
+
start = random.randint(0, new_W - self.max_hw)
|
358 |
+
reresized_image = TF.crop(resized_image, 0, start, self.max_hw, self.max_hw)
|
359 |
+
reresized_density = resized_density[0:self.max_hw, start:start + self.max_hw]
|
360 |
+
|
361 |
+
# Gaussian distribution density map
|
362 |
+
reresized_density = ndimage.gaussian_filter(reresized_density.numpy(), sigma=(1, 1), order=0)
|
363 |
+
|
364 |
+
# Density map scale up
|
365 |
+
reresized_density = reresized_density * 60
|
366 |
+
reresized_density = torch.from_numpy(reresized_density)
|
367 |
+
|
368 |
+
# Crop bboxes and resize as 64x64
|
369 |
+
boxes = list()
|
370 |
+
rects = list()
|
371 |
+
cnt = 0
|
372 |
+
for box in lines_boxes:
|
373 |
+
cnt += 1
|
374 |
+
if cnt > 3:
|
375 |
+
break
|
376 |
+
box2 = [int(k) for k in box]
|
377 |
+
y1 = int(box2[0] * scale_factor_h)
|
378 |
+
x1 = int(box2[1] * scale_factor_w)
|
379 |
+
y2 = int(box2[2] * scale_factor_h)
|
380 |
+
x2 = int(box2[3] * scale_factor_w)
|
381 |
+
# print(y1,x1,y2,x2)
|
382 |
+
if not aug_flag:
|
383 |
+
rects.append(torch.tensor([y1, max(0, x1-start), y2, min(self.max_hw, x2-start)]))
|
384 |
+
bbox = resized_image[:, y1:y2 + 1, x1:x2 + 1]
|
385 |
+
bbox = transforms.Resize((64, 64))(bbox)
|
386 |
+
boxes.append(bbox)
|
387 |
+
boxes = torch.stack(boxes)
|
388 |
+
neg_boxes = list()
|
389 |
+
neg_rects = list()
|
390 |
+
cnt = 0
|
391 |
+
for box in neg_lines_boxes:
|
392 |
+
cnt += 1
|
393 |
+
if cnt > 3:
|
394 |
+
break
|
395 |
+
box2 = [int(k) for k in box]
|
396 |
+
y1 = int(box2[0] * scale_factor_h)
|
397 |
+
x1 = int(box2[1] * scale_factor_w)
|
398 |
+
y2 = int(box2[2] * scale_factor_h)
|
399 |
+
x2 = int(box2[3] * scale_factor_w)
|
400 |
+
# print(y1,x1,y2,x2)
|
401 |
+
if not aug_flag:
|
402 |
+
neg_rects.append(torch.tensor([y1, max(0, x1-start), y2, min(self.max_hw, x2-start)]))
|
403 |
+
neg_bbox = resized_image[:, y1:y2 + 1, x1:x2 + 1]
|
404 |
+
neg_bbox = transforms.Resize((64, 64))(neg_bbox)
|
405 |
+
neg_boxes.append(neg_bbox)
|
406 |
+
neg_boxes = torch.stack(neg_boxes)
|
407 |
+
# if len(boxes) > 0:
|
408 |
+
# boxes = torch.stack(boxes) # 如果 boxes 非空,则正常执行 torch.stack
|
409 |
+
# boxes1 = boxes
|
410 |
+
# else:
|
411 |
+
# boxes = boxes1
|
412 |
+
# pass
|
413 |
+
# # 如果 boxes 为空,您可以选择跳过这个样本,或者提供一个默认的边界框
|
414 |
+
# # 例如,使用一个表示图像全区域的默认边界框
|
415 |
+
# default_box = torch.tensor([[0, 0],[0, 0],0, 0]) # 一个示例的默认边界框,具体值取决于您的应用
|
416 |
+
# boxes = default_box.unsqueeze(0) # 增加一个维度以符合 torch.stack 的要求
|
417 |
+
# # pass
|
418 |
+
if aug_flag:
|
419 |
+
pos = torch.tensor([])
|
420 |
+
else:
|
421 |
+
pos = torch.stack(rects)
|
422 |
+
|
423 |
+
# boxes shape [3,3,64,64], image shape [3,384,384], density shape[384,384]
|
424 |
+
sample = {'image': reresized_image, 'boxes': boxes, 'neg_boxes': neg_boxes, 'pos': pos, 'gt_density': reresized_density, 'm_flag': m_flag}
|
425 |
+
|
426 |
+
return sample
|
427 |
+
|
428 |
+
|
429 |
+
class ResizeValImage(ResizeSomeImage):
|
430 |
+
def __init__(self, args, MAX_HW=384):
|
431 |
+
super().__init__(args)
|
432 |
+
self.max_hw = MAX_HW
|
433 |
+
|
434 |
+
def __call__(self, sample):
|
435 |
+
image, dots, m_flag, lines_boxes, neg_lines_boxes = sample['image'], sample['dots'], sample['m_flag'], sample['lines_boxes'], sample['neg_lines_boxes']
|
436 |
+
|
437 |
+
W, H = image.size
|
438 |
+
|
439 |
+
new_H = new_W = self.max_hw
|
440 |
+
scale_factor_h = float(new_H) / H
|
441 |
+
scale_factor_w = float(new_W) / W
|
442 |
+
resized_image = transforms.Resize((new_H, new_W))(image)
|
443 |
+
resized_image = TTensor(resized_image)
|
444 |
+
|
445 |
+
# Resize density map
|
446 |
+
resized_density = np.zeros((new_H, new_W), dtype='float32')
|
447 |
+
for i in range(dots.shape[0]):
|
448 |
+
resized_density[min(new_H - 1, int(dots[i][1] * scale_factor_h))] \
|
449 |
+
[min(new_W - 1, int(dots[i][0] * scale_factor_w))] = 1
|
450 |
+
# resized_density = ndimage.gaussian_filter(resized_density, sigma=4, radius=7, order=0)
|
451 |
+
resized_density = ndimage.gaussian_filter(resized_density, sigma=4, order=0)
|
452 |
+
resized_density = torch.from_numpy(resized_density) * 60
|
453 |
+
|
454 |
+
# Crop bboxes and resize as 64x64
|
455 |
+
boxes = list()
|
456 |
+
rects = list()
|
457 |
+
cnt = 0
|
458 |
+
for box in lines_boxes:
|
459 |
+
cnt += 1
|
460 |
+
if cnt > 3:
|
461 |
+
break
|
462 |
+
box2 = [int(k) for k in box]
|
463 |
+
y1 = int(box2[0] * scale_factor_h)
|
464 |
+
x1 = int(box2[1] * scale_factor_w)
|
465 |
+
y2 = int(box2[2] * scale_factor_h)
|
466 |
+
x2 = int(box2[3] * scale_factor_w)
|
467 |
+
rects.append(torch.tensor([y1, x1, y2, x2]))
|
468 |
+
bbox = resized_image[:, y1:y2 + 1, x1:x2 + 1]
|
469 |
+
bbox = transforms.Resize((64, 64))(bbox)
|
470 |
+
boxes.append(bbox)
|
471 |
+
boxes = torch.stack(boxes)
|
472 |
+
pos = torch.stack(rects)
|
473 |
+
neg_boxes = list()
|
474 |
+
neg_rects = list()
|
475 |
+
cnt = 0
|
476 |
+
for box in neg_lines_boxes:
|
477 |
+
cnt += 1
|
478 |
+
if cnt > 3:
|
479 |
+
break
|
480 |
+
box2 = [int(k) for k in box]
|
481 |
+
y1 = int(box2[0] * scale_factor_h)
|
482 |
+
x1 = int(box2[1] * scale_factor_w)
|
483 |
+
y2 = int(box2[2] * scale_factor_h)
|
484 |
+
x2 = int(box2[3] * scale_factor_w)
|
485 |
+
neg_rects.append(torch.tensor([y1, x1, y2, x2]))
|
486 |
+
neg_bbox = resized_image[:, y1:y2 + 1, x1:x2 + 1]
|
487 |
+
neg_bbox = transforms.Resize((64, 64))(neg_bbox)
|
488 |
+
neg_boxes.append(neg_bbox)
|
489 |
+
neg_boxes = torch.stack(neg_boxes)
|
490 |
+
# boxes shape [3,3,64,64], image shape [3,384,384], density shape[384,384]
|
491 |
+
sample = {'image': resized_image, 'boxes': boxes, 'neg_boxes': neg_boxes, 'pos': pos, 'gt_density': resized_density, 'm_flag': m_flag}
|
492 |
+
return sample
|
493 |
+
|
494 |
+
|
495 |
+
PreTrainNormalize = transforms.Compose([
|
496 |
+
transforms.RandomResizedCrop(MAX_HW, scale=(0.2, 1.0), interpolation=3),
|
497 |
+
transforms.RandomHorizontalFlip(),
|
498 |
+
transforms.ToTensor(),
|
499 |
+
# transforms.Normalize(mean=IM_NORM_MEAN, std=IM_NORM_STD)
|
500 |
+
])
|
501 |
+
|
502 |
+
TTensor = transforms.Compose([
|
503 |
+
transforms.ToTensor(),
|
504 |
+
])
|
505 |
+
|
506 |
+
Augmentation = transforms.Compose([
|
507 |
+
transforms.ColorJitter(brightness=0.25, contrast=0.15, saturation=0.15, hue=0.15),
|
508 |
+
transforms.GaussianBlur(kernel_size=(7, 9))
|
509 |
+
])
|
510 |
+
|
511 |
+
Normalize = transforms.Compose([
|
512 |
+
transforms.ToTensor(),
|
513 |
+
transforms.Normalize(mean=IM_NORM_MEAN, std=IM_NORM_STD)
|
514 |
+
])
|
515 |
+
|
516 |
+
|
517 |
+
def transform_train(args: Namespace, do_aug=True):
|
518 |
+
return transforms.Compose([ResizeTrainImage(args, MAX_HW, do_aug)])
|
519 |
+
|
520 |
+
def transform_val(args: Namespace):
|
521 |
+
return transforms.Compose([ResizeValImage(args, MAX_HW)])
|
522 |
+
|
523 |
+
def transform_pre_train(args: Namespace):
|
524 |
+
return transforms.Compose([ResizePreTrainImage(args, MAX_HW)])
|
util/__pycache__/FSC147.cpython-38.pyc
ADDED
Binary file (15.3 kB). View file
|
|
util/__pycache__/FSC147.cpython-39.pyc
ADDED
Binary file (14.4 kB). View file
|
|
util/__pycache__/FSC147_test.cpython-38.pyc
ADDED
Binary file (16.6 kB). View file
|
|
util/__pycache__/lr_sched.cpython-38.pyc
ADDED
Binary file (628 Bytes). View file
|
|
util/__pycache__/lr_sched.cpython-39.pyc
ADDED
Binary file (628 Bytes). View file
|
|
util/__pycache__/misc.cpython-38.pyc
ADDED
Binary file (19.5 kB). View file
|
|
util/__pycache__/misc.cpython-39.pyc
ADDED
Binary file (19.4 kB). View file
|
|
util/__pycache__/pos_embed.cpython-38.pyc
ADDED
Binary file (2.41 kB). View file
|
|
util/__pycache__/pos_embed.cpython-39.pyc
ADDED
Binary file (2.39 kB). View file
|
|
util/crop.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class RandomResizedCrop(transforms.RandomResizedCrop):
|
16 |
+
"""
|
17 |
+
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
|
18 |
+
This may lead to results different with torchvision's version.
|
19 |
+
Following BYOL's TF code:
|
20 |
+
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
|
21 |
+
"""
|
22 |
+
@staticmethod
|
23 |
+
def get_params(img, scale, ratio):
|
24 |
+
width, height = F._get_image_size(img)
|
25 |
+
area = height * width
|
26 |
+
|
27 |
+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
28 |
+
log_ratio = torch.log(torch.tensor(ratio))
|
29 |
+
aspect_ratio = torch.exp(
|
30 |
+
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
|
31 |
+
).item()
|
32 |
+
|
33 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
34 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
35 |
+
|
36 |
+
w = min(w, width)
|
37 |
+
h = min(h, height)
|
38 |
+
|
39 |
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
40 |
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
41 |
+
|
42 |
+
return i, j, h, w
|
util/datasets.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
9 |
+
# --------------------------------------------------------
|
10 |
+
|
11 |
+
import os
|
12 |
+
import PIL
|
13 |
+
|
14 |
+
from torchvision import datasets, transforms
|
15 |
+
|
16 |
+
from timm.data import create_transform
|
17 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
18 |
+
|
19 |
+
|
20 |
+
def build_dataset(is_train, args):
|
21 |
+
transform = build_transform(is_train, args)
|
22 |
+
|
23 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
24 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
25 |
+
|
26 |
+
print(dataset)
|
27 |
+
|
28 |
+
return dataset
|
29 |
+
|
30 |
+
|
31 |
+
def build_transform(is_train, args):
|
32 |
+
mean = IMAGENET_DEFAULT_MEAN
|
33 |
+
std = IMAGENET_DEFAULT_STD
|
34 |
+
# train transform
|
35 |
+
if is_train:
|
36 |
+
# this should always dispatch to transforms_imagenet_train
|
37 |
+
transform = create_transform(
|
38 |
+
input_size=args.input_size,
|
39 |
+
is_training=True,
|
40 |
+
color_jitter=args.color_jitter,
|
41 |
+
auto_augment=args.aa,
|
42 |
+
interpolation='bicubic',
|
43 |
+
re_prob=args.reprob,
|
44 |
+
re_mode=args.remode,
|
45 |
+
re_count=args.recount,
|
46 |
+
mean=mean,
|
47 |
+
std=std,
|
48 |
+
)
|
49 |
+
return transform
|
50 |
+
|
51 |
+
# eval transform
|
52 |
+
t = []
|
53 |
+
if args.input_size <= 224:
|
54 |
+
crop_pct = 224 / 256
|
55 |
+
else:
|
56 |
+
crop_pct = 1.0
|
57 |
+
size = int(args.input_size / crop_pct)
|
58 |
+
t.append(
|
59 |
+
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
|
60 |
+
)
|
61 |
+
t.append(transforms.CenterCrop(args.input_size))
|
62 |
+
|
63 |
+
t.append(transforms.ToTensor())
|
64 |
+
t.append(transforms.Normalize(mean, std))
|
65 |
+
return transforms.Compose(t)
|
util/lars.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# LARS optimizer, implementation from MoCo v3:
|
8 |
+
# https://github.com/facebookresearch/moco-v3
|
9 |
+
# --------------------------------------------------------
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
class LARS(torch.optim.Optimizer):
|
15 |
+
"""
|
16 |
+
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
17 |
+
"""
|
18 |
+
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
19 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
20 |
+
super().__init__(params, defaults)
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def step(self):
|
24 |
+
for g in self.param_groups:
|
25 |
+
for p in g['params']:
|
26 |
+
dp = p.grad
|
27 |
+
|
28 |
+
if dp is None:
|
29 |
+
continue
|
30 |
+
|
31 |
+
if p.ndim > 1: # if not normalization gamma/beta or bias
|
32 |
+
dp = dp.add(p, alpha=g['weight_decay'])
|
33 |
+
param_norm = torch.norm(p)
|
34 |
+
update_norm = torch.norm(dp)
|
35 |
+
one = torch.ones_like(param_norm)
|
36 |
+
q = torch.where(param_norm > 0.,
|
37 |
+
torch.where(update_norm > 0,
|
38 |
+
(g['trust_coefficient'] * param_norm / update_norm), one),
|
39 |
+
one)
|
40 |
+
dp = dp.mul(q)
|
41 |
+
|
42 |
+
param_state = self.state[p]
|
43 |
+
if 'mu' not in param_state:
|
44 |
+
param_state['mu'] = torch.zeros_like(p)
|
45 |
+
mu = param_state['mu']
|
46 |
+
mu.mul_(g['momentum']).add_(dp)
|
47 |
+
p.add_(mu, alpha=-g['lr'])
|
util/lr_decay.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# ELECTRA https://github.com/google-research/electra
|
9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import json
|
13 |
+
|
14 |
+
|
15 |
+
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
|
16 |
+
"""
|
17 |
+
Parameter groups for layer-wise lr decay
|
18 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
19 |
+
"""
|
20 |
+
param_group_names = {}
|
21 |
+
param_groups = {}
|
22 |
+
|
23 |
+
num_layers = len(model.blocks) + 1
|
24 |
+
|
25 |
+
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
|
26 |
+
|
27 |
+
for n, p in model.named_parameters():
|
28 |
+
if not p.requires_grad:
|
29 |
+
continue
|
30 |
+
|
31 |
+
# no decay: all 1D parameters and model specific ones
|
32 |
+
if p.ndim == 1 or n in no_weight_decay_list:
|
33 |
+
g_decay = "no_decay"
|
34 |
+
this_decay = 0.
|
35 |
+
else:
|
36 |
+
g_decay = "decay"
|
37 |
+
this_decay = weight_decay
|
38 |
+
|
39 |
+
layer_id = get_layer_id_for_vit(n, num_layers)
|
40 |
+
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
41 |
+
|
42 |
+
if group_name not in param_group_names:
|
43 |
+
this_scale = layer_scales[layer_id]
|
44 |
+
|
45 |
+
param_group_names[group_name] = {
|
46 |
+
"lr_scale": this_scale,
|
47 |
+
"weight_decay": this_decay,
|
48 |
+
"params": [],
|
49 |
+
}
|
50 |
+
param_groups[group_name] = {
|
51 |
+
"lr_scale": this_scale,
|
52 |
+
"weight_decay": this_decay,
|
53 |
+
"params": [],
|
54 |
+
}
|
55 |
+
|
56 |
+
param_group_names[group_name]["params"].append(n)
|
57 |
+
param_groups[group_name]["params"].append(p)
|
58 |
+
|
59 |
+
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
60 |
+
|
61 |
+
return list(param_groups.values())
|
62 |
+
|
63 |
+
|
64 |
+
def get_layer_id_for_vit(name, num_layers):
|
65 |
+
"""
|
66 |
+
Assign a parameter with its layer id
|
67 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
68 |
+
"""
|
69 |
+
if name in ['cls_token', 'pos_embed']:
|
70 |
+
return 0
|
71 |
+
elif name.startswith('patch_embed'):
|
72 |
+
return 0
|
73 |
+
elif name.startswith('blocks'):
|
74 |
+
return int(name.split('.')[1]) + 1
|
75 |
+
else:
|
76 |
+
return num_layers
|
util/lr_sched.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
10 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
11 |
+
if epoch < args.warmup_epochs:
|
12 |
+
lr = args.lr * epoch / args.warmup_epochs
|
13 |
+
else:
|
14 |
+
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
|
15 |
+
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
16 |
+
for param_group in optimizer.param_groups:
|
17 |
+
if "lr_scale" in param_group:
|
18 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
19 |
+
else:
|
20 |
+
param_group["lr"] = lr
|
21 |
+
return lr
|
util/misc.py
ADDED
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import builtins
|
13 |
+
import datetime
|
14 |
+
import os
|
15 |
+
import time
|
16 |
+
import json
|
17 |
+
from collections import defaultdict, deque
|
18 |
+
from pathlib import Path
|
19 |
+
# from typing import Union
|
20 |
+
|
21 |
+
import pandas as pd
|
22 |
+
import torch
|
23 |
+
import torch.distributed as dist
|
24 |
+
import wandb
|
25 |
+
# from torch._six import inf
|
26 |
+
from torch import inf
|
27 |
+
import matplotlib.pyplot as plt
|
28 |
+
from torchvision import transforms
|
29 |
+
import cv2
|
30 |
+
from tqdm import tqdm
|
31 |
+
from typing import Union, List
|
32 |
+
|
33 |
+
class SmoothedValue(object):
|
34 |
+
"""Track a series of values and provide access to smoothed values over a
|
35 |
+
window or the global series average.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, window_size=20, fmt=None):
|
39 |
+
if fmt is None:
|
40 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
41 |
+
self.deque = deque(maxlen=window_size)
|
42 |
+
self.total = 0.0
|
43 |
+
self.count = 0
|
44 |
+
self.fmt = fmt
|
45 |
+
|
46 |
+
def update(self, value, n=1):
|
47 |
+
self.deque.append(value)
|
48 |
+
self.count += n
|
49 |
+
self.total += value * n
|
50 |
+
|
51 |
+
def synchronize_between_processes(self):
|
52 |
+
"""
|
53 |
+
Warning: does not synchronize the deque!
|
54 |
+
"""
|
55 |
+
if not is_dist_avail_and_initialized():
|
56 |
+
return
|
57 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
58 |
+
dist.barrier()
|
59 |
+
dist.all_reduce(t)
|
60 |
+
t = t.tolist()
|
61 |
+
self.count = int(t[0])
|
62 |
+
self.total = t[1]
|
63 |
+
|
64 |
+
@property
|
65 |
+
def median(self):
|
66 |
+
d = torch.tensor(list(self.deque))
|
67 |
+
return d.median().item()
|
68 |
+
|
69 |
+
@property
|
70 |
+
def avg(self):
|
71 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
72 |
+
return d.mean().item()
|
73 |
+
|
74 |
+
@property
|
75 |
+
def global_avg(self):
|
76 |
+
if self.count == 0:
|
77 |
+
# Return a default value or handle the zero count scenario
|
78 |
+
return 0 # Or any other default value or handling mechanism
|
79 |
+
else:
|
80 |
+
return self.total / self.count
|
81 |
+
# return self.total / self.count
|
82 |
+
|
83 |
+
@property
|
84 |
+
def max(self):
|
85 |
+
return max(self.deque)
|
86 |
+
|
87 |
+
@property
|
88 |
+
def value(self):
|
89 |
+
return self.deque[-1]
|
90 |
+
|
91 |
+
def __str__(self):
|
92 |
+
return self.fmt.format(
|
93 |
+
median=self.median,
|
94 |
+
avg=self.avg,
|
95 |
+
global_avg=self.global_avg,
|
96 |
+
max=self.max,
|
97 |
+
value=self.value)
|
98 |
+
|
99 |
+
|
100 |
+
class MetricLogger(object):
|
101 |
+
def __init__(self, delimiter="\t"):
|
102 |
+
self.meters = defaultdict(SmoothedValue)
|
103 |
+
self.delimiter = delimiter
|
104 |
+
|
105 |
+
def update(self, **kwargs):
|
106 |
+
for k, v in kwargs.items():
|
107 |
+
if v is None:
|
108 |
+
continue
|
109 |
+
if isinstance(v, torch.Tensor):
|
110 |
+
v = v.item()
|
111 |
+
assert isinstance(v, (float, int))
|
112 |
+
self.meters[k].update(v)
|
113 |
+
|
114 |
+
def __getattr__(self, attr):
|
115 |
+
if attr in self.meters:
|
116 |
+
return self.meters[attr]
|
117 |
+
if attr in self.__dict__:
|
118 |
+
return self.__dict__[attr]
|
119 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
120 |
+
type(self).__name__, attr))
|
121 |
+
|
122 |
+
def __str__(self):
|
123 |
+
loss_str = []
|
124 |
+
for name, meter in self.meters.items():
|
125 |
+
loss_str.append(
|
126 |
+
"{}: {}".format(name, str(meter))
|
127 |
+
)
|
128 |
+
return self.delimiter.join(loss_str)
|
129 |
+
|
130 |
+
def synchronize_between_processes(self):
|
131 |
+
for meter in self.meters.values():
|
132 |
+
meter.synchronize_between_processes()
|
133 |
+
|
134 |
+
def add_meter(self, name, meter):
|
135 |
+
self.meters[name] = meter
|
136 |
+
|
137 |
+
def log_every(self, iterable, print_freq, header=None):
|
138 |
+
i = 0
|
139 |
+
if not header:
|
140 |
+
header = ''
|
141 |
+
start_time = time.time()
|
142 |
+
end = time.time()
|
143 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
144 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
145 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
146 |
+
log_msg = [
|
147 |
+
header,
|
148 |
+
'[{0' + space_fmt + '}/{1}]',
|
149 |
+
'eta: {eta}',
|
150 |
+
'{meters}',
|
151 |
+
'time: {time}',
|
152 |
+
'data: {data}'
|
153 |
+
]
|
154 |
+
if torch.cuda.is_available():
|
155 |
+
log_msg.append('max mem: {memory:.0f}')
|
156 |
+
log_msg = self.delimiter.join(log_msg)
|
157 |
+
MB = 1024.0 * 1024.0
|
158 |
+
for obj in iterable:
|
159 |
+
data_time.update(time.time() - end)
|
160 |
+
yield obj
|
161 |
+
iter_time.update(time.time() - end)
|
162 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
163 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
164 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
165 |
+
if torch.cuda.is_available():
|
166 |
+
print(log_msg.format(
|
167 |
+
i, len(iterable), eta=eta_string,
|
168 |
+
meters=str(self),
|
169 |
+
time=str(iter_time), data=str(data_time),
|
170 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
171 |
+
else:
|
172 |
+
print(log_msg.format(
|
173 |
+
i, len(iterable), eta=eta_string,
|
174 |
+
meters=str(self),
|
175 |
+
time=str(iter_time), data=str(data_time)))
|
176 |
+
i += 1
|
177 |
+
end = time.time()
|
178 |
+
total_time = time.time() - start_time
|
179 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
180 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
181 |
+
header, total_time_str, total_time / len(iterable)))
|
182 |
+
|
183 |
+
|
184 |
+
def setup_for_distributed(is_master):
|
185 |
+
"""
|
186 |
+
This function disables printing when not in master process
|
187 |
+
"""
|
188 |
+
builtin_print = builtins.print
|
189 |
+
|
190 |
+
def print(*args, **kwargs):
|
191 |
+
force = kwargs.pop('force', False)
|
192 |
+
force = force or (get_world_size() > 8)
|
193 |
+
if is_master or force:
|
194 |
+
now = datetime.datetime.now().time()
|
195 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
196 |
+
builtin_print(*args, **kwargs)
|
197 |
+
|
198 |
+
builtins.print = print
|
199 |
+
|
200 |
+
|
201 |
+
def is_dist_avail_and_initialized():
|
202 |
+
if not dist.is_available():
|
203 |
+
return False
|
204 |
+
if not dist.is_initialized():
|
205 |
+
return False
|
206 |
+
return True
|
207 |
+
|
208 |
+
|
209 |
+
def get_world_size():
|
210 |
+
if not is_dist_avail_and_initialized():
|
211 |
+
return 1
|
212 |
+
return dist.get_world_size()
|
213 |
+
|
214 |
+
|
215 |
+
def get_rank():
|
216 |
+
if not is_dist_avail_and_initialized():
|
217 |
+
return 0
|
218 |
+
return dist.get_rank()
|
219 |
+
|
220 |
+
|
221 |
+
def is_main_process():
|
222 |
+
return get_rank() == 0
|
223 |
+
|
224 |
+
|
225 |
+
def save_on_master(*args, **kwargs):
|
226 |
+
if is_main_process():
|
227 |
+
torch.save(*args, **kwargs)
|
228 |
+
|
229 |
+
|
230 |
+
def init_distributed_mode(args):
|
231 |
+
if args.dist_on_itp:
|
232 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
233 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
234 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
235 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
236 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
237 |
+
os.environ['RANK'] = str(args.rank)
|
238 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
239 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
240 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
241 |
+
args.rank = int(os.environ["RANK"])
|
242 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
243 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
244 |
+
elif 'SLURM_PROCID' in os.environ:
|
245 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
246 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
247 |
+
else:
|
248 |
+
print('Not using distributed mode')
|
249 |
+
setup_for_distributed(is_master=True) # hack
|
250 |
+
args.distributed = False
|
251 |
+
return
|
252 |
+
|
253 |
+
args.distributed = True
|
254 |
+
|
255 |
+
torch.cuda.set_device(args.gpu)
|
256 |
+
args.dist_backend = 'nccl'
|
257 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
258 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
259 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
260 |
+
world_size=args.world_size, rank=args.rank)
|
261 |
+
torch.distributed.barrier()
|
262 |
+
setup_for_distributed(args.rank == 0)
|
263 |
+
|
264 |
+
|
265 |
+
class NativeScalerWithGradNormCount:
|
266 |
+
state_dict_key = "amp_scaler"
|
267 |
+
|
268 |
+
def __init__(self):
|
269 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
270 |
+
|
271 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
272 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
273 |
+
if update_grad:
|
274 |
+
if clip_grad is not None:
|
275 |
+
assert parameters is not None
|
276 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
277 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
278 |
+
else:
|
279 |
+
self._scaler.unscale_(optimizer)
|
280 |
+
norm = get_grad_norm_(parameters)
|
281 |
+
self._scaler.step(optimizer)
|
282 |
+
self._scaler.update()
|
283 |
+
else:
|
284 |
+
norm = None
|
285 |
+
return norm
|
286 |
+
|
287 |
+
def state_dict(self):
|
288 |
+
return self._scaler.state_dict()
|
289 |
+
|
290 |
+
def load_state_dict(self, state_dict):
|
291 |
+
self._scaler.load_state_dict(state_dict)
|
292 |
+
|
293 |
+
|
294 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
295 |
+
if isinstance(parameters, torch.Tensor):
|
296 |
+
parameters = [parameters]
|
297 |
+
parameters = [p for p in parameters if p.grad is not None]
|
298 |
+
norm_type = float(norm_type)
|
299 |
+
if len(parameters) == 0:
|
300 |
+
return torch.tensor(0.)
|
301 |
+
device = parameters[0].grad.device
|
302 |
+
if norm_type == inf:
|
303 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
304 |
+
else:
|
305 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
306 |
+
return total_norm
|
307 |
+
|
308 |
+
|
309 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, suffix="", upload=True):
|
310 |
+
if suffix:
|
311 |
+
suffix = f"__{suffix}"
|
312 |
+
output_dir = Path(args.output_dir)
|
313 |
+
ckpt_name = f"checkpoint{suffix}.pth"
|
314 |
+
if loss_scaler is not None:
|
315 |
+
checkpoint_paths = [output_dir / ckpt_name]
|
316 |
+
for checkpoint_path in checkpoint_paths:
|
317 |
+
to_save = {
|
318 |
+
'model': model_without_ddp.state_dict(),
|
319 |
+
'optimizer': optimizer.state_dict(),
|
320 |
+
'epoch': epoch,
|
321 |
+
'scaler': loss_scaler.state_dict(),
|
322 |
+
'args': args,
|
323 |
+
}
|
324 |
+
save_on_master(to_save, checkpoint_path)
|
325 |
+
if upload and is_main_process():
|
326 |
+
log_wandb_model(f"checkpoint{suffix}", checkpoint_path, epoch)
|
327 |
+
print("checkpoint sent to W&B (if)")
|
328 |
+
else:
|
329 |
+
client_state = {'epoch': epoch}
|
330 |
+
model.save_checkpoint(save_dir=args.output_dir, tag=ckpt_name, client_state=client_state)
|
331 |
+
if upload and is_main_process():
|
332 |
+
log_wandb_model(f"checkpoint{suffix}", output_dir / ckpt_name, epoch)
|
333 |
+
print("checkpoint sent to W&B (else)")
|
334 |
+
|
335 |
+
|
336 |
+
def log_wandb_model(title, path, epoch):
|
337 |
+
artifact = wandb.Artifact(title, type="model")
|
338 |
+
artifact.add_file(path)
|
339 |
+
artifact.metadata["epoch"] = epoch
|
340 |
+
wandb.log_artifact(artifact_or_path=artifact, name=title)
|
341 |
+
|
342 |
+
|
343 |
+
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
344 |
+
if args.resume:
|
345 |
+
if args.resume.startswith('https'):
|
346 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
347 |
+
args.resume, map_location='cpu', check_hash=True)
|
348 |
+
else:
|
349 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
350 |
+
|
351 |
+
if 'pos_embed' in checkpoint['model'] and checkpoint['model']['pos_embed'].shape != model_without_ddp.state_dict()['pos_embed'].shape:
|
352 |
+
print(f"Removing key pos_embed from pretrained checkpoint")
|
353 |
+
del checkpoint['model']['pos_embed']
|
354 |
+
|
355 |
+
if 'decoder_pos_embed' in checkpoint['model'] and checkpoint['model']['decoder_pos_embed'].shape != model_without_ddp.state_dict()['decoder_pos_embed'].shape:
|
356 |
+
print(f"Removing key decoder_pos_embed from pretrained checkpoint")
|
357 |
+
del checkpoint['model']['decoder_pos_embed']
|
358 |
+
|
359 |
+
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
360 |
+
print("Resume checkpoint %s" % args.resume)
|
361 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
|
362 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
363 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
364 |
+
if 'scaler' in checkpoint:
|
365 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
366 |
+
print("With optim & sched!")
|
367 |
+
|
368 |
+
def load_model_FSC(args, model_without_ddp):
|
369 |
+
if args.resume:
|
370 |
+
if args.resume.startswith('https'):
|
371 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
372 |
+
args.resume, map_location='cpu', check_hash=True)
|
373 |
+
else:
|
374 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
375 |
+
|
376 |
+
if 'pos_embed' in checkpoint['model'] and checkpoint['model']['pos_embed'].shape != model_without_ddp.state_dict()['pos_embed'].shape:
|
377 |
+
print(f"Removing key pos_embed from pretrained checkpoint")
|
378 |
+
del checkpoint['model']['pos_embed']
|
379 |
+
|
380 |
+
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
381 |
+
print(f"Resume checkpoint {args.resume} ({checkpoint['epoch']})")
|
382 |
+
|
383 |
+
def load_model_FSC1(args, model_without_ddp):
|
384 |
+
if args.resume:
|
385 |
+
if args.resume.startswith('https'):
|
386 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
387 |
+
args.resume, map_location='cpu', check_hash=True)
|
388 |
+
else:
|
389 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
390 |
+
#model = timm.create_model('vit_base_patch16_224', pretrained=True)
|
391 |
+
#torch.save(model.state_dict(), './output_abnopre_dir/checkpoint-6657.pth')
|
392 |
+
checkpoint1 = torch.load('./output_abnopre_dir/checkpoint-6657.pth', map_location='cpu')
|
393 |
+
|
394 |
+
if 'pos_embed' in checkpoint['model'] and checkpoint['model']['pos_embed'].shape != model_without_ddp.state_dict()['pos_embed'].shape:
|
395 |
+
print(f"Removing key pos_embed from pretrained checkpoint")
|
396 |
+
del checkpoint['model']['pos_embed']
|
397 |
+
|
398 |
+
del checkpoint1['cls_token'],checkpoint1['pos_embed']
|
399 |
+
|
400 |
+
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
401 |
+
model_without_ddp.load_state_dict(checkpoint1, strict=False)
|
402 |
+
print("Resume checkpoint %s" % args.resume)
|
403 |
+
|
404 |
+
|
405 |
+
def load_model_FSC_full(args, model_without_ddp, optimizer, loss_scaler):
|
406 |
+
if args.resume:
|
407 |
+
if args.resume.startswith('https'):
|
408 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
409 |
+
args.resume, map_location='cpu', check_hash=True)
|
410 |
+
else:
|
411 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
412 |
+
|
413 |
+
if 'pos_embed' in checkpoint['model'] and checkpoint['model']['pos_embed'].shape != \
|
414 |
+
model_without_ddp.state_dict()['pos_embed'].shape:
|
415 |
+
print(f"Removing key pos_embed from pretrained checkpoint")
|
416 |
+
del checkpoint['model']['pos_embed']
|
417 |
+
|
418 |
+
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
419 |
+
print("Resume checkpoint %s" % args.resume)
|
420 |
+
|
421 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint and args.do_resume:
|
422 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
423 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
424 |
+
if 'scaler' in checkpoint:
|
425 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
426 |
+
print("With optim & scheduler!")
|
427 |
+
|
428 |
+
|
429 |
+
def all_reduce_mean(x):
|
430 |
+
world_size = get_world_size()
|
431 |
+
if world_size > 1:
|
432 |
+
x_reduce = torch.tensor(x).cuda()
|
433 |
+
dist.all_reduce(x_reduce)
|
434 |
+
x_reduce /= world_size
|
435 |
+
return x_reduce.item()
|
436 |
+
else:
|
437 |
+
return x
|
438 |
+
|
439 |
+
|
440 |
+
def plot_counts(res_csv: Union[str, List[str]], output_dir: str, suffix: str = "", smooth: bool = False):
|
441 |
+
if suffix:
|
442 |
+
suffix = f"_{suffix}"
|
443 |
+
if smooth:
|
444 |
+
suffix = f"_smooth{suffix}"
|
445 |
+
if type(res_csv) == str:
|
446 |
+
res_csv = [res_csv]
|
447 |
+
|
448 |
+
plt.figure(figsize=(15, 5))
|
449 |
+
|
450 |
+
for res in res_csv:
|
451 |
+
name = Path(res).parent.name
|
452 |
+
df = pd.read_csv(res)
|
453 |
+
print(df)
|
454 |
+
|
455 |
+
df.sort_values(by="name", inplace=True)
|
456 |
+
df.reset_index(drop=True, inplace=True)
|
457 |
+
df.index += 1
|
458 |
+
print(df)
|
459 |
+
|
460 |
+
if smooth:
|
461 |
+
time_arr = df.index[5:-5]
|
462 |
+
smooth_pred_mean = df['prediction'].iloc[5:-5].rolling(25).mean()
|
463 |
+
smooth_pred_std = df['prediction'].iloc[5:-5].rolling(25).std()
|
464 |
+
plt.plot(time_arr, smooth_pred_mean, label=name)
|
465 |
+
plt.fill_between(time_arr, smooth_pred_mean + smooth_pred_std, smooth_pred_mean - smooth_pred_std, alpha=.2)
|
466 |
+
plt.xlabel('Frame')
|
467 |
+
plt.ylabel('Count')
|
468 |
+
else:
|
469 |
+
plt.plot(df.index, df['prediction'], label=name)
|
470 |
+
|
471 |
+
plt.legend()
|
472 |
+
plt.savefig(os.path.join(output_dir, f'counts{suffix}.png'), dpi=300)
|
473 |
+
|
474 |
+
|
475 |
+
def write_zeroshot_annotations(p: Path):
|
476 |
+
with open(p / 'annotations.json', 'a') as split:
|
477 |
+
split.write('{\n')
|
478 |
+
for img in p.iterdir():
|
479 |
+
if img.is_file():
|
480 |
+
split.write(f' "{img.name}": {{\n' \
|
481 |
+
' "H": 960,\n' \
|
482 |
+
' "W": 1280,\n' \
|
483 |
+
' "box_examples_coordinates": [],\n' \
|
484 |
+
' "points": []\n' \
|
485 |
+
' },\n')
|
486 |
+
split.write("}")
|
487 |
+
|
488 |
+
with open(p / 'split.json', 'a') as split:
|
489 |
+
split.write('{\n "test":\n [\n')
|
490 |
+
for img in p.iterdir():
|
491 |
+
if img.is_file():
|
492 |
+
split.write(f' "{img.name}",\n')
|
493 |
+
split.write(" ]\n}")
|
494 |
+
|
495 |
+
|
496 |
+
def make_grid(imgs, h, w):
|
497 |
+
assert len(imgs) == 9
|
498 |
+
rows = []
|
499 |
+
for i in range(0, 9, 3):
|
500 |
+
row = torch.cat((imgs[i], imgs[i + 1], imgs[i + 2]), -1)
|
501 |
+
rows += [row]
|
502 |
+
grid = torch.cat((rows[0], rows[1], rows[2]), 0)
|
503 |
+
grid = transforms.Resize((h, w))(grid.unsqueeze(0))
|
504 |
+
return grid.squeeze(0)
|
505 |
+
|
506 |
+
|
507 |
+
def min_max(t):
|
508 |
+
t_shape = t.shape
|
509 |
+
t = t.view(t_shape[0], -1)
|
510 |
+
t -= t.min(1, keepdim=True)[0]
|
511 |
+
t /= t.max(1, keepdim=True)[0]
|
512 |
+
t = t.view(*t_shape)
|
513 |
+
return t
|
514 |
+
|
515 |
+
|
516 |
+
def min_max_np(v, new_min=0, new_max=1):
|
517 |
+
v_min, v_max = v.min(), v.max()
|
518 |
+
return (v - v_min) / (v_max - v_min) * (new_max - new_min) + new_min
|
519 |
+
|
520 |
+
|
521 |
+
def get_box_map(sample, pos, device, external=False):
|
522 |
+
box_map = torch.zeros([sample.shape[1], sample.shape[2]], device=device)
|
523 |
+
if external is False:
|
524 |
+
for rect in pos:
|
525 |
+
for i in range(rect[2] - rect[0]):
|
526 |
+
box_map[min(rect[0] + i, sample.shape[1] - 1), min(rect[1], sample.shape[2] - 1)] = 10
|
527 |
+
box_map[min(rect[0] + i, sample.shape[1] - 1), min(rect[3], sample.shape[2] - 1)] = 10
|
528 |
+
for i in range(rect[3] - rect[1]):
|
529 |
+
box_map[min(rect[0], sample.shape[1] - 1), min(rect[1] + i, sample.shape[2] - 1)] = 10
|
530 |
+
box_map[min(rect[2], sample.shape[1] - 1), min(rect[1] + i, sample.shape[2] - 1)] = 10
|
531 |
+
box_map = box_map.unsqueeze(0).repeat(3, 1, 1)
|
532 |
+
return box_map
|
533 |
+
|
534 |
+
|
535 |
+
timerfunc = time.perf_counter
|
536 |
+
|
537 |
+
class measure_time(object):
|
538 |
+
def __enter__(self):
|
539 |
+
self.start = timerfunc()
|
540 |
+
return self
|
541 |
+
|
542 |
+
def __exit__(self, typ, value, traceback):
|
543 |
+
self.duration = timerfunc() - self.start
|
544 |
+
|
545 |
+
def __add__(self, other):
|
546 |
+
return self.duration + other.duration
|
547 |
+
|
548 |
+
def __sub__(self, other):
|
549 |
+
return self.duration - other.duration
|
550 |
+
|
551 |
+
def __str__(self):
|
552 |
+
return str(self.duration)
|
553 |
+
|
554 |
+
|
555 |
+
def log_test_results(test_dir):
|
556 |
+
test_dir = Path(test_dir)
|
557 |
+
logs = []
|
558 |
+
for d in test_dir.iterdir():
|
559 |
+
if d.is_dir() and (d / "log.txt").exists():
|
560 |
+
print(d.name)
|
561 |
+
with open(d / "log.txt") as f:
|
562 |
+
last = f.readlines()[-1]
|
563 |
+
j = json.loads(last)
|
564 |
+
j['name'] = d.name
|
565 |
+
logs.append(j)
|
566 |
+
df = pd.DataFrame(logs)
|
567 |
+
|
568 |
+
df.sort_values('name', inplace=True, ignore_index=True)
|
569 |
+
cols = list(df.columns)
|
570 |
+
cols = cols[-1:] + cols[:-1]
|
571 |
+
df = df[cols]
|
572 |
+
|
573 |
+
df.to_csv(test_dir / "logs.csv", index=False)
|
574 |
+
|
575 |
+
|
576 |
+
COLORS = {
|
577 |
+
'muted blue': '#1f77b4',
|
578 |
+
'safety orange': '#ff7f0e',
|
579 |
+
'cooked asparagus green': '#2ca02c',
|
580 |
+
'brick red': '#d62728',
|
581 |
+
'muted purple': '#9467bd',
|
582 |
+
'chestnut brown': '#8c564b',
|
583 |
+
'raspberry yogurt pink': '#e377c2',
|
584 |
+
'middle gray': '#7f7f7f',
|
585 |
+
'curry yellow-green': '#bcbd22',
|
586 |
+
'blue-teal': '#17becf',
|
587 |
+
'muted blue light': '#419ede',
|
588 |
+
'safety orange light': '#ffa85b',
|
589 |
+
'cooked asparagus green light': '#4bce4b',
|
590 |
+
'brick red light': '#e36667'
|
591 |
+
}
|
592 |
+
|
593 |
+
|
594 |
+
def plot_test_results(test_dir):
|
595 |
+
import plotly.graph_objects as go
|
596 |
+
|
597 |
+
test_dir = Path(test_dir)
|
598 |
+
df = pd.read_csv(test_dir / "logs.csv")
|
599 |
+
df.sort_values('name', inplace=True)
|
600 |
+
|
601 |
+
fig = go.Figure()
|
602 |
+
fig.add_trace(go.Scatter(x=df['name'], y=df['MAE'], line_color=COLORS['muted blue'],
|
603 |
+
mode='lines', name='MAE'))
|
604 |
+
fig.add_trace(go.Scatter(x=df['name'], y=df['RMSE'], line_color=COLORS['safety orange'],
|
605 |
+
mode='lines', name='RMSE'))
|
606 |
+
fig.add_trace(go.Scatter(x=df['name'], y=df['NAE'], line_color=COLORS['cooked asparagus green'],
|
607 |
+
mode='lines', name='NAE'))
|
608 |
+
|
609 |
+
fig.update_yaxes(type="log")
|
610 |
+
fig.write_image(test_dir / "plot.jpeg", scale=4)
|
611 |
+
fig.write_html(test_dir / "plot.html", auto_open=False)
|
612 |
+
|
613 |
+
|
614 |
+
def frames2vid(input_dir: str, output_file: str, pattern: str, fps: int, h=720, w=1280):
|
615 |
+
input_dir = Path(input_dir)
|
616 |
+
video_file = None
|
617 |
+
files = sorted(input_dir.glob(pattern))
|
618 |
+
video_file = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
619 |
+
for img in tqdm(files, total=len(files)):
|
620 |
+
frame = cv2.imread(str(img))
|
621 |
+
frame = cv2.resize(frame, (w, h))
|
622 |
+
video_file.write(frame)
|
623 |
+
|
624 |
+
video_file.release()
|
util/pos_embed.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# Position embedding utils
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# 2D sine-cosine position embedding
|
16 |
+
# References:
|
17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
19 |
+
# --------------------------------------------------------
|
20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
21 |
+
"""
|
22 |
+
grid_size: int of the grid height and width
|
23 |
+
return:
|
24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
25 |
+
"""
|
26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
29 |
+
grid = np.stack(grid, axis=0)
|
30 |
+
|
31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
33 |
+
if cls_token:
|
34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
35 |
+
return pos_embed
|
36 |
+
|
37 |
+
|
38 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
39 |
+
assert embed_dim % 2 == 0
|
40 |
+
|
41 |
+
# use half of dimensions to encode grid_h
|
42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
44 |
+
|
45 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
46 |
+
return emb
|
47 |
+
|
48 |
+
|
49 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
50 |
+
"""
|
51 |
+
embed_dim: output dimension for each position
|
52 |
+
pos: a list of positions to be encoded: size (M,)
|
53 |
+
out: (M, D)
|
54 |
+
"""
|
55 |
+
assert embed_dim % 2 == 0
|
56 |
+
# omega = np.arange(embed_dim // 2, dtype=np.float)
|
57 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
58 |
+
omega /= embed_dim / 2.
|
59 |
+
omega = 1. / 10000**omega # (D/2,)
|
60 |
+
|
61 |
+
pos = pos.reshape(-1) # (M,)
|
62 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
63 |
+
|
64 |
+
emb_sin = np.sin(out) # (M, D/2)
|
65 |
+
emb_cos = np.cos(out) # (M, D/2)
|
66 |
+
|
67 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
68 |
+
return emb
|
69 |
+
|
70 |
+
|
71 |
+
# --------------------------------------------------------
|
72 |
+
# Interpolate position embeddings for high-resolution
|
73 |
+
# References:
|
74 |
+
# DeiT: https://github.com/facebookresearch/deit
|
75 |
+
# --------------------------------------------------------
|
76 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
77 |
+
if 'pos_embed' in checkpoint_model:
|
78 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
79 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
80 |
+
num_patches = model.patch_embed.num_patches
|
81 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
82 |
+
# height (== width) for the checkpoint position embedding
|
83 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
84 |
+
# height (== width) for the new position embedding
|
85 |
+
new_size = int(num_patches ** 0.5)
|
86 |
+
# class_token and dist_token are kept unchanged
|
87 |
+
if orig_size != new_size:
|
88 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
89 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
90 |
+
# only the position tokens are interpolated
|
91 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
92 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
93 |
+
pos_tokens = torch.nn.functional.interpolate(
|
94 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
95 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
96 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
97 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|