HopooLinZ commited on
Commit
dc066a6
·
verified ·
1 Parent(s): 1d588ce

Upload 35 files

Browse files
FSC_pretrain.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+
5
+ import PIL.Image
6
+ import numpy as np
7
+ import os
8
+ import time
9
+ import random
10
+ from pathlib import Path
11
+ import math
12
+ import sys
13
+ from PIL import Image
14
+
15
+ import torch
16
+ import torch.backends.cudnn as cudnn
17
+ from torch.utils.tensorboard import SummaryWriter
18
+ import torch.nn.functional as F
19
+ from torch.utils.data import Dataset
20
+ import wandb
21
+ import timm
22
+
23
+ assert "0.4.5" <= timm.__version__ <= "0.4.9" # version check
24
+ import timm.optim.optim_factory as optim_factory
25
+
26
+ import util.misc as misc
27
+ from util.misc import NativeScalerWithGradNormCount as NativeScaler
28
+ import util.lr_sched as lr_sched
29
+ from util.FSC147 import transform_pre_train
30
+ import models_mae_noct
31
+
32
+
33
+ def get_args_parser():
34
+ parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
35
+ parser.add_argument('--batch_size', default=8, type=int,
36
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
37
+ parser.add_argument('--epochs', default=200, type=int)
38
+ parser.add_argument('--accum_iter', default=1, type=int,
39
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
40
+
41
+ # Model parameters
42
+ parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
43
+ help='Name of model to train')
44
+
45
+ parser.add_argument('--mask_ratio', default=0.5, type=float,
46
+ help='Masking ratio (percentage of removed patches).')
47
+
48
+ parser.add_argument('--norm_pix_loss', action='store_true',
49
+ help='Use (per-patch) normalized pixels as targets for computing loss')
50
+ parser.set_defaults(norm_pix_loss=False)
51
+
52
+ # Optimizer parameters
53
+ parser.add_argument('--weight_decay', type=float, default=0.05,
54
+ help='weight decay (default: 0.05)')
55
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
56
+ help='learning rate (absolute lr)')
57
+ parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
58
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
59
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
60
+ help='lower lr bound for cyclic schedulers that hit 0')
61
+ parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
62
+ help='epochs to warmup LR')
63
+
64
+ # Dataset parameters
65
+ parser.add_argument('--data_path', default='./data/FSC147/', type=str,
66
+ help='dataset path')
67
+ parser.add_argument('--anno_file', default='annotation_FSC147_384.json', type=str,
68
+ help='annotation json file')
69
+ parser.add_argument('--data_split_file', default='Train_Test_Val_FSC_147.json', type=str,
70
+ help='data split json file')
71
+ parser.add_argument('--im_dir', default='images_384_VarV2', type=str,
72
+ help='images directory')
73
+ parser.add_argument('--gt_dir', default='gt_density_map_adaptive_384_VarV2', type=str,
74
+ help='ground truth directory')
75
+ parser.add_argument('--output_dir', default='./data/out/pre_4_dir',
76
+ help='path where to save, empty for no saving')
77
+ parser.add_argument('--device', default='cuda:5',
78
+ help='device to use for training / testing')
79
+ parser.add_argument('--seed', default=0, type=int)
80
+ parser.add_argument('--resume', default='./weights/mae_pretrain_vit_base_full.pth', # mae_visualize_vit_base
81
+ help='resume from checkpoint')
82
+
83
+ # Training parameters
84
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
85
+ help='start epoch')
86
+ parser.add_argument('--num_workers', default=10, type=int)
87
+ parser.add_argument('--pin_mem', action='store_true',
88
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
89
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
90
+ parser.set_defaults(pin_mem=True)
91
+
92
+ # Distributed training parameters
93
+ parser.add_argument('--world_size', default=1, type=int,
94
+ help='number of distributed processes')
95
+ parser.add_argument('--local_rank', default=-1, type=int)
96
+ parser.add_argument('--dist_on_itp', action='store_true')
97
+ parser.add_argument('--dist_url', default='env://',
98
+ help='url used to set up distributed training')
99
+
100
+ # Logging parameters
101
+ parser.add_argument('--log_dir', default='./logs/pre_4_dir',
102
+ help='path where to tensorboard log')
103
+ parser.add_argument("--title", default="CounTR_pretraining", type=str)
104
+ parser.add_argument("--wandb", default="counting", type=str)
105
+ parser.add_argument("--team", default="wsense", type=str)
106
+ parser.add_argument("--wandb_id", default=None, type=str)
107
+ parser.add_argument('--anno_file_negative', default='annotation_FSC147_negative1.json', type=str,
108
+ help='annotation json file')
109
+ return parser
110
+
111
+
112
+ os.environ["CUDA_LAUNCH_BLOCKING"] = '5'
113
+
114
+
115
+ class TrainData(Dataset):
116
+ def __init__(self):
117
+ self.img = data_split['train']
118
+ random.shuffle(self.img)
119
+ self.img_dir = im_dir
120
+ self.TransformPreTrain = transform_pre_train(data_path)
121
+
122
+ def __len__(self):
123
+ return len(self.img)
124
+
125
+ def __getitem__(self, idx):
126
+ im_id = self.img[idx]
127
+ anno = annotations[im_id]
128
+ bboxes = anno['box_examples_coordinates']
129
+ # box_coordinates = anno.get('box_examples_coordinates', {}) # 获取图像的边界框坐标信息
130
+ # # print(box_coordinates)
131
+ # # 获取第一个类别的边界框坐标列表
132
+ # first_category = next(iter(box_coordinates), None)
133
+ # # print(first_category)
134
+ # first_category_bboxes = box_coordinates[first_category]
135
+ # if first_category_bboxes:
136
+ # # print(first_category_bboxes[0])
137
+ # bboxes = first_category_bboxes[0]
138
+ # else:
139
+ # bboxes = []
140
+ # # if first_category_bboxes:
141
+ # # bboxes = first_category_bboxes[0]
142
+ # # else:
143
+ # # pass
144
+
145
+
146
+ rects = list()
147
+ for bbox in bboxes:
148
+ x1 = bbox[0][0]
149
+ y1 = bbox[0][1]
150
+ x2 = bbox[2][0]
151
+ y2 = bbox[2][1]
152
+ rects.append([y1, x1, y2, x2])
153
+
154
+ image = Image.open('{}/{}'.format(im_dir, im_id))
155
+ image.load()
156
+ density_path = gt_dir / (im_id.split(".jpg")[0] + ".npy")
157
+ density = np.load(density_path).astype('float32')
158
+ sample = {'image': image, 'lines_boxes': rects, 'gt_density': density}
159
+ sample = self.TransformPreTrain(sample)
160
+ return sample['image']
161
+
162
+
163
+ def main(args):
164
+ misc.init_distributed_mode(args)
165
+
166
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
167
+ print("{}".format(args).replace(', ', ',\n'))
168
+
169
+ device = torch.device(args.device)
170
+
171
+ # fix the seed for reproducibility
172
+ seed = args.seed + misc.get_rank()
173
+ torch.manual_seed(seed)
174
+ np.random.seed(seed)
175
+
176
+ cudnn.benchmark = True
177
+
178
+ dataset_train = TrainData()
179
+ print(dataset_train)
180
+
181
+ if True: # args.distributed:
182
+ num_tasks = misc.get_world_size()
183
+ global_rank = misc.get_rank()
184
+ sampler_train = torch.utils.data.DistributedSampler(
185
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
186
+ )
187
+ print("Sampler_train = %s" % str(sampler_train))
188
+ else:
189
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
190
+
191
+ if global_rank == 0:
192
+ if args.log_dir is not None:
193
+ os.makedirs(args.log_dir, exist_ok=True)
194
+ log_writer = SummaryWriter(log_dir=args.log_dir)
195
+ else:
196
+ log_writer = None
197
+ if args.wandb is not None:
198
+ wandb_run = wandb.init(
199
+ config=args,
200
+ resume="allow",
201
+ project=args.wandb,
202
+ name=args.title,
203
+ # entity=args.team,
204
+ tags=["CounTR", "pretraining"],
205
+ id=args.wandb_id,
206
+ )
207
+ else:
208
+ wandb_run = None
209
+
210
+ data_loader_train = torch.utils.data.DataLoader(
211
+ dataset_train, sampler=sampler_train,
212
+ batch_size=args.batch_size,
213
+ num_workers=args.num_workers,
214
+ pin_memory=args.pin_mem,
215
+ drop_last=False,
216
+ )
217
+
218
+ # define the model
219
+ model = models_mae_noct.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
220
+
221
+ model.to(device)
222
+
223
+ model_without_ddp = model
224
+
225
+ print("Model = %s" % str(model_without_ddp))
226
+
227
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
228
+
229
+ if args.lr is None: # only base_lr is specified
230
+ args.lr = args.blr * eff_batch_size / 256
231
+
232
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
233
+ print("actual lr: %.2e" % args.lr)
234
+
235
+ print("accumulate grad iterations: %d" % args.accum_iter)
236
+ print("effective batch size: %d" % eff_batch_size)
237
+
238
+ if args.distributed:
239
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
240
+ model_without_ddp = model.module
241
+
242
+ # following timm: set wd as 0 for bias and norm layers
243
+ param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
244
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
245
+ print(optimizer)
246
+ loss_scaler = NativeScaler()
247
+
248
+ misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
249
+
250
+ print(f"Start training for {args.epochs} epochs")
251
+ start_time = time.time()
252
+ for epoch in range(args.start_epoch, args.epochs):
253
+ if args.distributed:
254
+ data_loader_train.sampler.set_epoch(epoch)
255
+
256
+ # train one epoch
257
+ model.train(True)
258
+ metric_logger = misc.MetricLogger(delimiter=" ")
259
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
260
+ header = 'Epoch: [{}]'.format(epoch)
261
+ print_freq = 20
262
+ accum_iter = args.accum_iter
263
+
264
+ optimizer.zero_grad()
265
+
266
+ if log_writer is not None:
267
+ print('log_dir: {}'.format(log_writer.log_dir))
268
+
269
+ model_ = getattr(models_mae_noct, args.model)()
270
+
271
+ for data_iter_step, samples in enumerate(metric_logger.log_every(data_loader_train, print_freq, header)):
272
+ epoch_1000x = int((data_iter_step / len(data_loader_train) + epoch) * 1000)
273
+
274
+ if data_iter_step % accum_iter == 0:
275
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
276
+
277
+ samples = samples.to(device, non_blocking=True)
278
+
279
+ with torch.cuda.amp.autocast():
280
+ loss, pred, mask = model(samples, mask_ratio=args.mask_ratio)
281
+
282
+ loss_value = loss.item()
283
+
284
+ if data_iter_step % 2000 == 0:
285
+ preds = model_.unpatchify(pred)
286
+ preds = preds.float()
287
+ preds = torch.einsum('nchw->nhwc', preds)
288
+ preds = torch.clip(preds, 0, 1)
289
+
290
+ if log_writer is not None:
291
+ log_writer.add_images('reconstruction', preds, int(epoch), dataformats='NHWC')
292
+
293
+ if wandb_run is not None:
294
+ wandb_images = []
295
+ w_samples = torch.einsum('nchw->nhwc', samples.float()).clip(0, 1)
296
+ masks = F.interpolate(
297
+ mask.reshape(shape=(mask.shape[0], 1, int(mask.shape[1] ** .5), int(mask.shape[1] ** .5))),
298
+ size=(preds.shape[1], preds.shape[2]))
299
+ masks = torch.einsum('nchw->nhwc', masks.float())
300
+ combos = (w_samples + masks.repeat(1, 1, 1, 3)).clip(0, 1)
301
+ w_images = (torch.cat([w_samples, combos, preds], dim=2) * 255).detach().cpu()
302
+ print("w_images:", w_samples.shape, combos.shape, preds.shape, "-->", w_images.shape)
303
+
304
+ for i in range(w_images.shape[0]):
305
+ wi = w_images[i, :, :, :]
306
+ wandb_images += [wandb.Image(wi.numpy().astype(np.uint8),
307
+ caption=f"Prediction {i} at epoch {epoch}")]
308
+ wandb.log({f"reconstruction": wandb_images}, step=epoch_1000x, commit=False)
309
+
310
+ if not math.isfinite(loss_value):
311
+ print("Loss is {}, stopping training".format(loss_value))
312
+ sys.exit(1)
313
+
314
+ loss /= accum_iter
315
+ loss_scaler(loss, optimizer, parameters=model.parameters(),
316
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
317
+ if (data_iter_step + 1) % accum_iter == 0:
318
+ optimizer.zero_grad()
319
+
320
+ torch.cuda.synchronize()
321
+
322
+ metric_logger.update(loss=loss_value)
323
+
324
+ lr = optimizer.param_groups[0]["lr"]
325
+ metric_logger.update(lr=lr)
326
+
327
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
328
+ if (data_iter_step + 1) % accum_iter == 0:
329
+ if log_writer is not None:
330
+ """ We use epoch_1000x as the x-axis in tensorboard.
331
+ This calibrates different curves when batch size changes.
332
+ """
333
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
334
+ log_writer.add_scalar('lr', lr, epoch_1000x)
335
+ if wandb_run is not None:
336
+ log = {"train/loss": loss_value_reduce, "train/lr": lr}
337
+ wandb.log(log, step=epoch_1000x, commit=True if data_iter_step == 0 else False)
338
+
339
+ metric_logger.synchronize_between_processes()
340
+ print("Averaged stats:", metric_logger)
341
+ train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
342
+
343
+ # save train status and model
344
+ if args.output_dir and (epoch % 100 == 0 or epoch + 1 == args.epochs):
345
+ misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
346
+ loss_scaler=loss_scaler, epoch=epoch, suffix=f"pretraining_{epoch}")
347
+
348
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
349
+ 'epoch': epoch, }
350
+
351
+ if args.output_dir and misc.is_main_process():
352
+ if log_writer is not None:
353
+ log_writer.flush()
354
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
355
+ f.write(json.dumps(log_stats) + "\n")
356
+
357
+ total_time = time.time() - start_time
358
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
359
+ print('Training time {}'.format(total_time_str))
360
+ wandb.run.finish()
361
+
362
+
363
+ if __name__ == '__main__':
364
+ args = get_args_parser()
365
+ args = args.parse_args()
366
+
367
+ # load data
368
+ data_path = Path(args.data_path)
369
+ anno_file = data_path / args.anno_file
370
+ data_split_file = data_path / args.data_split_file
371
+ im_dir = data_path / args.im_dir
372
+ gt_dir = data_path / args.gt_dir
373
+ with open(anno_file) as f:
374
+ annotations = json.load(f)
375
+ with open(data_split_file) as f:
376
+ data_split = json.load(f)
377
+
378
+ if args.output_dir:
379
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
380
+ main(args)
FSC_tain.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import numpy as np
5
+ import os
6
+ import time
7
+ import random
8
+ from pathlib import Path
9
+ import sys
10
+ from PIL import Image
11
+ import torch.nn.functional as F
12
+ import torch
13
+ import torch.backends.cudnn as cudnn
14
+ from torch.utils.data import Dataset
15
+ import torchvision
16
+ import wandb
17
+ import timm
18
+ from tqdm import tqdm
19
+
20
+ assert "0.4.5" <= timm.__version__ <= "0.4.9" # version check
21
+ import timm.optim.optim_factory as optim_factory
22
+
23
+ import util.misc as misc
24
+ from util.misc import NativeScalerWithGradNormCount as NativeScaler
25
+ import util.lr_sched as lr_sched
26
+ from util.FSC147 import transform_train, transform_val
27
+ import models_mae_cross
28
+
29
+
30
+ def get_args_parser():
31
+ parser = argparse.ArgumentParser('MAE pre-training', add_help=True)
32
+ parser.add_argument('--batch_size', default=26, type=int,
33
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)')
34
+ parser.add_argument('--epochs', default=200, type=int)
35
+ parser.add_argument('--accum_iter', default=1, type=int,
36
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
37
+
38
+ # Model parameters
39
+ parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
40
+ help='Name of model to train')
41
+ parser.add_argument('--mask_ratio', default=0.5, type=float,
42
+ help='Masking ratio (percentage of removed patches).')
43
+ parser.add_argument('--norm_pix_loss', action='store_true',
44
+ help='Use (per-patch) normalized pixels as targets for computing loss')
45
+ parser.set_defaults(norm_pix_loss=False)
46
+
47
+ # Optimizer parameters
48
+ parser.add_argument('--weight_decay', type=float, default=0.05,
49
+ help='weight decay (default: 0.05)')
50
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
51
+ help='learning rate (absolute lr)')
52
+ parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
53
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
54
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
55
+ help='lower lr bound for cyclic schedulers that hit 0')
56
+ parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
57
+ help='epochs to warmup LR')
58
+
59
+ # Dataset parameters
60
+ parser.add_argument('--data_path', default='./data/FSC147/', type=str,
61
+ help='dataset path')
62
+ parser.add_argument('--anno_file', default='annotation_FSC147_pos.json', type=str,
63
+ help='annotation json file for positive samples')
64
+ parser.add_argument('--anno_file_negative', default='./data/FSC147/annotation_FSC147_neg.json', type=str,
65
+ help='annotation json file for negative samples')
66
+ parser.add_argument('--data_split_file', default='Train_Test_Val_FSC_147.json', type=str,
67
+ help='data split json file')
68
+ parser.add_argument('--class_file', default='ImageClasses_FSC147.txt', type=str,
69
+ help='class json file')
70
+ parser.add_argument('--im_dir', default='images_384_VarV2', type=str,
71
+ help='images directory')
72
+ parser.add_argument('--output_dir', default='./data/out/fim6_dir',
73
+ help='path where to save, empty for no saving')
74
+ parser.add_argument('--device', default='cuda',
75
+ help='device to use for training / testing')
76
+ parser.add_argument('--seed', default=0, type=int)
77
+ parser.add_argument('--resume', default='./data/checkpoint.pth',
78
+ help='resume from checkpoint')
79
+ parser.add_argument('--do_resume', action='store_true',
80
+ help='Resume training (e.g. if crashed).')
81
+
82
+ # Training parameters
83
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
84
+ help='start epoch')
85
+ parser.add_argument('--num_workers', default=10, type=int)
86
+ parser.add_argument('--pin_mem', action='store_true',
87
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
88
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
89
+ parser.set_defaults(pin_mem=True)
90
+ parser.add_argument('--do_aug', action='store_true',
91
+ help='Perform data augmentation.')
92
+ parser.add_argument('--no_do_aug', action='store_false', dest='do_aug')
93
+ parser.set_defaults(do_aug=True)
94
+
95
+ # Distributed training parameters
96
+ parser.add_argument('--world_size', default=1, type=int,
97
+ help='number of distributed processes')
98
+ parser.add_argument('--local_rank', default=-1, type=int)
99
+ parser.add_argument('--dist_on_itp', action='store_true')
100
+ parser.add_argument('--dist_url', default='env://',
101
+ help='url used to set up distributed training')
102
+
103
+ # Logging parameters
104
+ parser.add_argument("--title", default="count", type=str)
105
+ parser.add_argument("--wandb", default="240227", type=str)
106
+ parser.add_argument("--team", default="wsense", type=str)
107
+ parser.add_argument("--wandb_id", default=None, type=str)
108
+
109
+ return parser
110
+
111
+
112
+ os.environ["CUDA_LAUNCH_BLOCKING"] = '0'
113
+
114
+ class TrainData(Dataset):
115
+ def __init__(self, args, split='train', do_aug=True):
116
+ with open(args.anno_file) as f:
117
+ annotations = json.load(f)
118
+ # Load negative annotations
119
+ with open(args.anno_file_negative) as f:
120
+ neg_annotations = json.load(f)
121
+ with open(args.data_split_file) as f:
122
+ data_split = json.load(f)
123
+
124
+ self.img = data_split[split]
125
+ random.shuffle(self.img)
126
+ self.split = split
127
+ self.img_dir = im_dir
128
+ self.TransformTrain = transform_train(args, do_aug=do_aug)
129
+ self.TransformVal = transform_val(args)
130
+ self.annotations = annotations
131
+ self.neg_annotations = neg_annotations
132
+ self.im_dir = im_dir
133
+
134
+ def __len__(self):
135
+ return len(self.img)
136
+
137
+ def __getitem__(self, idx):
138
+ im_id = self.img[idx]
139
+ anno = self.annotations[im_id]
140
+ bboxes = anno['box_examples_coordinates']
141
+ dots = np.array(anno['points'])
142
+
143
+ # 加载负样本的框
144
+ neg_anno = self.neg_annotations[im_id] # 假设每个图像ID在负样本注释中都有对应的条目
145
+ neg_bboxes = neg_anno['box_examples_coordinates']
146
+
147
+ rects = list()
148
+ for bbox in bboxes:
149
+ x1 = bbox[0][0]
150
+ y1 = bbox[0][1]
151
+ x2 = bbox[2][0]
152
+ y2 = bbox[2][1]
153
+ if x1 < 0:
154
+ x1 = 0
155
+ if x2 < 0:
156
+ x2 = 0
157
+ if y1 < 0:
158
+ y1 = 0
159
+ if y2 < 0:
160
+ y2 = 0
161
+
162
+ rects.append([y1, x1, y2, x2])
163
+ neg_rects = list()
164
+ for neg_bbox in neg_bboxes:
165
+ x1 = neg_bbox[0][0]
166
+ y1 = neg_bbox[0][1]
167
+ x2 = neg_bbox[2][0]
168
+ y2 = neg_bbox[2][1]
169
+ if x1 < 0:
170
+ x1 = 0
171
+ if x2 < 0:
172
+ x2 = 0
173
+ if y1 < 0:
174
+ y1 = 0
175
+ if y2 < 0:
176
+ y2 = 0
177
+
178
+ neg_rects.append([y1, x1, y2, x2])
179
+
180
+ image = Image.open('{}/{}'.format(self.im_dir, im_id))
181
+ if image.mode == "RGBA":
182
+ image = image.convert("RGB")
183
+ image.load()
184
+ m_flag = 0
185
+
186
+ sample = {'image': image, 'lines_boxes': rects, 'neg_lines_boxes': neg_rects,'dots': dots, 'id': im_id, 'm_flag': m_flag}
187
+ sample = self.TransformTrain(sample) if self.split == "train" else self.TransformVal(sample)
188
+ return sample['image'], sample['gt_density'], len(dots), sample['boxes'],sample['neg_boxes'], sample['pos'],sample['m_flag'], im_id
189
+
190
+ def main(args):
191
+ wandb_run = None
192
+ try:
193
+ misc.init_distributed_mode(args)
194
+
195
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
196
+ print("{}".format(args).replace(', ', ',\n'))
197
+
198
+ device = torch.device(args.device)
199
+ # if torch.cuda.is_available():
200
+ # device = torch.device("cuda:5")
201
+
202
+ # fix the seed for reproducibility
203
+ seed = args.seed + misc.get_rank()
204
+ torch.manual_seed(seed)
205
+ np.random.seed(seed)
206
+ cudnn.benchmark = True
207
+
208
+ dataset_train = TrainData(args, do_aug=args.do_aug)
209
+ dataset_val = TrainData(args, split='val')
210
+
211
+ num_tasks = misc.get_world_size()
212
+ global_rank = misc.get_rank()
213
+ sampler_train = torch.utils.data.DistributedSampler(
214
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
215
+ )
216
+ sampler_val = torch.utils.data.DistributedSampler(
217
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
218
+ )
219
+
220
+ if global_rank == 0:
221
+ if args.wandb is not None:
222
+ wandb_run = wandb.init(
223
+ config=args,
224
+ resume="allow",
225
+ project=args.wandb,
226
+ name=args.title,
227
+ # entity=args.team,
228
+ tags=["count", "finetuning"],
229
+ id=args.wandb_id,
230
+ )
231
+
232
+ data_loader_train = torch.utils.data.DataLoader(
233
+ dataset_train, sampler=sampler_train,
234
+ batch_size=args.batch_size,
235
+ num_workers=args.num_workers,
236
+ pin_memory=args.pin_mem,
237
+ drop_last=False,
238
+ )
239
+ data_loader_val = torch.utils.data.DataLoader(
240
+ dataset_val, sampler=sampler_val,
241
+ batch_size=args.batch_size,
242
+ num_workers=args.num_workers,
243
+ pin_memory=args.pin_mem,
244
+ drop_last=False,
245
+ )
246
+
247
+ # define the model
248
+ model = models_mae_cross.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
249
+ model.to(device)
250
+ model_without_ddp = model
251
+ # print("Model = %s" % str(model_without_ddp))
252
+
253
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
254
+
255
+ if args.lr is None: # only base_lr is specified
256
+ args.lr = args.blr * eff_batch_size / 256
257
+
258
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
259
+ print("actual lr: %.2e" % args.lr)
260
+
261
+ print("accumulate grad iterations: %d" % args.accum_iter)
262
+ print("effective batch size: %d" % eff_batch_size)
263
+
264
+ if args.distributed:
265
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
266
+ model_without_ddp = model.module
267
+
268
+ # following timm: set wd as 0 for bias and norm layers
269
+ param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
270
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
271
+ print(optimizer)
272
+
273
+ loss_scaler = NativeScaler()
274
+
275
+ min_MAE = 99999
276
+ print_freq = 50
277
+ save_freq = 50
278
+
279
+ misc.load_model_FSC_full(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
280
+
281
+ print(f"Start training for {args.epochs - args.start_epoch} epochs - rank {global_rank}")
282
+ start_time = time.time()
283
+ for epoch in range(args.start_epoch, args.epochs):
284
+ if args.distributed:
285
+ data_loader_train.sampler.set_epoch(epoch)
286
+
287
+ # train one epoch
288
+ model.train(True)
289
+ accum_iter = args.accum_iter
290
+
291
+ # some parameters in training
292
+ train_mae = torch.tensor([0], dtype=torch.float64, device=device)
293
+ train_mse = torch.tensor([0], dtype=torch.float64, device=device)
294
+ val_mae = torch.tensor([0], dtype=torch.float64, device=device)
295
+ val_mse = torch.tensor([0], dtype=torch.float64, device=device)
296
+ val_nae = torch.tensor([0], dtype=torch.float64, device=device)
297
+
298
+ optimizer.zero_grad()
299
+
300
+ for data_iter_step, (samples, gt_density, _, pos_boxes, neg_boxes, pos, m_flag, im_names) in enumerate(
301
+ tqdm(data_loader_train, total=len(data_loader_train), desc=f"Train [e. {epoch} - r. {global_rank}]")):
302
+ idx = data_iter_step + (epoch * len(data_loader_train))
303
+
304
+ if data_iter_step % accum_iter == 0:
305
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
306
+
307
+ samples = samples.to(device, non_blocking=True, dtype=torch.half)
308
+ gt_density = gt_density.to(device, non_blocking=True, dtype=torch.half)
309
+ pos_boxes = pos_boxes.to(device, non_blocking=True, dtype=torch.half)
310
+ neg_boxes = neg_boxes.to(device, non_blocking=True, dtype=torch.half)
311
+
312
+ # 如果至少有一个图像在批处理中使用了Type 2 Mosaic,则禁止0-shot。
313
+ flag = 0
314
+ for i in range(m_flag.shape[0]):
315
+ flag += m_flag[i].item()
316
+ if flag == 0:
317
+ shot_num = random.randint(0, 3)
318
+ else:
319
+ shot_num = random.randint(1, 3)
320
+
321
+ with torch.cuda.amp.autocast():
322
+ pos_output = model(samples, pos_boxes, shot_num) # 正样本输出
323
+
324
+ # 计算正样本损失
325
+ mask = np.random.binomial(n=1, p=0.8, size=[384, 384])
326
+ masks = np.tile(mask, (pos_output.shape[0], 1))
327
+ masks = masks.reshape(pos_output.shape[0], 384, 384)
328
+ masks = torch.from_numpy(masks).to(device)
329
+ pos_loss = ((pos_output - gt_density) ** 2)
330
+ pos_loss = (pos_loss * masks / (384 * 384)).sum() / pos_output.shape[0]
331
+ # 负样本输出
332
+
333
+ with torch.cuda.amp.autocast():
334
+ neg_output = model(samples, neg_boxes, 1) # 负样本输出
335
+
336
+ cnt1 = 1-torch.exp(-(torch.abs(pos_output.sum()/60 - gt_density.sum()/60).mean()))
337
+ if neg_output.shape[0] == 0:
338
+ cnt2 = 0
339
+ else:
340
+ # cnt2 = torch.log(torch.abs((neg_output.sum() / neg_output.shape[0]) - 1).mean()+1)
341
+ cnt2 = 1-torch.exp(-(torch.abs((neg_output.sum() / (neg_output.shape[0]*60)) - 1).mean()))
342
+ cnt = cnt1+cnt2
343
+
344
+ # 计算正样本损失
345
+ mask = np.random.binomial(n=1, p=0.8, size=[384, 384])
346
+ masks = np.tile(mask, (neg_output.shape[0], 1))
347
+ masks = masks.reshape(neg_output.shape[0], 384, 384)
348
+ masks = torch.from_numpy(masks).to(device)
349
+ neg_loss = ((neg_output - gt_density) ** 2)
350
+ if neg_output.shape[0] == 0:
351
+ neg_loss = 1
352
+ else:
353
+ neg_loss = (neg_loss * masks / (384 * 384)).sum() / neg_output.shape[0]
354
+ margin = 0.5
355
+ contrastive_loss = torch.relu(pos_loss - neg_loss + margin)
356
+ total_loss = contrastive_loss+pos_loss
357
+
358
+
359
+ # 更新 MAE 和 RMSE
360
+ with torch.no_grad():
361
+ pred_cnt = (pos_output.view(len(samples), -1)).sum(1) / 60
362
+ gt_cnt = (gt_density.view(len(samples), -1)).sum(1) / 60
363
+ cnt_err = torch.abs(pred_cnt - gt_cnt).float()
364
+ batch_mae = cnt_err.double().mean()
365
+ batch_mse = (cnt_err ** 2).double().mean()
366
+
367
+ train_mae += batch_mae
368
+ train_mse += batch_mse
369
+
370
+ if not torch.isfinite(total_loss):
371
+ print("Loss is {}, stopping training".format(total_loss))
372
+ sys.exit(1)
373
+
374
+ total_loss /= accum_iter
375
+ loss_scaler(total_loss, optimizer, parameters=model.parameters(),
376
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
377
+ if (data_iter_step + 1) % accum_iter == 0:
378
+ optimizer.zero_grad()
379
+
380
+ lr = optimizer.param_groups[0]["lr"]
381
+ loss_value_reduce = misc.all_reduce_mean(total_loss)
382
+
383
+ if (data_iter_step + 1) % (print_freq * accum_iter) == 0 and (data_iter_step + 1) != len(data_loader_train) and data_iter_step != 0:
384
+ if wandb_run is not None:
385
+ log = {"train/loss": loss_value_reduce,
386
+ "train/lr": lr,
387
+ "train/MAE": batch_mae,
388
+ "train/RMSE": batch_mse ** 0.5}
389
+ wandb.log(log, step=idx)
390
+
391
+ # evaluation on Validation split
392
+ for val_samples, val_gt_density, val_n_ppl, val_boxes,_, val_pos, _, val_im_names in \
393
+ tqdm(data_loader_val, total=len(data_loader_val),
394
+ desc=f"Val [e. {epoch} - r. {global_rank}]"):
395
+
396
+ val_samples = val_samples.to(device, non_blocking=True, dtype=torch.half)
397
+ val_gt_density = val_gt_density.to(device, non_blocking=True, dtype=torch.half)
398
+ val_boxes = val_boxes.to(device, non_blocking=True, dtype=torch.half)
399
+ val_n_ppl = val_n_ppl.to(device, non_blocking=True)
400
+ shot_num = random.randint(0, 3)
401
+
402
+ with torch.no_grad():
403
+ with torch.cuda.amp.autocast():
404
+ val_output = model(val_samples, val_boxes, shot_num)
405
+
406
+ val_pred_cnt = (val_output.view(len(val_samples), -1)).sum(1) / 60
407
+ val_gt_cnt = (val_gt_density.view(len(val_samples), -1)).sum(1) / 60
408
+ # print('val_pred_cnt',val_pred_cnt)
409
+ # print('val_gt_cnt',val_gt_cnt)
410
+ val_cnt_err = torch.abs(val_pred_cnt - val_gt_cnt).float()
411
+ # print('val_cnt_err',val_cnt_err.mean())
412
+ val_cnt_err[val_cnt_err == float('inf')] = 0
413
+ val_mae += val_cnt_err.double().mean()
414
+
415
+ # val_mae += val_cnt_err
416
+ # print('val_mae',val_mae.mean())
417
+ val_cnt_err[val_cnt_err == float('inf')] = 0
418
+ val_mse += (val_cnt_err ** 2).double().mean()
419
+
420
+ # val_mse += (val_cnt_err ** 2)
421
+ _val_nae = val_cnt_err / val_gt_cnt
422
+ _val_nae[_val_nae == float('inf')] = 0
423
+ val_nae += _val_nae.double().mean()
424
+ # val_mae = val_mae/len(data_loader_val)
425
+ # val_mse = val_mse/len(data_loader_val)
426
+ # print('val_mae',val_mae)
427
+ # print('val_mse',val_mse)
428
+ # Output visualisation information to W&B
429
+ if wandb_run is not None:
430
+ train_wandb_densities = []
431
+ train_wandb_bboxes = []
432
+ val_wandb_densities = []
433
+ val_wandb_bboxes = []
434
+ black = torch.zeros([384, 384], device=device)
435
+
436
+ for i in range(pos_output.shape[0]):
437
+ # gt and predicted density
438
+ w_d_map = torch.stack([pos_output[i], black, black])
439
+ gt_map = torch.stack([gt_density[i], black, black])
440
+ box_map = misc.get_box_map(samples[i], pos[i], device)
441
+ w_gt_density = samples[i] / 2 + gt_map + box_map
442
+ w_d_map_overlay = samples[i] / 2 + w_d_map
443
+ w_densities = torch.cat([w_gt_density, w_d_map, w_d_map_overlay], dim=2)
444
+ w_densities = torch.clamp(w_densities, 0, 1)
445
+ train_wandb_densities += [wandb.Image(torchvision.transforms.ToPILImage()(w_densities),
446
+ caption=f"[E#{epoch}] {im_names[i]} ({torch.sum(gt_density[i]).item()}, {torch.sum(pos_output[i]).item()})")]
447
+
448
+ # exemplars
449
+ w_boxes = torch.cat([pos_boxes[i][x, :, :, :] for x in range(pos_boxes[i].shape[0])], 2)
450
+ train_wandb_bboxes += [wandb.Image(torchvision.transforms.ToPILImage()(w_boxes),
451
+ caption=f"[E#{epoch}] {im_names[i]}")]
452
+
453
+ for i in range(val_output.shape[0]):
454
+ # gt and predicted density
455
+ w_d_map = torch.stack([val_output[i], black, black])
456
+ gt_map = torch.stack([val_gt_density[i], black, black])
457
+ box_map = misc.get_box_map(val_samples[i], val_pos[i], device)
458
+ w_gt_density = val_samples[i] / 2 + gt_map + box_map
459
+ w_d_map_overlay = val_samples[i] / 2 + w_d_map
460
+ w_densities = torch.cat([w_gt_density, w_d_map, w_d_map_overlay], dim=2)
461
+ w_densities = torch.clamp(w_densities, 0, 1)
462
+ val_wandb_densities += [wandb.Image(torchvision.transforms.ToPILImage()(w_densities),
463
+ caption=f"[E#{epoch}] {val_im_names[i]} ({torch.sum(val_gt_density[i]).item()}, {torch.sum(val_output[i]).item()})")]
464
+
465
+ # exemplars
466
+ w_boxes = torch.cat([val_boxes[i][x, :, :, :] for x in range(val_boxes[i].shape[0])], 2)
467
+ val_wandb_bboxes += [wandb.Image(torchvision.transforms.ToPILImage()(w_boxes),
468
+ caption=f"[E#{epoch}] {val_im_names[i]}")]
469
+
470
+ log = {"train/loss": loss_value_reduce,
471
+ "train/lr": lr,
472
+ "train/MAE": batch_mae,
473
+ "train/RMSE": batch_mse ** 0.5,
474
+ "val/MAE": val_mae / len(data_loader_val),
475
+ "val/RMSE": (val_mse / len(data_loader_val)) ** 0.5,
476
+ "val/NAE": val_nae / len(data_loader_val),
477
+ "train_densitss": train_wandb_densities,
478
+ "val_densites": val_wandb_densities,
479
+ "train_boxes": train_wandb_bboxes,
480
+ "val_boxes": val_wandb_bboxes}
481
+ wandb.log(log, step=idx)
482
+
483
+ # save train status and model
484
+ if args.output_dir and (epoch % save_freq == 0 or epoch + 1 == args.epochs) and epoch != 0:
485
+ misc.save_model(
486
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
487
+ loss_scaler=loss_scaler, epoch=epoch, suffix=f"finetuning_{epoch}", upload=epoch % 100 == 0)
488
+ elif True:
489
+ misc.save_model(
490
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
491
+ loss_scaler=loss_scaler, epoch=epoch, suffix=f"finetuning_last", upload=False)
492
+ if args.output_dir and val_mae / len(data_loader_val) < min_MAE:
493
+ min_MAE = val_mae / len(data_loader_val)
494
+ misc.save_model(
495
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
496
+ loss_scaler=loss_scaler, epoch=epoch, suffix="finetuning_minMAE")
497
+
498
+ print(f'[Train Epoch #{epoch}] - MAE: {train_mae.item() / len(data_loader_train):5.2f}, RMSE: {(train_mse.item() / len(data_loader_train)) ** 0.5:5.2f}', flush=True)
499
+ print(f'[Val Epoch #{epoch}] - MAE: {val_mae.item() / len(data_loader_val):5.2f}, RMSE: {(val_mse.item() / len(data_loader_val)) ** 0.5:5.2f}, NAE: {val_nae.item() / len(data_loader_val):5.2f}', flush=True)
500
+
501
+ total_time = time.time() - start_time
502
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
503
+ print('Training time {}'.format(total_time_str))
504
+
505
+ finally:
506
+ if wandb_run is not None:
507
+ wandb.run.finish()
508
+
509
+
510
+ if __name__ == '__main__':
511
+ args = get_args_parser()
512
+ args = args.parse_args()
513
+
514
+ data_path = Path(args.data_path)
515
+ anno_file = data_path / args.anno_file
516
+ data_split_file = data_path / args.data_split_file
517
+ im_dir = data_path / args.im_dir
518
+
519
+ if args.do_aug:
520
+ class_file = data_path / args.class_file
521
+ else:
522
+ class_file = None
523
+
524
+ args.anno_file = anno_file
525
+ args.data_split_file = data_split_file
526
+ args.im_dir = im_dir
527
+ args.class_file = class_file
528
+
529
+ if args.output_dir:
530
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
531
+
532
+ main(args)
FSC_test.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import numpy as np
4
+ import os
5
+ from pathlib import Path
6
+ from PIL import Image, ImageDraw
7
+ import matplotlib.pyplot as plt
8
+ import scipy.ndimage as ndimage
9
+ import pandas as pd
10
+ import random
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.backends.cudnn as cudnn
14
+ from torch.utils.data import Dataset
15
+ import torchvision
16
+ from torchvision import transforms
17
+ import torchvision.transforms.functional as TF
18
+ import timm
19
+ from util.FSC147 import transform_train, transform_val
20
+ from tqdm import tqdm
21
+ assert "0.4.5" <= timm.__version__ <= "0.4.9" # version check
22
+
23
+ import util.misc as misc
24
+ import models_mae_cross
25
+
26
+
27
+ def get_args_parser():
28
+ parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
29
+
30
+ # Model parameters
31
+ parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
32
+ help='Name of model to train')
33
+ parser.add_argument('--mask_ratio', default=0.5, type=float,
34
+ help='Masking ratio (percentage of removed patches).')
35
+ parser.add_argument('--norm_pix_loss', action='store_true',
36
+ help='Use (per-patch) normalized pixels as targets for computing loss')
37
+ parser.set_defaults(norm_pix_loss=False)
38
+
39
+ # Dataset parameters
40
+ parser.add_argument('--data_path', default='./data/FSC147/', type=str,
41
+ help='dataset path')
42
+ parser.add_argument('--anno_file', default='annotation_FSC147_positive.json', type=str,
43
+ help='annotation json file')
44
+ parser.add_argument('--anno_file_negative', default='./data/FSC147/annotation_FSC147_neg2.json', type=str,
45
+ help='annotation json file')
46
+ parser.add_argument('--data_split_file', default='Train_Test_Val_FSC_147.json', type=str,
47
+ help='data split json file')
48
+ parser.add_argument('--im_dir', default='images_384_VarV2', type=str,
49
+ help='images directory')
50
+ parser.add_argument('--output_dir', default='./Image',
51
+ help='path where to save, empty for no saving')
52
+ parser.add_argument('--device', default='cuda',
53
+ help='device to use for training / testing')
54
+ parser.add_argument('--seed', default=0, type=int)
55
+ parser.add_argument('--resume', default='./output_fim6_dir/checkpoint-0.pth',
56
+ help='resume from checkpoint')
57
+ parser.add_argument('--external', action='store_true',
58
+ help='Set this param for using external exemplars')
59
+ parser.add_argument('--box_bound', default=-1, type=int,
60
+ help='The max number of exemplars to be considered')
61
+ parser.add_argument('--split', default="test", type=str)
62
+
63
+ # Training parameters
64
+ parser.add_argument('--num_workers', default=0, type=int)
65
+ parser.add_argument('--pin_mem', action='store_true',
66
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
67
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
68
+ parser.set_defaults(pin_mem=True)
69
+ parser.add_argument('--normalization', default=True, help='Set to False to disable test-time normalization')
70
+
71
+ # Distributed training parameters
72
+ parser.add_argument('--world_size', default=1, type=int,
73
+ help='number of distributed processes')
74
+ parser.add_argument('--local_rank', default=-1, type=int)
75
+ parser.add_argument('--dist_on_itp', action='store_true')
76
+ parser.add_argument('--dist_url', default='env://',
77
+ help='url used to set up distributed training')
78
+
79
+ return parser
80
+
81
+ os.environ["CUDA_LAUNCH_BLOCKING"] = '5'
82
+
83
+ class TestData(Dataset):
84
+ def __init__(self, args, split='val', do_aug=True):
85
+ with open(data_path/args.anno_file) as f:
86
+ annotations = json.load(f)
87
+ # Load negative annotations
88
+ with open(args.anno_file_negative) as f:
89
+ neg_annotations = json.load(f)
90
+ with open(data_path/args.data_split_file) as f:
91
+ data_split = json.load(f)
92
+
93
+ self.img = data_split[split]
94
+ random.shuffle(self.img)
95
+ self.split = split
96
+ self.img_dir = im_dir
97
+ # self.TransformTrain = transform_train(args, do_aug=do_aug)
98
+ self.TransformVal = transform_val(args)
99
+ self.annotations = annotations
100
+ self.neg_annotations = neg_annotations
101
+ self.im_dir = im_dir
102
+
103
+ def __len__(self):
104
+ return len(self.img)
105
+
106
+ def __getitem__(self, idx):
107
+ im_id = self.img[idx]
108
+ anno = self.annotations[im_id]
109
+ bboxes = anno['box_examples_coordinates']
110
+ dots = np.array(anno['points'])
111
+
112
+ # 加载负样本的框
113
+ neg_anno = self.neg_annotations[im_id] # 假设每个图像ID在负样本注释中都有对应的条目
114
+ neg_bboxes = neg_anno['box_examples_coordinates']
115
+
116
+ rects = list()
117
+ for bbox in bboxes:
118
+ x1 = bbox[0][0]
119
+ y1 = bbox[0][1]
120
+ x2 = bbox[2][0]
121
+ y2 = bbox[2][1]
122
+ if x1 < 0:
123
+ x1 = 0
124
+ if x2 < 0:
125
+ x2 = 0
126
+ if y1 < 0:
127
+ y1 = 0
128
+ if y2 < 0:
129
+ y2 = 0
130
+
131
+ rects.append([y1, x1, y2, x2])
132
+ neg_rects = list()
133
+ for neg_bbox in neg_bboxes:
134
+ x1 = neg_bbox[0][0]
135
+ y1 = neg_bbox[0][1]
136
+ x2 = neg_bbox[2][0]
137
+ y2 = neg_bbox[2][1]
138
+ if x1 < 0:
139
+ x1 = 0
140
+ if x2 < 0:
141
+ x2 = 0
142
+ if y1 < 0:
143
+ y1 = 0
144
+ if y2 < 0:
145
+ y2 = 0
146
+
147
+ neg_rects.append([y1, x1, y2, x2])
148
+
149
+ image = Image.open('{}/{}'.format(self.im_dir, im_id))
150
+ if image.mode == "RGBA":
151
+ image = image.convert("RGB")
152
+ image.load()
153
+ m_flag = 0
154
+
155
+ sample = {'image': image, 'lines_boxes': rects,'neg_lines_boxes': neg_rects, 'dots': dots, 'id': im_id, 'm_flag': m_flag}
156
+ sample = self.TransformTrain(sample) if self.split == "train" else self.TransformVal(sample)
157
+ # if self.split == "train":
158
+ # sample = self.TransformTrain(sample)
159
+ # # print(sample.keys())
160
+ return sample['image'], sample['gt_density'], len(dots), sample['boxes'], sample['neg_boxes'], sample['pos'],sample['m_flag'], im_id
161
+
162
+ def batched_rmse(predictions, targets, batch_size=100):
163
+ """
164
+ 分批计算RMSE
165
+ :param predictions: 模型预测的值,一个PyTorch张量
166
+ :param targets: 真实的值,一个PyTorch张量,与predictions形状相同
167
+ :param batch_size: 每个批次的大小
168
+ :return: RMSE值
169
+ """
170
+ total_mse = 0.0
171
+ total_count = 0
172
+
173
+ # 分批处理
174
+ for i in range(0, len(predictions), batch_size):
175
+ batch_predictions = predictions[i:i+batch_size]
176
+ batch_targets = targets[i:i+batch_size]
177
+
178
+ # 确保使用float64进行计算以提高精度
179
+ batch_predictions = batch_predictions.double()
180
+ batch_targets = batch_targets.double()
181
+
182
+ # 计算批次的MSE
183
+ difference = batch_predictions - batch_targets
184
+ mse = torch.mean(difference ** 2)
185
+
186
+ # 累加MSE和计数
187
+ total_mse += mse * len(batch_predictions)
188
+ total_count += len(batch_predictions)
189
+
190
+ # 计算平均MSE
191
+ avg_mse = total_mse / total_count
192
+
193
+ # 计算RMSE
194
+ rmse_val = torch.sqrt(avg_mse)
195
+
196
+ return rmse_val
197
+ def batched_mae(predictions, targets, batch_size=100):
198
+ """
199
+ 分批计算MAE
200
+ :param predictions: 模型预测的值,一个PyTorch张量
201
+ :param targets: 真实的值,一个PyTorch张量,与predictions形状相同
202
+ :param batch_size: 每个批次的大小
203
+ :return: MAE值
204
+ """
205
+ total_mae = 0.0
206
+ total_count = 0
207
+
208
+ # 分批处理
209
+ for i in range(0, len(predictions), batch_size):
210
+ batch_predictions = predictions[i:i+batch_size]
211
+ batch_targets = targets[i:i+batch_size]
212
+
213
+ # 计算批次的绝对误差
214
+ absolute_errors = torch.abs(batch_predictions - batch_targets)
215
+
216
+ # 累加绝对误差和计数
217
+ total_mae += torch.sum(absolute_errors)
218
+ total_count += len(batch_predictions)
219
+
220
+ # 计算平均绝对误差
221
+ avg_mae = total_mae / total_count
222
+
223
+ return avg_mae
224
+
225
+ def main(args):
226
+ misc.init_distributed_mode(args)
227
+
228
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
229
+ print("{}".format(args).replace(', ', ',\n'))
230
+
231
+ device = torch.device(args.device)
232
+
233
+ # fix the seed for reproducibility
234
+ seed = args.seed + misc.get_rank()
235
+ torch.manual_seed(seed)
236
+ np.random.seed(seed)
237
+
238
+ cudnn.benchmark = True
239
+
240
+ # dataset_test = TestData(external=args.external, box_bound=args.box_bound, split=args.split)
241
+ dataset_test = TestData(args, split='test')
242
+ num_tasks = misc.get_world_size()
243
+ global_rank = misc.get_rank()
244
+ sampler_test = torch.utils.data.DistributedSampler(
245
+ dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
246
+ )
247
+
248
+ data_loader_test = torch.utils.data.DataLoader(
249
+ dataset_test, sampler=sampler_test,
250
+ batch_size=1,
251
+ num_workers=args.num_workers,
252
+ pin_memory=args.pin_mem,
253
+ drop_last=False,
254
+ )
255
+
256
+ # define the model
257
+ model = models_mae_cross.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
258
+ model.to(device)
259
+ model_without_ddp = model
260
+
261
+ if args.distributed:
262
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
263
+ model_without_ddp = model.module
264
+
265
+ misc.load_model_FSC(args=args, model_without_ddp=model_without_ddp)
266
+
267
+ print(f"Start testing.")
268
+
269
+ # test
270
+ model.eval()
271
+
272
+ # some parameters in training
273
+ train_mae = 0
274
+ train_rmse = 0
275
+ train_nae = 0
276
+ tot_load_time = 0
277
+ tot_infer_time = 0
278
+
279
+ loss_array = []
280
+ gt_array = []
281
+ pred_arr = []
282
+ name_arr = []
283
+ empties = []
284
+
285
+ total_mae = 0.0
286
+ total_mse = 0.0
287
+ total_nae = 0.0
288
+ total_count = 0
289
+ sub_batch_size = 50
290
+ for val_samples, val_gt_density, val_n_ppl, val_boxes,neg_val_boxes, val_pos, _, val_im_names in tqdm(data_loader_test, total=len(data_loader_test), desc="Validation"):
291
+ val_samples = val_samples.to(device, non_blocking=True, dtype=torch.float) # 使用更高精度
292
+ val_gt_density = val_gt_density.to(device, non_blocking=True, dtype=torch.float)
293
+ val_boxes = val_boxes.to(device, non_blocking=True, dtype=torch.float)
294
+ neg_val_boxes = neg_val_boxes.to(device, non_blocking=True, dtype=torch.float)
295
+ num_samples = val_samples.size(0)
296
+ total_count += num_samples
297
+
298
+ for i in range(0, num_samples, sub_batch_size):
299
+ sub_val_samples = val_samples[i:i+sub_batch_size]
300
+ sub_val_gt_density = val_gt_density[i:i+sub_batch_size]
301
+
302
+ with torch.no_grad():
303
+ with torch.cuda.amp.autocast():
304
+ sub_val_output = model(sub_val_samples, val_boxes[i:i+sub_batch_size], 3)
305
+ with torch.no_grad():
306
+ with torch.cuda.amp.autocast():
307
+ neg_sub_val_output = model(sub_val_samples, neg_val_boxes[i:i+sub_batch_size], 3)
308
+ # output = torch.clamp((sub_val_output-neg_sub_val_output),min=0)
309
+ sub_val_pred_cnt = torch.abs(sub_val_output.sum()) / 60
310
+ # sub_val_pred_cnt = torch.abs(output.sum()) / 60
311
+ # neg_sub_val_pred_cnt = torch.abs(neg_sub_val_output.sum()) / 60
312
+ sub_val_gt_cnt = sub_val_gt_density.sum() / 60
313
+
314
+ sub_val_cnt_err = torch.abs(sub_val_pred_cnt - sub_val_gt_cnt)
315
+
316
+ # 逐项添加并检查
317
+ if not torch.isinf(sub_val_cnt_err) and not torch.isnan(sub_val_cnt_err):
318
+ batch_mae = sub_val_cnt_err.item()
319
+ batch_mse = sub_val_cnt_err.item() ** 2
320
+ batch_nae = sub_val_cnt_err.item() / sub_val_gt_cnt.item() if sub_val_gt_cnt.item() != 0 else 0
321
+
322
+ total_mae += batch_mae * sub_val_samples.size(0)
323
+ total_mse += batch_mse * sub_val_samples.size(0)
324
+ total_nae += batch_nae * sub_val_samples.size(0)
325
+ sub_val_pred_cnt = (sub_val_pred_cnt).int()
326
+ final_mae = total_mae / total_count
327
+ final_rmse = (total_mse / total_count) ** 0.5
328
+ final_nae = total_nae / total_count
329
+
330
+ print(f'MAE: {final_mae}, RMSE: {final_rmse}, NAE: {final_nae}')
331
+
332
+
333
+
334
+ if __name__ == '__main__':
335
+ args = get_args_parser()
336
+ args = args.parse_args()
337
+
338
+ # load data
339
+ data_path = Path(args.data_path)
340
+ anno_file = data_path / args.anno_file
341
+ data_split_file = data_path / args.data_split_file
342
+ im_dir = data_path / args.im_dir
343
+
344
+ with open(anno_file) as f:
345
+ annotations = json.load(f)
346
+
347
+ with open(data_split_file) as f:
348
+ data_split = json.load(f)
349
+
350
+ if args.output_dir:
351
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
352
+ main(args)
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Chang Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,100 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VA-Count
2
+ [ECCV 2024] Zero-shot Object Counting with Good Exemplars
3
+ [[paper](https://arxiv.org/abs/2407.04948)]
4
+ ![figure](figure.png)
5
+ # Zero-shot Object Counting with Good Exemplars
6
+ ## News🚀
7
+ * **2024.09.27**: Our code is released.
8
+ * **2024.09.26**: Our inference code has been updated, and the code for selecting exemplars and the training code will be coming soon.
9
+ * **2024.07.02**: VA-Count is accepted by ECCV2024.
10
+ ## Overview
11
+ Overview of the proposed method. The proposed method focuses on two main elements: the Exemplar Enhancement Module (EEM) for improving exemplar quality through a patch selection integrated with Grounding DINO, and the Noise Suppression Module (NSM) that distinguishes between positive and negative class samples using density maps. It employs a Contrastive Loss function to refine the precision in identifying target class objects from others in an image.
12
+ ## Environment
13
+ ```
14
+ pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
15
+ pip install timm==0.3.2
16
+ pip install numpy
17
+ pip install matplotlib tqdm
18
+ pip install tensorboard
19
+ pip install scipy
20
+ pip install imgaug
21
+ pip install opencv-python
22
+ pip3 install hub
23
+ ```
24
+ ### For more information on Grounding DINO, please refer to the following link:
25
+ [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO)
26
+ We are very grateful for the Grounding DINO approach, which has been instrumental in our work!
27
+
28
+ ## Datasets
29
+
30
+ * [FSC147](https://github.com/cvlab-stonybrook/LearningToCountEverything)
31
+
32
+ * [CARPK](https://lafi.github.io/LPN/)
33
+
34
+ Preparing the datasets as follows:
35
+
36
+ ```
37
+ ./data/
38
+ |--FSC147
39
+ | |--images_384_VarV2
40
+ | | |--2.jpg
41
+ | | |--3.jpg
42
+ | |--gt_density_map_adaptive_384_VarV2
43
+ | | |--2.npy
44
+ | | |--3.npy
45
+ | |--annotation_FSC147_384.json
46
+ | |--Train_Test_Val_FSC_147.json
47
+ | |--ImageClasses_FSC147.txt
48
+ | |--train.txt
49
+ | |--test.txt
50
+ | |--val.txt
51
+ |--CARPK/
52
+ | |--Annotations/
53
+ | |--Images/
54
+ | |--ImageSets/
55
+ ```
56
+ ## Inference
57
+ + For inference, you can download the model from [Baidu-Disk](https://pan.baidu.com/s/11sbdDYLDfTOIPx5pZvBpmw?pwd=paeh), passward:paeh
58
+ ```
59
+ python FSC_test.py --output_dir ./data/out/results_base --resume ./data/checkpoint_FSC.pth
60
+ ```
61
+ ## Single and Multiple Object Classifier Training
62
+ ```
63
+ python datasetmake.py
64
+ python biclassify.py
65
+ ```
66
+ + You can also directly download the model from [Baidu-Disk](https://pan.baidu.com/s/1fOF0giI3yQpvGTiNFUI7cQ?pwd=psum), passward:psum Save it in ./data/out/classify/
67
+ ## Generate exemplars
68
+ ```
69
+ python grounding_pos.py --root_path ./data/FSC147/
70
+ python grounding_neg.py --root_path ./data/FSC147/
71
+ ```
72
+
73
+ ## Train
74
+
75
+ ```
76
+ CUDA_VISIBLE_DEVICES=0 python FSC_pretrain.py \
77
+ --epochs 500 \
78
+ --warmup_epochs 10 \
79
+ --blr 1.5e-4 --weight_decay 0.05
80
+ ```
81
+ + You can also directly download the pre-train model from [Baidu-Disk](https://pan.baidu.com/s/1_-w_9I4bPA66pMZkHTrdrg?pwd=xynw), passward:xynw Save it in ./data/
82
+ ```
83
+ CUDA_VISIBLE_DEVICES=0 python FSC_train.py --epochs 1000 --batch_size 8 --lr 1e-5 --output_dir ./data/out/
84
+ ```
85
+
86
+ ## Citation
87
+
88
+ ```
89
+ @inproceedings{zhu2024zero,
90
+ title={Zero-shot Object Counting with Good Exemplars},
91
+ author={Zhu, Huilin and Yuan, Jingling and Yang, Zhengwei and Guo, Yu and Wang, Zheng and Zhong, Xian and He, Shengfeng},
92
+ booktitle={Proceedings of the European Conference on Computer Vision},
93
+ year={2024}
94
+ }
95
+ ```
96
+
97
+ ## Acknowledgement
98
+ This project is based on the implementation from [CounTR](https://github.com/Verg-Avesta/CounTR), we are very grateful for this work and [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO).
99
+
100
+ #### If you have any questions, please get in touch with me ([email protected]).
__pycache__/models_crossvit.cpython-38.pyc ADDED
Binary file (6.28 kB). View file
 
__pycache__/models_mae_cross.cpython-38.pyc ADDED
Binary file (6.69 kB). View file
 
__pycache__/models_mae_noct.cpython-38.pyc ADDED
Binary file (7.03 kB). View file
 
__pycache__/models_mae_noct.cpython-39.pyc ADDED
Binary file (6.96 kB). View file
 
biclassify.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import torch
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision.transforms import Compose, Resize, Normalize, ToTensor
6
+ from PIL import Image
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from sklearn.model_selection import train_test_split
10
+ import clip
11
+ import re
12
+ import torchvision.models as models
13
+ # 1. 读取数据和预处理
14
+ def read_label_file(file_path):
15
+ data = []
16
+ with open(file_path, 'r') as f:
17
+ for line in f.readlines():
18
+ image_name, label = line.strip().split(',')
19
+ data.append([image_name, 1 if label == 'one' else 0])
20
+ return pd.DataFrame(data, columns=['image', 'label'])
21
+ # 读取a.txt中的图片名称
22
+ with open('./data/FSC147/train.txt', 'r') as file:
23
+ a_txt_images = file.read().splitlines()
24
+
25
+ # 提取.jpg前的数字
26
+ a_txt_numbers = set([name.split('.')[0] for name in a_txt_images])
27
+
28
+ # 从label.txt中读取图片名称和标签
29
+ with open('./data/FSC147/one/labels.txt', 'r') as file:
30
+ label_txt_lines = file.read().splitlines()
31
+
32
+ # 筛选出存在于a.txt中的图片
33
+ filtered_images = []
34
+ for line in label_txt_lines:
35
+ image_name, label = line.strip().split(',')
36
+ # 使用正则表达式匹配开头的数字
37
+ match = re.match(r'(\d+)', image_name)
38
+ if match:
39
+ image_number = match.group(1)
40
+ if image_number in a_txt_numbers:
41
+ # 转换'label'的值
42
+ label_value = 1 if label == 'one' else 0
43
+ filtered_images.append([image_name, label_value]) # 注意这里是列表,以匹配read_label_file的输出
44
+
45
+ # 将筛选后的图片和标签转换为DataFrame,确保列名与read_label_file函数的输出相匹配
46
+ df_filtered = pd.DataFrame(filtered_images, columns=['image', 'label'])
47
+
48
+ # 自定义Dataset类
49
+ class CustomDataset(Dataset):
50
+ def __init__(self, dataframe, root_dir, transform=None):
51
+ self.dataframe = dataframe
52
+ self.root_dir = root_dir
53
+ self.transform = transform
54
+
55
+ def __len__(self):
56
+ return len(self.dataframe)
57
+
58
+ def __getitem__(self, idx):
59
+ img_name = os.path.join(self.root_dir, self.dataframe.iloc[idx, 0])
60
+ image = Image.open(img_name).convert('RGB')
61
+ label = self.dataframe.iloc[idx, 1]
62
+ if self.transform:
63
+ image = self.transform(image)
64
+ return image, label
65
+
66
+ # 2. 数据集划分
67
+ data_folder = './data/FSC147/one'
68
+ label_file = os.path.join(data_folder, 'labels.txt')
69
+ # df = read_label_file(label_file)
70
+ df = df_filtered
71
+ train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
72
+
73
+ # 3. 数据加载
74
+ transform = Compose([
75
+ Resize((224, 224)),
76
+ ToTensor(),
77
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
78
+ ])
79
+
80
+ train_dataset = CustomDataset(train_df, data_folder, transform=transform)
81
+ test_dataset = CustomDataset(test_df, data_folder, transform=transform)
82
+
83
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
84
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
85
+
86
+ # 4. 模型定义
87
+ class ClipClassifier(nn.Module):
88
+ def __init__(self, clip_model, embed_dim=512):
89
+ super(ClipClassifier, self).__init__()
90
+ self.clip_model = clip_model
91
+ # 冻结CLIP模型的参数
92
+ for param in self.clip_model.parameters():
93
+ param.requires_grad = False
94
+ self.fc = nn.Linear(clip_model.visual.output_dim, embed_dim)
95
+ self.classifier = nn.Linear(embed_dim, 2) # 二分类
96
+
97
+ def forward(self, images):
98
+ with torch.no_grad():
99
+ image_features = self.clip_model.encode_image(images).float()
100
+ x = self.fc(image_features)
101
+ x = F.relu(x)
102
+ logits = self.classifier(x)
103
+ return logits
104
+ class ResNetClassifier(nn.Module):
105
+ def __init__(self, num_classes=2):
106
+ super(ResNetClassifier, self).__init__()
107
+ # 加载预训练的ResNet50模型
108
+ self.resnet50 = models.resnet50(pretrained=True)
109
+ # 冻结所有预训练层的参数
110
+ for param in self.resnet50.parameters():
111
+ param.requires_grad = False
112
+ # 替换最后的全连接层以适应二分类任务
113
+ num_ftrs = self.resnet50.fc.in_features
114
+ self.resnet50.fc = nn.Linear(num_ftrs, num_classes)
115
+
116
+ def forward(self, images):
117
+ return self.resnet50(images)
118
+
119
+ # 5. 训练和测试
120
+ device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
121
+ clip_model, _ = clip.load("ViT-B/32", device=device)
122
+ # model = ClipClassifier(clip_model).to(device)
123
+ model = ResNetClassifier().to(device)
124
+
125
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
126
+ criterion = nn.CrossEntropyLoss()
127
+
128
+ def train(model, device, train_loader, optimizer, epoch):
129
+ model.train()
130
+ for batch_idx, (data, target) in enumerate(train_loader):
131
+ data, target = data.to(device), target.to(device)
132
+ optimizer.zero_grad()
133
+ output = model(data)
134
+ loss = criterion(output, target)
135
+ loss.backward()
136
+ optimizer.step()
137
+ if batch_idx % 10 == 0:
138
+ print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
139
+
140
+ def test(model, device, test_loader):
141
+ model.eval()
142
+ test_loss = 0
143
+ correct = 0
144
+ with torch.no_grad():
145
+ for data, target in test_loader:
146
+ data, target = data.to(device), target.to(device)
147
+ output = model(data)
148
+ test_loss += criterion(output, target).item()
149
+ pred = output.argmax(dim=1, keepdim=True)
150
+ correct += pred.eq(target.view_as(pred)).sum().item()
151
+ test_loss /= len(test_loader.dataset)
152
+ print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')
153
+ return 100. * correct / len(test_loader.dataset)
154
+
155
+ best_accuracy = 0.0
156
+ for epoch in range(1, 11):
157
+ train(model, device, train_loader, optimizer, epoch)
158
+ accuracy = test(model, device, test_loader)
159
+ if accuracy > best_accuracy:
160
+ best_accuracy = accuracy
161
+ torch.save(model.state_dict(), './data/out/classify/best_model.pth')
162
+ print(f'Best model saved with accuracy: {best_accuracy:.2f}%')
163
+
datasetmake.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ import random
4
+
5
+ def is_image_file(filename):
6
+ """判断文件是否是图像文件"""
7
+ image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] # 支持的图像文件扩展名列表
8
+ return any(filename.lower().endswith(ext) for ext in image_extensions)
9
+
10
+ def random_crop(img, size=(256, 256)):
11
+ """从给定的图片中随机裁剪出指定大小的区域"""
12
+ width, height = img.size
13
+ crop_width, crop_height = size
14
+
15
+ if width < crop_width or height < crop_height:
16
+ return None # 如果图片尺寸小于裁剪尺寸,则返回None
17
+
18
+ x_left = random.randint(0, width - crop_width)
19
+ y_upper = random.randint(0, height - crop_height)
20
+
21
+ return img.crop((x_left, y_upper, x_left + crop_width, y_upper + crop_height))
22
+
23
+ # 文件夹路径设置(根据实际情况修改)
24
+ single_object_folder = './data/FSC147/box'
25
+ multiple_objects_folder = './data/FSC147/images_384_VarV2'
26
+ output_folder = './data/FSC147/one'
27
+
28
+ # 确保输出文件夹存在
29
+ if not os.path.exists(output_folder):
30
+ os.makedirs(output_folder)
31
+
32
+ output_txt_path = os.path.join(output_folder, 'labels.txt')
33
+ with open(output_txt_path, 'w') as f:
34
+ for folder, label in [(single_object_folder, 'one'), (multiple_objects_folder, 'more')]:
35
+ for filename in os.listdir(folder):
36
+ if is_image_file(filename): # 只处理图像文件
37
+ img_path = os.path.join(folder, filename)
38
+ img = Image.open(img_path)
39
+
40
+ # 保存原图并记录到txt文件
41
+ original_img_output_path = os.path.join(output_folder, filename)
42
+ img.save(original_img_output_path)
43
+ f.write(f"{filename},{label}\n")
44
+
45
+ # 从原图中随机裁剪并保存裁剪图像
46
+ for size in [(256, 384), (256, 256), (384, 384),(128,256),(256,128)]:
47
+ img_cropped = random_crop(img, size=size)
48
+ if img_cropped:
49
+ cropped_img_output_path = os.path.join(output_folder, f"{filename[:-4]}_random_{size[0]}x{size[1]}.jpg")
50
+ img_cropped.save(cropped_img_output_path)
51
+ f.write(f"{filename[:-4]}_random_{size[0]}x{size[1]}.jpg,{label}\n")
52
+
53
+ print("数据集准备完成。")
figure.png ADDED
grounding_neg.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import inflect
4
+ import argparse
5
+ from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict
6
+ from PIL import Image
7
+ import numpy as np
8
+ from torchvision.ops import box_convert
9
+ import json
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import clip
13
+
14
+ # 定义全局变量
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # 阈值设置
18
+ BOX_THRESHOLD = 0.02
19
+ TEXT_THRESHOLD = 0.02
20
+ BOX_THRESHOLD_class = 0.01
21
+ TEXT_THRESHOLD_class = 0.01
22
+
23
+ # 初始化inflect引擎
24
+ p = inflect.engine()
25
+
26
+ # 将单词转换为单数形式的函数
27
+ def to_singular(word):
28
+ singular_word = p.singular_noun(word)
29
+ return singular_word if singular_word else word
30
+
31
+ # 定义ClipClassifier类
32
+ class ClipClassifier(nn.Module):
33
+ def __init__(self, clip_model, embed_dim=512):
34
+ super(ClipClassifier, self).__init__()
35
+ self.clip_model = clip_model.to(device)
36
+ for param in self.clip_model.parameters():
37
+ param.requires_grad = False
38
+ self.fc = nn.Linear(clip_model.visual.output_dim, embed_dim)
39
+ self.classifier = nn.Linear(embed_dim, 2) # 二分类
40
+
41
+ def forward(self, images):
42
+ with torch.no_grad():
43
+ image_features = self.clip_model.encode_image(images).float().to(device)
44
+ x = self.fc(image_features)
45
+ x = F.relu(x)
46
+ logits = self.classifier(x)
47
+ return logits
48
+
49
+ # 初始化和加载二分类模型
50
+ clip_model, preprocess = clip.load("ViT-B/32", device)
51
+ binary_classifier = ClipClassifier(clip_model).to(device)
52
+
53
+ # 加载保存的权重
54
+ model_weights_path = './data/out/classify/best_model.pth'
55
+ binary_classifier.load_state_dict(torch.load(model_weights_path, map_location=device))
56
+
57
+ # 确认模型已经被设置为评估模式
58
+ binary_classifier.eval()
59
+
60
+ # 计算两个边界框的IoU
61
+ def calculate_iou(box1, box2):
62
+ x1, y1, w1, h1 = box1
63
+ x2, y2, w2, h2 = box2
64
+
65
+ intersection_x1 = max(x1, x2)
66
+ intersection_y1 = max(y1, y2)
67
+ intersection_x2 = min(x1 + w1, x2 + w2)
68
+ intersection_y2 = min(y1 + h1, y2 + h2)
69
+
70
+ intersection_area = max(intersection_x2 - intersection_x1, 0) * max(intersection_y2 - intersection_y1, 0)
71
+ box1_area = w1 * h1
72
+ box2_area = w2 * h2
73
+ union_area = box1_area + box2_area - intersection_area
74
+ iou = intersection_area / union_area if union_area > 0 else 0
75
+
76
+ return iou
77
+
78
+ # 检查patch是否有效
79
+ def is_valid_patch(patch, binary_classifier, preprocess, device):
80
+ if patch.size[0] <= 0 or patch.size[1] <= 0:
81
+ return False
82
+
83
+ patch_tensor = preprocess(patch).unsqueeze(0).to(device)
84
+ with torch.no_grad():
85
+ logits = binary_classifier(patch_tensor)
86
+ probabilities = torch.softmax(logits, dim=1)
87
+ prob_label_1 = probabilities[0, 1]
88
+ return prob_label_1.item() > 0.8
89
+
90
+ # 处理图片的主函数
91
+ def process_images(text_file_path, dataset_path, model, preprocess, binary_classifier, output_folder, device='cpu'):
92
+ boxes_dict = {}
93
+
94
+ with open(text_file_path, 'r') as f:
95
+ for line in f:
96
+ image_name, class_name = line.strip().split('\t')
97
+ print(f"Processing image: {image_name}")
98
+ text_prompt = class_name + ' .'
99
+ object_prompt = "object ."
100
+ image_path = os.path.join(dataset_path, image_name)
101
+ img = Image.open(image_path).convert("RGB")
102
+ image_source, image = load_image(image_path)
103
+ h, w, _ = image_source.shape
104
+ boxes_object, logits_object, _ = predict(model, image, object_prompt, BOX_THRESHOLD, TEXT_THRESHOLD)
105
+ boxes_class, logits_class, _ = predict(model, image, text_prompt, BOX_THRESHOLD_class, TEXT_THRESHOLD_class)
106
+
107
+ patches_object = box_convert(boxes_object, in_fmt="cxcywh", out_fmt="xyxy")
108
+ patches_class = box_convert(boxes_class, in_fmt="cxcywh", out_fmt="xyxy")
109
+
110
+ top_patches = []
111
+ iou_matrix = np.zeros((len(boxes_object), len(boxes_class)))
112
+
113
+ for j, box_class in enumerate(patches_class):
114
+ box_object_class = box_class.cpu().numpy() * np.array([w, h, w, h], dtype=np.float32)
115
+ x1_, y1_, x2_, y2_ = box_object_class.astype(int)
116
+ x1_, y1_, x2_, y2_ = max(x1_, 0), max(y1_, 0), min(x2_, w), min(y2_, h)
117
+ patch_ = img.crop((x1_, y1_, x2_, y2_))
118
+ if x2_ - x1_ > w / 2 or y2_ - y1_ > h / 2 or not is_valid_patch(patch_, binary_classifier, preprocess, device):
119
+ print(f"Skipping patch at box {box_class}")
120
+ continue
121
+ for i, box_object in enumerate(patches_object):
122
+ iou_matrix[i][j] = calculate_iou(box_object.cpu().numpy(), box_class.cpu().numpy())
123
+
124
+ for i, box_object in enumerate(patches_object):
125
+ max_iou = np.max(iou_matrix[i])
126
+ if max_iou < 0.5:
127
+ box_object = box_object.cpu().numpy() * np.array([w, h, w, h], dtype=np.float32)
128
+ x1, y1, x2, y2 = box_object.astype(int)
129
+ x1, y1, x2, y2 = max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)
130
+ patch = img.crop((x1, y1, x2, y2))
131
+ if patch.size == (0, 0) or not is_valid_patch(patch, binary_classifier, preprocess, device) or x2 - x1 > w / 2 or y2 - y1 > h / 2 or y2 - y1 < 5 or x2 - x1 < 5:
132
+ print(f"Skipping patch at box {box_object}")
133
+ continue
134
+ patch_logits = logits_object[i]
135
+ top_patches.append((i, patch_logits.item()))
136
+
137
+ top_patches.sort(key=lambda x: x[1], reverse=True)
138
+ top_3_indices = [patch[0] for patch in top_patches[:3]]
139
+
140
+ while len(top_3_indices) < 3:
141
+ if len(top_3_indices) > 0:
142
+ top_3_indices.append(top_3_indices[-1])
143
+ else:
144
+ default_box = torch.tensor([0,0,20/w,20/h]).unsqueeze(0)
145
+ patches_object = torch.cat((patches_object, default_box.to(boxes_object.device)), dim=0)
146
+ top_3_indices.append(len(patches_object) - 1)
147
+
148
+ boxes_dict[image_name] = [patches_object[idx].cpu().numpy().tolist() * np.array([w, h, w, h], dtype=np.float32) for idx in top_3_indices]
149
+
150
+ return boxes_dict
151
+
152
+ def main(args):
153
+ # 设置固定的默认路径
154
+ model_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
155
+ model_weights = "GroundingDINO/weights/groundingdino_swint_ogc.pth"
156
+
157
+ # 根据root_path设置路径
158
+ text_file_path = os.path.join(args.root_path, "ImageClasses_FSC147.txt")
159
+ dataset_path = os.path.join(args.root_path, "images_384_VarV2")
160
+ input_json_path = os.path.join(args.root_path, "annotation_FSC147_384.json")
161
+ output_json_path = os.path.join(args.root_path, "annotation_FSC147_neg.json")
162
+ output_folder = os.path.join(args.root_path, "annotated_images_n")
163
+
164
+ os.makedirs(output_folder, exist_ok=True)
165
+
166
+ # 加载GroundingDINO模型
167
+ model = load_model(model_config, model_weights, device=device)
168
+
169
+ # 处理图片并生成边界框
170
+ boxes_dict = process_images(text_file_path, dataset_path, model, preprocess, binary_classifier, output_folder, device=device)
171
+
172
+ # 更新JSON文件
173
+ with open(input_json_path, 'r') as f:
174
+ data = json.load(f)
175
+
176
+ for image_name, boxes in boxes_dict.items():
177
+ if image_name in data:
178
+ new_boxes = [[[x1, y1], [x1, y2], [x2, y2], [x2, y1]] for x1, y1, x2, y2 in boxes]
179
+ data[image_name]["box_examples_coordinates"] = new_boxes
180
+
181
+ with open(output_json_path, 'w') as f:
182
+ json.dump(data, f, indent=4)
183
+
184
+ if __name__ == "__main__":
185
+ parser = argparse.ArgumentParser(description="Image Processing Script")
186
+ parser.add_argument("--root_path", type=str, required=True, help="Root path to the dataset and output files")
187
+ args = parser.parse_args()
188
+ main(args)
grounding_pos.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import clip
4
+ import inflect
5
+ import argparse
6
+ from torchvision.ops import box_convert
7
+ from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict
8
+ from PIL import Image
9
+ import numpy as np
10
+ import json
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ # 定义全局变量
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ BOX_THRESHOLD = 0.05
17
+ TEXT_THRESHOLD = 0.05
18
+
19
+ # 初始化inflect引擎
20
+ p = inflect.engine()
21
+
22
+ # 定义 ClipClassifier 类
23
+ class ClipClassifier(nn.Module):
24
+ def __init__(self, clip_model, embed_dim=512):
25
+ super(ClipClassifier, self).__init__()
26
+ self.clip_model = clip_model.to(device)
27
+ for param in self.clip_model.parameters():
28
+ param.requires_grad = False
29
+ self.fc = nn.Linear(clip_model.visual.output_dim, embed_dim)
30
+ self.classifier = nn.Linear(embed_dim, 2) # 二分类
31
+
32
+ def forward(self, images):
33
+ with torch.no_grad():
34
+ image_features = self.clip_model.encode_image(images).float().to(device)
35
+ x = self.fc(image_features)
36
+ x = F.relu(x)
37
+ logits = self.classifier(x)
38
+ return logits
39
+
40
+ # 加载 CLIP 模型
41
+ clip_model, preprocess = clip.load("ViT-B/32", device)
42
+ clip_model.eval()
43
+
44
+ # 初始化并加载二分类模型
45
+ binary_classifier = ClipClassifier(clip_model).to(device)
46
+ model_weights_path = './data/out/classify/best_model.pth'
47
+ binary_classifier.load_state_dict(torch.load(model_weights_path, map_location=device))
48
+ binary_classifier.eval()
49
+
50
+ # 判断 patch 是否有效
51
+ def is_valid_patch(patch, binary_classifier, preprocess, device):
52
+ if patch.size[0] <= 0 or patch.size[1] <= 0:
53
+ return False
54
+ patch_tensor = preprocess(patch).unsqueeze(0).to(device)
55
+ with torch.no_grad():
56
+ logits = binary_classifier(patch_tensor)
57
+ probabilities = torch.softmax(logits, dim=1)
58
+ prob_label_1 = probabilities[0, 1]
59
+ return prob_label_1.item() > 0.8
60
+
61
+ # 处理图片的主函数
62
+ def process_images(text_file_path, dataset_path, model, preprocess, clip_model, output_folder, device='cpu'):
63
+ boxes_dict = {}
64
+ with open(text_file_path, 'r') as f:
65
+ for line in f:
66
+ image_name, class_name = line.strip().split('\t')
67
+ print(f"Processing image: {image_name}")
68
+ text_prompt = class_name + ' .'
69
+ image_path = os.path.join(dataset_path, image_name)
70
+ img = Image.open(image_path).convert("RGB")
71
+ image_source, image = load_image(image_path)
72
+ h, w, _ = image_source.shape
73
+ boxes, logits, _ = predict(model, image, text_prompt, BOX_THRESHOLD, TEXT_THRESHOLD)
74
+ patches = box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy")
75
+
76
+ top_patches = []
77
+ for i, (box, logit) in enumerate(zip(patches, logits)):
78
+ box = box.cpu().numpy() * np.array([w, h, w, h], dtype=np.float32)
79
+ x1, y1, x2, y2 = box.astype(int)
80
+ x1, y1, x2, y2 = max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)
81
+ patch = img.crop((x1, y1, x2, y2))
82
+
83
+ if patch.size == (0, 0) or not is_valid_patch(patch, binary_classifier, preprocess, device) or x2 - x1 > w / 2 or y2 - y1 > h / 2 or y2 - y1 < 5 or x2 - x1 < 5:
84
+ print(f"Skipping patch due to binary classifier at box {box}")
85
+ continue
86
+ top_patches.append((i, logit))
87
+
88
+ top_patches.sort(key=lambda x: x[1], reverse=True)
89
+ top_3_indices = [patch[0] for patch in top_patches[:3]]
90
+
91
+ # 确保每张图像都有三个边界框
92
+ while len(top_3_indices) < 3:
93
+ if len(top_3_indices) > 0:
94
+ top_3_indices.append(top_3_indices[-1])
95
+ else:
96
+ default_box = torch.tensor([0, 0, 20 / w, 20 / h]).unsqueeze(0)
97
+ patches = torch.cat((patches, default_box.to(boxes.device)), dim=0)
98
+ top_3_indices.append(len(patches) - 1)
99
+
100
+ boxes_dict[image_name] = [patches[idx].cpu().numpy().tolist() * np.array([w, h, w, h], dtype=np.float32) for idx in top_3_indices]
101
+
102
+ return boxes_dict
103
+
104
+ # 主函数
105
+ def main(args):
106
+ # 设置固定的默认路径
107
+ model_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
108
+ model_weights = "GroundingDINO/weights/groundingdino_swint_ogc.pth"
109
+ output_folder = os.path.join(args.root_path, "annotated_images")
110
+
111
+ # 根据 root_path 设置路径
112
+ text_file_path = os.path.join(args.root_path, "ImageClasses_FSC147.txt")
113
+ dataset_path = os.path.join(args.root_path, "images_384_VarV2")
114
+ input_json_path = os.path.join(args.root_path, "annotation_FSC147_384_old.json")
115
+ output_json_path = os.path.join(args.root_path, "annotation_FSC147_pos.json")
116
+
117
+ os.makedirs(output_folder, exist_ok=True)
118
+
119
+ # 加载 GroundingDINO 模型
120
+ model = load_model(model_config, model_weights, device=device)
121
+
122
+ # 处理��片并生成边界框
123
+ boxes_dict = process_images(text_file_path, dataset_path, model, preprocess, clip_model, output_folder, device=device)
124
+
125
+ # 更新 JSON 文件
126
+ with open(input_json_path, 'r') as f:
127
+ data = json.load(f)
128
+
129
+ for image_name, boxes in boxes_dict.items():
130
+ if image_name in data:
131
+ new_boxes = [[[x1, y1], [x1, y2], [x2, y2], [x2, y1]] for x1, y1, x2, y2 in boxes]
132
+ data[image_name]["box_examples_coordinates"] = new_boxes
133
+
134
+ with open(output_json_path, 'w') as f:
135
+ json.dump(data, f, indent=4)
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser(description="Image Processing Script")
139
+ parser.add_argument("--root_path", type=str, required=True, help="Root path to the dataset and output files")
140
+ args = parser.parse_args()
141
+ main(args)
models_crossvit.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.hub
5
+ from itertools import repeat
6
+ import collections.abc
7
+
8
+
9
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
10
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
11
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
12
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
13
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
14
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
15
+ 'survival rate' as the argument.
16
+ """
17
+ if drop_prob == 0. or not training:
18
+ return x
19
+ keep_prob = 1 - drop_prob
20
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
21
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
22
+ if keep_prob > 0.0 and scale_by_keep:
23
+ random_tensor.div_(keep_prob)
24
+ return x * random_tensor
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
28
+ """
29
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+ self.scale_by_keep = scale_by_keep
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
36
+
37
+ def _ntuple(n):
38
+ def parse(x):
39
+ if isinstance(x, collections.abc.Iterable):
40
+ return x
41
+ return tuple(repeat(x, n))
42
+ return parse
43
+
44
+ to_2tuple = _ntuple(2)
45
+
46
+ class Mlp(nn.Module):
47
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
48
+ """
49
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50
+ super().__init__()
51
+ out_features = out_features or in_features
52
+ hidden_features = hidden_features or in_features
53
+ drop_probs = to_2tuple(drop)
54
+
55
+ self.fc1 = nn.Linear(in_features, hidden_features)
56
+ self.act = act_layer()
57
+ self.drop1 = nn.Dropout(drop_probs[0])
58
+ self.fc2 = nn.Linear(hidden_features, out_features)
59
+ self.drop2 = nn.Dropout(drop_probs[1])
60
+
61
+ def forward(self, x):
62
+ x = self.fc1(x)
63
+ x = self.act(x)
64
+ x = self.drop1(x)
65
+ x = self.fc2(x)
66
+ x = self.drop2(x)
67
+ return x
68
+
69
+ class Attention(nn.Module):
70
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
71
+ super().__init__()
72
+ self.num_heads = num_heads
73
+ head_dim = dim // num_heads
74
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
75
+ self.scale = qk_scale or head_dim ** -0.5
76
+
77
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
78
+ self.attn_drop = nn.Dropout(attn_drop)
79
+ self.proj = nn.Linear(dim, dim)
80
+ self.proj_drop = nn.Dropout(proj_drop)
81
+
82
+ def forward(self, x):
83
+ B, N, C = x.shape
84
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
85
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
86
+
87
+ attn = (q @ k.transpose(-2, -1)) * self.scale
88
+ attn = attn.softmax(dim=-1)
89
+ attn = self.attn_drop(attn)
90
+
91
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
92
+ x = self.proj(x)
93
+ x = self.proj_drop(x)
94
+ return x
95
+
96
+ class CrossAttention(nn.Module):
97
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
98
+ super().__init__()
99
+ self.num_heads = num_heads
100
+ head_dim = dim // num_heads
101
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
102
+ self.scale = qk_scale or head_dim ** -0.5
103
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
104
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
105
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ def forward(self, x, y):
111
+ B, Nx, C = x.shape
112
+ Ny = y.shape[1]
113
+ # BNxC -> BNxH(C/H) -> BHNx(C/H)
114
+ q = self.wq(x).reshape(B, Nx, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
115
+ # BNyC -> BNyH(C/H) -> BHNy(C/H)
116
+ k = self.wk(y).reshape(B, Ny, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
117
+ # BNyC -> BNyH(C/H) -> BHNy(C/H)
118
+ v = self.wv(y).reshape(B, Ny, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
119
+
120
+ attn = (q @ k.transpose(-2, -1)) * self.scale # BHNx(C/H) @ BH(C/H)Ny -> BHNxNy
121
+ attn = attn.softmax(dim=-1)
122
+ attn = self.attn_drop(attn)
123
+
124
+ x = (attn @ v).transpose(1, 2).reshape(B, Nx, C) # (BHNxNy @ BHNy(C/H)) -> BHNx(C/H) -> BNxH(C/H) -> BNxC
125
+ x = self.proj(x)
126
+ x = self.proj_drop(x)
127
+ return x
128
+
129
+ class CrossAttentionBlock(nn.Module):
130
+
131
+ def __init__(
132
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
133
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
134
+ super().__init__()
135
+
136
+ self.norm0 = norm_layer(dim)
137
+ self.selfattn = Attention(
138
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
139
+ self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
140
+
141
+ self.norm1 = norm_layer(dim)
142
+ self.attn = CrossAttention(
143
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
144
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
145
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
146
+
147
+ self.norm2 = norm_layer(dim)
148
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
149
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
150
+
151
+ def forward(self, x, y):
152
+ x = x + self.drop_path0(self.selfattn(self.norm0(x)))
153
+ x = x + self.drop_path1(self.attn(self.norm1(x), y))
154
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
155
+ return x
models_mae_cross.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from functools import partial
3
+ import math
4
+ import random
5
+
6
+ import numpy as np
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.utils
12
+
13
+ from timm.models.vision_transformer import PatchEmbed, Block
14
+ from models_crossvit import CrossAttentionBlock
15
+
16
+ from util.pos_embed import get_2d_sincos_pos_embed
17
+
18
+ class SupervisedMAE(nn.Module):
19
+ def __init__(self, img_size=384, patch_size=16, in_chans=3,
20
+ embed_dim=1024, depth=24, num_heads=16,
21
+ decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
22
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
23
+ super().__init__()
24
+
25
+ # --------------------------------------------------------------------------
26
+ # MAE encoder specifics
27
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
28
+ num_patches = self.patch_embed.num_patches
29
+
30
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) # fixed sin-cos embedding
31
+
32
+ self.blocks = nn.ModuleList([
33
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
34
+ for i in range(depth)])
35
+ self.norm = norm_layer(embed_dim)
36
+ # --------------------------------------------------------------------------
37
+
38
+ # --------------------------------------------------------------------------
39
+ # MAE decoder specifics
40
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
41
+
42
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
43
+
44
+ self.shot_token = nn.Parameter(torch.zeros(512))
45
+
46
+ # Exemplar encoder with CNN
47
+ self.decoder_proj1 = nn.Sequential(
48
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
49
+ nn.InstanceNorm2d(64),
50
+ nn.ReLU(inplace=True),
51
+ nn.MaxPool2d(2) #[3,64,64]->[64,32,32]
52
+ )
53
+ self.decoder_proj2 = nn.Sequential(
54
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
55
+ nn.InstanceNorm2d(128),
56
+ nn.ReLU(inplace=True),
57
+ nn.MaxPool2d(2) #[64,32,32]->[128,16,16]
58
+ )
59
+ self.decoder_proj3 = nn.Sequential(
60
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
61
+ nn.InstanceNorm2d(256),
62
+ nn.ReLU(inplace=True),
63
+ nn.MaxPool2d(2) # [128,16,16]->[256,8,8]
64
+ )
65
+ self.decoder_proj4 = nn.Sequential(
66
+ nn.Conv2d(256, decoder_embed_dim, kernel_size=3, stride=1, padding=1),
67
+ nn.InstanceNorm2d(512),
68
+ nn.ReLU(inplace=True),
69
+ nn.AdaptiveAvgPool2d((1,1))
70
+ # [256,8,8]->[512,1,1]
71
+ )
72
+
73
+
74
+ self.decoder_blocks = nn.ModuleList([
75
+ CrossAttentionBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
76
+ for i in range(decoder_depth)])
77
+
78
+ self.decoder_norm = norm_layer(decoder_embed_dim)
79
+ # Density map regresssion module
80
+ self.decode_head0 = nn.Sequential(
81
+ nn.Conv2d(decoder_embed_dim, 256, kernel_size=3, stride=1, padding=1),
82
+ nn.GroupNorm(8, 256),
83
+ nn.ReLU(inplace=True)
84
+ )
85
+ self.decode_head1 = nn.Sequential(
86
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
87
+ nn.GroupNorm(8, 256),
88
+ nn.ReLU(inplace=True)
89
+ )
90
+ self.decode_head2 = nn.Sequential(
91
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
92
+ nn.GroupNorm(8, 256),
93
+ nn.ReLU(inplace=True)
94
+ )
95
+ self.decode_head3 = nn.Sequential(
96
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
97
+ nn.GroupNorm(8, 256),
98
+ nn.ReLU(inplace=True),
99
+ nn.Conv2d(256, 1, kernel_size=1, stride=1)
100
+ )
101
+
102
+ # --------------------------------------------------------------------------
103
+
104
+ self.norm_pix_loss = norm_pix_loss
105
+
106
+ self.initialize_weights()
107
+
108
+ def initialize_weights(self):
109
+ # initialization
110
+ # initialize (and freeze) pos_embed by sin-cos embedding
111
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
112
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
113
+
114
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
115
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
116
+
117
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
118
+ w = self.patch_embed.proj.weight.data
119
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
120
+
121
+ torch.nn.init.normal_(self.shot_token, std=.02)
122
+
123
+ # initialize nn.Linear and nn.LayerNorm
124
+ self.apply(self._init_weights)
125
+
126
+ def _init_weights(self, m):
127
+ if isinstance(m, nn.Linear):
128
+ # we use xavier_uniform following official JAX ViT:
129
+ torch.nn.init.xavier_uniform_(m.weight)
130
+ if isinstance(m, nn.Linear) and m.bias is not None:
131
+ nn.init.constant_(m.bias, 0)
132
+ elif isinstance(m, nn.LayerNorm):
133
+ nn.init.constant_(m.bias, 0)
134
+ nn.init.constant_(m.weight, 1.0)
135
+
136
+ def forward_encoder(self, x):
137
+ # embed patches
138
+ x = self.patch_embed(x)
139
+
140
+ # add pos embed w/o cls token
141
+ x = x + self.pos_embed
142
+
143
+ # apply Transformer blocks
144
+ for blk in self.blocks:
145
+ x = blk(x)
146
+ x = self.norm(x)
147
+
148
+ return x
149
+
150
+ def forward_decoder(self, x, y_, shot_num=3):
151
+ # embed tokens
152
+ x = self.decoder_embed(x)
153
+ # add pos embed
154
+ x = x + self.decoder_pos_embed
155
+
156
+ # Exemplar encoder
157
+ y_ = y_.transpose(0,1) # y_ [N,3,3,64,64]->[3,N,3,64,64]
158
+ y1=[]
159
+ C=0
160
+ N=0
161
+ cnt = 0
162
+ for yi in y_:
163
+ cnt+=1
164
+ if cnt > shot_num:
165
+ break
166
+ yi = self.decoder_proj1(yi)
167
+ yi = self.decoder_proj2(yi)
168
+ yi = self.decoder_proj3(yi)
169
+ yi = self.decoder_proj4(yi)
170
+ N, C,_,_ = yi.shape
171
+ y1.append(yi.squeeze(-1).squeeze(-1)) # yi [N,C,1,1]->[N,C]
172
+
173
+ if shot_num > 0:
174
+ y = torch.cat(y1,dim=0).reshape(shot_num,N,C).to(x.device)
175
+ else:
176
+ y = self.shot_token.repeat(y_.shape[1],1).unsqueeze(0).to(x.device)
177
+ y = y.transpose(0,1) # y [3,N,C]->[N,3,C]
178
+
179
+ # apply Transformer blocks
180
+ for blk in self.decoder_blocks:
181
+ x = blk(x, y)
182
+ x = self.decoder_norm(x)
183
+
184
+ # Density map regression
185
+ n, hw, c = x.shape
186
+ h = w = int(math.sqrt(hw))
187
+ x = x.transpose(1, 2).reshape(n, c, h, w)
188
+
189
+ x = F.interpolate(
190
+ self.decode_head0(x), size=x.shape[-1]*2, mode='bilinear', align_corners=False)
191
+ x = F.interpolate(
192
+ self.decode_head1(x), size=x.shape[-1]*2, mode='bilinear', align_corners=False)
193
+ x = F.interpolate(
194
+ self.decode_head2(x), size=x.shape[-1]*2, mode='bilinear', align_corners=False)
195
+ x = F.interpolate(
196
+ self.decode_head3(x), size=x.shape[-1]*2, mode='bilinear', align_corners=False)
197
+ x = x.squeeze(-3)
198
+
199
+ return x
200
+
201
+ def forward(self, imgs, boxes, shot_num):
202
+ # if boxes.nelement() > 0:
203
+ # torchvision.utils.save_image(boxes[0], f"data/out/crops/box_{time.time()}_{random.randint(0, 99999):>5}.png")
204
+ with torch.no_grad():
205
+ latent = self.forward_encoder(imgs)
206
+ pred = self.forward_decoder(latent, boxes, shot_num) # [N, 384, 384]
207
+ return pred
208
+
209
+
210
+ def mae_vit_base_patch16_dec512d8b(**kwargs):
211
+ model = SupervisedMAE(
212
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
213
+ decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
214
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
215
+ return model
216
+
217
+
218
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
219
+ model = SupervisedMAE(
220
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
221
+ decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
222
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
223
+ return model
224
+
225
+
226
+ def mae_vit_huge_patch14_dec512d8b(**kwargs):
227
+ model = SupervisedMAE(
228
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16,
229
+ decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
230
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
231
+ return model
232
+
233
+ def mae_vit_base_patch16_fim4(**kwargs):
234
+ model = SupervisedMAE(
235
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
236
+ decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16,
237
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
238
+ return model
239
+
240
+ def mae_vit_base_patch16_fim6(**kwargs):
241
+ model = SupervisedMAE(
242
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
243
+ decoder_embed_dim=512, decoder_depth=6, decoder_num_heads=16,
244
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
245
+ return model
246
+
247
+
248
+ # set recommended archs
249
+ mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b
250
+ mae_vit_base4_patch16 = mae_vit_base_patch16_fim4 # decoder: 4 blocks
251
+ mae_vit_base6_patch16 = mae_vit_base_patch16_fim6 # decoder: 6 blocks
252
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b
253
+ mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b
models_mae_noct.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from timm.models.vision_transformer import PatchEmbed, Block
7
+
8
+ from util.pos_embed import get_2d_sincos_pos_embed
9
+
10
+
11
+ class MaskedAutoencoderViTNoCT(nn.Module):
12
+ """ Masked Autoencoder with VisionTransformer backbone
13
+ """
14
+ def __init__(self, img_size=384, patch_size=16, in_chans=3,
15
+ embed_dim=1024, depth=24, num_heads=16,
16
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
17
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
18
+ super().__init__()
19
+
20
+ # --------------------------------------------------------------------------
21
+ # MAE encoder specifics
22
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
23
+ num_patches = self.patch_embed.num_patches
24
+
25
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) # fixed sin-cos embedding
26
+
27
+ self.blocks = nn.ModuleList([
28
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
29
+ for i in range(depth)])
30
+ self.norm = norm_layer(embed_dim)
31
+ # --------------------------------------------------------------------------
32
+
33
+ # --------------------------------------------------------------------------
34
+ # MAE decoder specifics
35
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
36
+
37
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
38
+
39
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
40
+
41
+ self.decoder_blocks = nn.ModuleList([
42
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
43
+ for i in range(decoder_depth)])
44
+
45
+ self.decoder_norm = norm_layer(decoder_embed_dim)
46
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
47
+ # --------------------------------------------------------------------------
48
+
49
+ self.norm_pix_loss = norm_pix_loss
50
+
51
+ self.initialize_weights()
52
+
53
+ def initialize_weights(self):
54
+ # initialization
55
+ # initialize (and freeze) pos_embed by sin-cos embedding
56
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
57
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
58
+
59
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
60
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
61
+
62
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
63
+ w = self.patch_embed.proj.weight.data
64
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
65
+
66
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
67
+ torch.nn.init.normal_(self.mask_token, std=.02)
68
+
69
+ # initialize nn.Linear and nn.LayerNorm
70
+ self.apply(self._init_weights)
71
+
72
+ def _init_weights(self, m):
73
+ if isinstance(m, nn.Linear):
74
+ # we use xavier_uniform following official JAX ViT:
75
+ torch.nn.init.xavier_uniform_(m.weight)
76
+ if isinstance(m, nn.Linear) and m.bias is not None:
77
+ nn.init.constant_(m.bias, 0)
78
+ elif isinstance(m, nn.LayerNorm):
79
+ nn.init.constant_(m.bias, 0)
80
+ nn.init.constant_(m.weight, 1.0)
81
+
82
+ def patchify(self, imgs):
83
+ """
84
+ imgs: (N, 3, H, W)
85
+ x: (N, L, patch_size**2 *3)
86
+ """
87
+ p = self.patch_embed.patch_size[0]
88
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
89
+
90
+ h = w = imgs.shape[2] // p
91
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
92
+ x = torch.einsum('nchpwq->nhwpqc', x)
93
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
94
+ return x
95
+
96
+ def unpatchify(self, x):
97
+ """
98
+ x: (N, L, patch_size**2 *3)
99
+ imgs: (N, 3, H, W)
100
+ """
101
+ p = self.patch_embed.patch_size[0]
102
+ h = w = int(x.shape[1]**.5)
103
+ assert h * w == x.shape[1]
104
+
105
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
106
+ x = torch.einsum('nhwpqc->nchpwq', x)
107
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
108
+ return imgs
109
+
110
+ def random_masking(self, x, mask_ratio):
111
+ """
112
+ Perform per-sample random masking by per-sample shuffling.
113
+ Per-sample shuffling is done by argsort random noise.
114
+ x: [N, L, D], sequence
115
+ """
116
+ N, L, D = x.shape # batch, length, dim
117
+ len_keep = int(L * (1 - mask_ratio))
118
+
119
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
120
+
121
+ # sort noise for each sample
122
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
123
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
124
+
125
+ # keep the first subset
126
+ ids_keep = ids_shuffle[:, :len_keep]
127
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
128
+
129
+ # generate the binary mask: 0 is keep, 1 is remove
130
+ mask = torch.ones([N, L], device=x.device)
131
+ mask[:, :len_keep] = 0
132
+ # unshuffle to get the binary mask
133
+ mask = torch.gather(mask, dim=1, index=ids_restore)
134
+
135
+ return x_masked, mask, ids_restore
136
+
137
+ def forward_encoder(self, x, mask_ratio):
138
+ # embed patches
139
+ x = self.patch_embed(x)
140
+
141
+ # add pos embed w/o cls token
142
+ x = x + self.pos_embed
143
+
144
+ # masking: length -> length * mask_ratio
145
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
146
+
147
+ # apply Transformer blocks
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+ x = self.norm(x)
151
+
152
+ return x, mask, ids_restore
153
+
154
+ def forward_decoder(self, x, ids_restore):
155
+ # embed tokens
156
+ x = self.decoder_embed(x)
157
+
158
+ # append mask tokens to sequence
159
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
160
+ x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
161
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
162
+ x = x_ # append cls token
163
+
164
+ # add pos embed
165
+ x = x + self.decoder_pos_embed
166
+
167
+ # apply Transformer blocks
168
+ for blk in self.decoder_blocks:
169
+ x = blk(x)
170
+ x = self.decoder_norm(x)
171
+
172
+ # predictor projection
173
+ x = self.decoder_pred(x)
174
+
175
+ return x
176
+
177
+ def forward_loss(self, imgs, pred, mask):
178
+ """
179
+ imgs: [N, 3, H, W]
180
+ pred: [N, L, p*p*3]
181
+ mask: [N, L], 0 is keep, 1 is remove,
182
+ """
183
+ target = self.patchify(imgs)
184
+ if self.norm_pix_loss:
185
+ mean = target.mean(dim=-1, keepdim=True)
186
+ var = target.var(dim=-1, keepdim=True)
187
+ target = (target - mean) / (var + 1.e-6)**.5
188
+
189
+ loss = (pred - target) ** 2
190
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
191
+
192
+ # For mean loss on all patches
193
+ N, L = mask.shape
194
+ mask_s = torch.ones([N, L], device=imgs.device)
195
+ loss = (loss * mask_s).sum() / mask_s.sum()
196
+
197
+ #loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
198
+ return loss
199
+
200
+ def forward(self, imgs, mask_ratio=0.75):
201
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
202
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
203
+ loss = self.forward_loss(imgs, pred, mask)
204
+ return loss, pred, mask
205
+
206
+
207
+ def mae_vit_base_patch16_dec512d8b(**kwargs):
208
+ model = MaskedAutoencoderViTNoCT(
209
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
210
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
211
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
212
+ return model
213
+
214
+
215
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
216
+ model = MaskedAutoencoderViTNoCT(
217
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
218
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
219
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
220
+ return model
221
+
222
+
223
+ def mae_vit_huge_patch14_dec512d8b(**kwargs):
224
+ model = MaskedAutoencoderViTNoCT(
225
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16,
226
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
227
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
228
+ return model
229
+
230
+
231
+ # set recommended archs
232
+ mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
233
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
234
+ mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+
3
+ torch==1.13.1+cu116
4
+ torchvision==0.14.1+cu116
5
+ timm==0.4.9
6
+ numpy==1.23.4
7
+ scipy==1.10.1
8
+ imgaug==0.4.0
9
+ pillow==9.3.0
10
+ matplotlib==3.6.3
11
+ hub==3.0.1
12
+ pandas==1.5.2
13
+ six==1.16.0
14
+ wandb
15
+ tqdm
util/FSC147.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import random
7
+ from torchvision import transforms
8
+ import torch
9
+ import cv2
10
+ import torchvision.transforms.functional as TF
11
+ import scipy.ndimage as ndimage
12
+ from PIL import Image
13
+ import argparse
14
+ import imgaug.augmenters as iaa
15
+ from imgaug.augmentables import Keypoint, KeypointsOnImage
16
+
17
+ MAX_HW = 384
18
+ IM_NORM_MEAN = [0.485, 0.456, 0.406]
19
+ IM_NORM_STD = [0.229, 0.224, 0.225]
20
+
21
+ def get_args_parser():
22
+ parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
23
+ parser.add_argument('--batch_size', default=8, type=int,
24
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
25
+ parser.add_argument('--epochs', default=200, type=int)
26
+ parser.add_argument('--accum_iter', default=1, type=int,
27
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
28
+
29
+ # Model parameters
30
+ parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
31
+ help='Name of model to train')
32
+
33
+ parser.add_argument('--mask_ratio', default=0.5, type=float,
34
+ help='Masking ratio (percentage of removed patches).')
35
+
36
+ parser.add_argument('--norm_pix_loss', action='store_true',
37
+ help='Use (per-patch) normalized pixels as targets for computing loss')
38
+ parser.set_defaults(norm_pix_loss=False)
39
+
40
+ # Optimizer parameters
41
+ parser.add_argument('--weight_decay', type=float, default=0.05,
42
+ help='weight decay (default: 0.05)')
43
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
44
+ help='learning rate (absolute lr)')
45
+ parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
46
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
47
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
48
+ help='lower lr bound for cyclic schedulers that hit 0')
49
+ parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
50
+ help='epochs to warmup LR')
51
+
52
+ # Dataset parameters
53
+ parser.add_argument('--data_path', default='./data/FSC147/', type=str,
54
+ help='dataset path')
55
+ parser.add_argument('--anno_file', default='annotation_FSC147_384.json', type=str,
56
+ help='annotation json file')
57
+ parser.add_argument('--data_split_file', default='Train_Test_Val_FSC_147.json', type=str,
58
+ help='data split json file')
59
+ parser.add_argument('--im_dir', default='images_384_VarV2', type=str,
60
+ help='images directory')
61
+ parser.add_argument('--gt_dir', default='./data/FSC147/gt_density_map_adaptive_384_VarV2', type=str,
62
+ help='ground truth directory')
63
+ parser.add_argument('--output_dir', default='./data/out/pre_4_dir',
64
+ help='path where to save, empty for no saving')
65
+ parser.add_argument('--device', default='cuda',
66
+ help='device to use for training / testing')
67
+ parser.add_argument('--seed', default=0, type=int)
68
+ parser.add_argument('--resume', default='./weights/mae_pretrain_vit_base_full.pth', # mae_visualize_vit_base
69
+ help='resume from checkpoint')
70
+
71
+ # Training parameters
72
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
73
+ help='start epoch')
74
+ parser.add_argument('--num_workers', default=10, type=int)
75
+ parser.add_argument('--pin_mem', action='store_true',
76
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
77
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
78
+ parser.set_defaults(pin_mem=True)
79
+
80
+ # Distributed training parameters
81
+ parser.add_argument('--world_size', default=1, type=int,
82
+ help='number of distributed processes')
83
+ parser.add_argument('--local_rank', default=-1, type=int)
84
+ parser.add_argument('--dist_on_itp', action='store_true')
85
+ parser.add_argument('--dist_url', default='env://',
86
+ help='url used to set up distributed training')
87
+
88
+ # Logging parameters
89
+ parser.add_argument('--log_dir', default='./logs/pre_4_dir',
90
+ help='path where to tensorboard log')
91
+ parser.add_argument("--title", default="CounTR_pretraining", type=str)
92
+ parser.add_argument("--wandb", default="counting", type=str)
93
+ parser.add_argument("--team", default="wsense", type=str)
94
+ parser.add_argument("--wandb_id", default=None, type=str)
95
+ parser.add_argument("--do_aug", default=True, type=bool)
96
+ parser.add_argument('--class_file', default='./data/FSC147/ImageClasses_FSC147.txt', type=str,
97
+ help='class json file')
98
+ return parser
99
+
100
+ args = get_args_parser()
101
+ args = args.parse_args()
102
+
103
+ class ResizeSomeImage(object):
104
+ def __init__(self, args):
105
+ args = get_args_parser()
106
+ args = args.parse_args()
107
+ # print(dir(args.im_dir.as_posix()))
108
+ self.data_path = Path(args.data_path)
109
+ self.im_dir = self.data_path/args.im_dir
110
+ anno_file = self.data_path/args.anno_file
111
+ data_split_file = self.data_path/args.data_split_file
112
+
113
+ with open(anno_file) as f:
114
+ self.annotations = json.load(f)
115
+
116
+ with open(data_split_file) as f:
117
+ data_split = json.load(f)
118
+
119
+ self.train_set = data_split['train']
120
+
121
+ self.class_dict = {}
122
+ if args.do_aug:
123
+ with open(args.class_file) as f:
124
+ for line in f:
125
+ key = line.split()[0]
126
+ val = line.split()[1:]
127
+ self.class_dict[key] = val
128
+
129
+
130
+ class ResizePreTrainImage(ResizeSomeImage):
131
+ """
132
+ Resize the image so that:
133
+ 1. Image is equal to 384 * 384
134
+ 2. The new height and new width are divisible by 16
135
+ 3. The aspect ratio is preserved
136
+ Density and boxes correctness not preserved(crop and horizontal flip)
137
+ """
138
+
139
+ def __init__(self, args, MAX_HW=384):
140
+ super().__init__(args)
141
+ self.max_hw = MAX_HW
142
+
143
+ def __call__(self, sample):
144
+ image, lines_boxes, density = sample['image'], sample['lines_boxes'], sample['gt_density']
145
+
146
+ W, H = image.size
147
+
148
+ new_H = 16 * int(H / 16)
149
+ new_W = 16 * int(W / 16)
150
+ resized_image = transforms.Resize((new_H, new_W))(image)
151
+ resized_density = cv2.resize(density, (new_W, new_H))
152
+ orig_count = np.sum(density)
153
+ new_count = np.sum(resized_density)
154
+
155
+ if new_count > 0:
156
+ resized_density = resized_density * (orig_count / new_count)
157
+
158
+ boxes = list()
159
+ for box in lines_boxes:
160
+ box2 = [int(k) for k in box]
161
+ y1, x1, y2, x2 = box2[0], box2[1], box2[2], box2[3]
162
+ boxes.append([0, y1, x1, y2, x2])
163
+
164
+ boxes = torch.Tensor(boxes).unsqueeze(0)
165
+ resized_image = PreTrainNormalize(resized_image)
166
+ resized_density = torch.from_numpy(resized_density).unsqueeze(0).unsqueeze(0)
167
+ sample = {'image': resized_image, 'boxes': boxes, 'gt_density': resized_density}
168
+ return sample
169
+
170
+
171
+ class ResizeTrainImage(ResizeSomeImage):
172
+ """
173
+ Resize the image so that:
174
+ 1. Image is equal to 384 * 384
175
+ 2. The new height and new width are divisible by 16
176
+ 3. The aspect ratio is possibly preserved
177
+ Density map is cropped to have the same size(and position) with the cropped image
178
+ Exemplar boxes may be outside the cropped area.
179
+ Augmentation including Gaussian noise, Color jitter, Gaussian blur, Random affine, Random horizontal flip and Mosaic (or Random Crop if no Mosaic) is used.
180
+ """
181
+
182
+ def __init__(self, args, MAX_HW=384, do_aug=True):
183
+ super().__init__(args)
184
+ self.max_hw = MAX_HW
185
+ self.do_aug = do_aug
186
+
187
+ def __call__(self, sample):
188
+ image, lines_boxes, neg_lines_boxes, dots, im_id, m_flag = sample['image'], sample['lines_boxes'], sample['neg_lines_boxes'], \
189
+ sample['dots'], sample['id'], sample['m_flag']
190
+
191
+ W, H = image.size
192
+
193
+ new_H = 16 * int(H / 16)
194
+ new_W = 16 * int(W / 16)
195
+ scale_factor_h = float(new_H) / H
196
+ scale_factor_w = float(new_W) / W
197
+ resized_image = transforms.Resize((new_H, new_W))(image)
198
+ resized_image = TTensor(resized_image)
199
+ resized_density = np.zeros((new_H, new_W), dtype='float32')
200
+
201
+ # Augmentation probability
202
+ aug_flag = self.do_aug
203
+ mosaic_flag = random.random() < 0.25
204
+
205
+ if aug_flag:
206
+ # Gaussian noise
207
+ noise = np.random.normal(0, 0.1, resized_image.size())
208
+ noise = torch.from_numpy(noise)
209
+ re_image = resized_image + noise
210
+ re_image = torch.clamp(re_image, 0, 1)
211
+
212
+ # Color jitter and Gaussian blur
213
+ re_image = Augmentation(re_image)
214
+
215
+ # Random affine
216
+ re1_image = re_image.transpose(0, 1).transpose(1, 2).numpy()
217
+ keypoints = []
218
+ for i in range(dots.shape[0]):
219
+ keypoints.append(Keypoint(x=min(new_W - 1, int(dots[i][0] * scale_factor_w)), y=min(new_H - 1, int(dots[i][1] * scale_factor_h))))
220
+ kps = KeypointsOnImage(keypoints, re1_image.shape)
221
+
222
+ seq = iaa.Sequential([
223
+ iaa.Affine(
224
+ rotate=(-15, 15),
225
+ scale=(0.8, 1.2),
226
+ shear=(-10, 10),
227
+ translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}
228
+ )
229
+ ])
230
+ re1_image, kps_aug = seq(image=re1_image, keypoints=kps)
231
+
232
+ # Produce dot annotation map
233
+ resized_density = np.zeros((resized_density.shape[0], resized_density.shape[1]), dtype='float32')
234
+ for i in range(len(kps.keypoints)):
235
+ if (int(kps_aug.keypoints[i].y) <= new_H - 1 and int(kps_aug.keypoints[i].x) <= new_W - 1) and not \
236
+ kps_aug.keypoints[i].is_out_of_image(re1_image):
237
+ resized_density[int(kps_aug.keypoints[i].y)][int(kps_aug.keypoints[i].x)] = 1
238
+ resized_density = torch.from_numpy(resized_density)
239
+
240
+ re_image = TTensor(re1_image)
241
+
242
+ # Random horizontal flip
243
+ flip_p = random.random()
244
+ if flip_p > 0.5:
245
+ re_image = TF.hflip(re_image)
246
+ resized_density = TF.hflip(resized_density)
247
+
248
+ # Random self mosaic
249
+ if mosaic_flag:
250
+ image_array = []
251
+ map_array = []
252
+ blending_l = random.randint(10, 20)
253
+ resize_l = 192 + 2 * blending_l
254
+ if dots.shape[0] >= 70:
255
+ for i in range(4):
256
+ length = random.randint(150, 384)
257
+ start_W = random.randint(0, new_W - length)
258
+ start_H = random.randint(0, new_H - length)
259
+ reresized_image1 = TF.crop(resized_image, start_H, start_W, length, length)
260
+ reresized_image1 = transforms.Resize((resize_l, resize_l))(reresized_image1)
261
+ reresized_density1 = np.zeros((resize_l, resize_l), dtype='float32')
262
+ for i in range(dots.shape[0]):
263
+ if start_H <= min(new_H - 1, int(dots[i][1] * scale_factor_h)) < start_H + length and start_W <= min(new_W - 1, int(dots[i][0] * scale_factor_w)) < start_W + length:
264
+ reresized_density1[min(resize_l-1,int((min(new_H-1,int(dots[i][1] * scale_factor_h))-start_H)*resize_l/length))][min(resize_l-1,int((min(new_W-1,int(dots[i][0] * scale_factor_w))-start_W)*resize_l/length))]=1
265
+ reresized_density1 = torch.from_numpy(reresized_density1)
266
+ image_array.append(reresized_image1)
267
+ map_array.append(reresized_density1)
268
+ else:
269
+ m_flag = 1
270
+ prob = random.random()
271
+ if prob > 0.25:
272
+ gt_pos = random.randint(0, 3)
273
+ else:
274
+ gt_pos = random.randint(0, 4) # 5% 0 objects
275
+ for i in range(4):
276
+ if i == gt_pos:
277
+ Tim_id = im_id
278
+ r_image = resized_image
279
+ Tdots = dots
280
+ new_TH = new_H
281
+ new_TW = new_W
282
+ Tscale_factor_w = scale_factor_w
283
+ Tscale_factor_h = scale_factor_h
284
+ else:
285
+ Tim_id = self.train_set[random.randint(0, len(self.train_set) - 1)]
286
+ Tdots = np.array(self.annotations[Tim_id]['points'])
287
+ Timage = Image.open('{}/{}'.format(self.im_dir, Tim_id))
288
+ Timage.load()
289
+ new_TW = 16 * int(Timage.size[0] / 16)
290
+ new_TH = 16 * int(Timage.size[1] / 16)
291
+ Tscale_factor_w = float(new_TW) / Timage.size[0]
292
+ Tscale_factor_h = float(new_TH) / Timage.size[1]
293
+ r_image = TTensor(transforms.Resize((new_TH, new_TW))(Timage))
294
+
295
+ length = random.randint(250, 384)
296
+ start_W = random.randint(0, new_TW - length)
297
+ start_H = random.randint(0, new_TH - length)
298
+ r_image1 = TF.crop(r_image, start_H, start_W, length, length)
299
+ r_image1 = transforms.Resize((resize_l, resize_l))(r_image1)
300
+ r_density1 = np.zeros((resize_l, resize_l), dtype='float32')
301
+ # try:
302
+ # class_value = self.class_dict[im_id]
303
+ # Tim_value = self.class_dict[Tim_id]
304
+ # except KeyError:
305
+ # # Handle the case when the key doesn't exist
306
+ # class_value = None # Or any appropriate default value
307
+ # Tim_value = None # Or any appropriate default value
308
+ if self.class_dict[im_id] == self.class_dict[Tim_id]:
309
+ # if class_value == Tim_value:
310
+ # if im_id in self.class_dict and Tim_id in self.class_dict:
311
+ # if im_id in self.class_dict and Tim_id in self.class_dict:
312
+ # class_value = self.class_dict[im_id]
313
+ # Tim_value = self.class_dict[Tim_id]
314
+
315
+ # # Proceed with your comparison and processing here
316
+ # if class_value == Tim_value:
317
+ for i in range(Tdots.shape[0]):
318
+ if start_H <= min(new_TH - 1, int(Tdots[i][1] * Tscale_factor_h)) < start_H + length and start_W <= min(new_TW - 1, int(Tdots[i][0] * Tscale_factor_w)) < start_W + length:
319
+ r_density1[min(resize_l-1,int((min(new_TH-1, int(Tdots[i][1] * Tscale_factor_h))-start_H)*resize_l/length))][min(resize_l-1,int((min(new_TW-1,int(Tdots[i][0] * Tscale_factor_w))-start_W)*resize_l/length))]=1
320
+ r_density1 = torch.from_numpy(r_density1)
321
+ image_array.append(r_image1)
322
+ map_array.append(r_density1)
323
+
324
+ reresized_image5 = torch.cat((image_array[0][:, blending_l:resize_l-blending_l], image_array[1][:, blending_l: resize_l-blending_l]), 1)
325
+ reresized_density5 = torch.cat((map_array[0][blending_l:resize_l-blending_l], map_array[1][blending_l: resize_l-blending_l]), 0)
326
+ for i in range(blending_l):
327
+ reresized_image5[:, 192+i] = image_array[0][:, resize_l-1-blending_l+i] * (blending_l-i)/(2 * blending_l) + reresized_image5[:, 192+i] * (i+blending_l)/(2*blending_l)
328
+ reresized_image5[:, 191-i] = image_array[1][:, blending_l-i] * (blending_l-i)/(2*blending_l) + reresized_image5[:, 191-i] * (i+blending_l)/(2*blending_l)
329
+ reresized_image5 = torch.clamp(reresized_image5, 0, 1)
330
+
331
+ reresized_image6 = torch.cat((image_array[2][:, blending_l:resize_l-blending_l], image_array[3][:, blending_l: resize_l-blending_l]), 1)
332
+ reresized_density6 = torch.cat((map_array[2][blending_l:resize_l-blending_l], map_array[3][blending_l:resize_l-blending_l]), 0)
333
+ for i in range(blending_l):
334
+ reresized_image6[:, 192+i] = image_array[2][:, resize_l-1-blending_l+i] * (blending_l-i)/(2*blending_l) + reresized_image6[:, 192+i] * (i+blending_l)/(2*blending_l)
335
+ reresized_image6[:, 191-i] = image_array[3][:, blending_l-i] * (blending_l-i)/(2*blending_l) + reresized_image6[:, 191-i] * (i+blending_l)/(2*blending_l)
336
+ reresized_image6 = torch.clamp(reresized_image6, 0, 1)
337
+
338
+ reresized_image = torch.cat((reresized_image5[:, :, blending_l:resize_l-blending_l], reresized_image6[:, :, blending_l:resize_l-blending_l]), 2)
339
+ reresized_density = torch.cat((reresized_density5[:, blending_l:resize_l-blending_l], reresized_density6[:, blending_l:resize_l-blending_l]), 1)
340
+ for i in range(blending_l):
341
+ reresized_image[:, :, 192+i] = reresized_image5[:, :, resize_l-1-blending_l+i] * (blending_l-i)/(2*blending_l) + reresized_image[:, :, 192+i] * (i+blending_l)/(2*blending_l)
342
+ reresized_image[:, :, 191-i] = reresized_image6[:, :, blending_l-i] * (blending_l-i)/(2*blending_l) + reresized_image[:, :, 191-i] * (i+blending_l)/(2*blending_l)
343
+ reresized_image = torch.clamp(reresized_image, 0, 1)
344
+
345
+ else:
346
+ # Random 384*384 crop in a new_W*384 image and 384*new_W density map
347
+ start = random.randint(0, new_W - 1 - 383)
348
+ reresized_image = TF.crop(re_image, 0, start, 384, 384)
349
+ reresized_density = resized_density[:, start:start + 384]
350
+
351
+ else:
352
+ # Random 384*384 crop in a new_W*384 image and 384*new_W density map
353
+ for i in range(dots.shape[0]):
354
+ resized_density[min(new_H - 1, int(dots[i][1] * scale_factor_h))] \
355
+ [min(new_W - 1, int(dots[i][0] * scale_factor_w))] = 1
356
+ resized_density = torch.from_numpy(resized_density)
357
+ start = random.randint(0, new_W - self.max_hw)
358
+ reresized_image = TF.crop(resized_image, 0, start, self.max_hw, self.max_hw)
359
+ reresized_density = resized_density[0:self.max_hw, start:start + self.max_hw]
360
+
361
+ # Gaussian distribution density map
362
+ reresized_density = ndimage.gaussian_filter(reresized_density.numpy(), sigma=(1, 1), order=0)
363
+
364
+ # Density map scale up
365
+ reresized_density = reresized_density * 60
366
+ reresized_density = torch.from_numpy(reresized_density)
367
+
368
+ # Crop bboxes and resize as 64x64
369
+ boxes = list()
370
+ rects = list()
371
+ cnt = 0
372
+ for box in lines_boxes:
373
+ cnt += 1
374
+ if cnt > 3:
375
+ break
376
+ box2 = [int(k) for k in box]
377
+ y1 = int(box2[0] * scale_factor_h)
378
+ x1 = int(box2[1] * scale_factor_w)
379
+ y2 = int(box2[2] * scale_factor_h)
380
+ x2 = int(box2[3] * scale_factor_w)
381
+ # print(y1,x1,y2,x2)
382
+ if not aug_flag:
383
+ rects.append(torch.tensor([y1, max(0, x1-start), y2, min(self.max_hw, x2-start)]))
384
+ bbox = resized_image[:, y1:y2 + 1, x1:x2 + 1]
385
+ bbox = transforms.Resize((64, 64))(bbox)
386
+ boxes.append(bbox)
387
+ boxes = torch.stack(boxes)
388
+ neg_boxes = list()
389
+ neg_rects = list()
390
+ cnt = 0
391
+ for box in neg_lines_boxes:
392
+ cnt += 1
393
+ if cnt > 3:
394
+ break
395
+ box2 = [int(k) for k in box]
396
+ y1 = int(box2[0] * scale_factor_h)
397
+ x1 = int(box2[1] * scale_factor_w)
398
+ y2 = int(box2[2] * scale_factor_h)
399
+ x2 = int(box2[3] * scale_factor_w)
400
+ # print(y1,x1,y2,x2)
401
+ if not aug_flag:
402
+ neg_rects.append(torch.tensor([y1, max(0, x1-start), y2, min(self.max_hw, x2-start)]))
403
+ neg_bbox = resized_image[:, y1:y2 + 1, x1:x2 + 1]
404
+ neg_bbox = transforms.Resize((64, 64))(neg_bbox)
405
+ neg_boxes.append(neg_bbox)
406
+ neg_boxes = torch.stack(neg_boxes)
407
+ # if len(boxes) > 0:
408
+ # boxes = torch.stack(boxes) # 如果 boxes 非空,则正常执行 torch.stack
409
+ # boxes1 = boxes
410
+ # else:
411
+ # boxes = boxes1
412
+ # pass
413
+ # # 如果 boxes 为空,您可以选择跳过这个样本,或者提供一个默认的边界框
414
+ # # 例如,使用一个表示图像全区域的默认边界框
415
+ # default_box = torch.tensor([[0, 0],[0, 0],0, 0]) # 一个示例的默认边界框,具体值取决于您的应用
416
+ # boxes = default_box.unsqueeze(0) # 增加一个维度以符合 torch.stack 的要求
417
+ # # pass
418
+ if aug_flag:
419
+ pos = torch.tensor([])
420
+ else:
421
+ pos = torch.stack(rects)
422
+
423
+ # boxes shape [3,3,64,64], image shape [3,384,384], density shape[384,384]
424
+ sample = {'image': reresized_image, 'boxes': boxes, 'neg_boxes': neg_boxes, 'pos': pos, 'gt_density': reresized_density, 'm_flag': m_flag}
425
+
426
+ return sample
427
+
428
+
429
+ class ResizeValImage(ResizeSomeImage):
430
+ def __init__(self, args, MAX_HW=384):
431
+ super().__init__(args)
432
+ self.max_hw = MAX_HW
433
+
434
+ def __call__(self, sample):
435
+ image, dots, m_flag, lines_boxes, neg_lines_boxes = sample['image'], sample['dots'], sample['m_flag'], sample['lines_boxes'], sample['neg_lines_boxes']
436
+
437
+ W, H = image.size
438
+
439
+ new_H = new_W = self.max_hw
440
+ scale_factor_h = float(new_H) / H
441
+ scale_factor_w = float(new_W) / W
442
+ resized_image = transforms.Resize((new_H, new_W))(image)
443
+ resized_image = TTensor(resized_image)
444
+
445
+ # Resize density map
446
+ resized_density = np.zeros((new_H, new_W), dtype='float32')
447
+ for i in range(dots.shape[0]):
448
+ resized_density[min(new_H - 1, int(dots[i][1] * scale_factor_h))] \
449
+ [min(new_W - 1, int(dots[i][0] * scale_factor_w))] = 1
450
+ # resized_density = ndimage.gaussian_filter(resized_density, sigma=4, radius=7, order=0)
451
+ resized_density = ndimage.gaussian_filter(resized_density, sigma=4, order=0)
452
+ resized_density = torch.from_numpy(resized_density) * 60
453
+
454
+ # Crop bboxes and resize as 64x64
455
+ boxes = list()
456
+ rects = list()
457
+ cnt = 0
458
+ for box in lines_boxes:
459
+ cnt += 1
460
+ if cnt > 3:
461
+ break
462
+ box2 = [int(k) for k in box]
463
+ y1 = int(box2[0] * scale_factor_h)
464
+ x1 = int(box2[1] * scale_factor_w)
465
+ y2 = int(box2[2] * scale_factor_h)
466
+ x2 = int(box2[3] * scale_factor_w)
467
+ rects.append(torch.tensor([y1, x1, y2, x2]))
468
+ bbox = resized_image[:, y1:y2 + 1, x1:x2 + 1]
469
+ bbox = transforms.Resize((64, 64))(bbox)
470
+ boxes.append(bbox)
471
+ boxes = torch.stack(boxes)
472
+ pos = torch.stack(rects)
473
+ neg_boxes = list()
474
+ neg_rects = list()
475
+ cnt = 0
476
+ for box in neg_lines_boxes:
477
+ cnt += 1
478
+ if cnt > 3:
479
+ break
480
+ box2 = [int(k) for k in box]
481
+ y1 = int(box2[0] * scale_factor_h)
482
+ x1 = int(box2[1] * scale_factor_w)
483
+ y2 = int(box2[2] * scale_factor_h)
484
+ x2 = int(box2[3] * scale_factor_w)
485
+ neg_rects.append(torch.tensor([y1, x1, y2, x2]))
486
+ neg_bbox = resized_image[:, y1:y2 + 1, x1:x2 + 1]
487
+ neg_bbox = transforms.Resize((64, 64))(neg_bbox)
488
+ neg_boxes.append(neg_bbox)
489
+ neg_boxes = torch.stack(neg_boxes)
490
+ # boxes shape [3,3,64,64], image shape [3,384,384], density shape[384,384]
491
+ sample = {'image': resized_image, 'boxes': boxes, 'neg_boxes': neg_boxes, 'pos': pos, 'gt_density': resized_density, 'm_flag': m_flag}
492
+ return sample
493
+
494
+
495
+ PreTrainNormalize = transforms.Compose([
496
+ transforms.RandomResizedCrop(MAX_HW, scale=(0.2, 1.0), interpolation=3),
497
+ transforms.RandomHorizontalFlip(),
498
+ transforms.ToTensor(),
499
+ # transforms.Normalize(mean=IM_NORM_MEAN, std=IM_NORM_STD)
500
+ ])
501
+
502
+ TTensor = transforms.Compose([
503
+ transforms.ToTensor(),
504
+ ])
505
+
506
+ Augmentation = transforms.Compose([
507
+ transforms.ColorJitter(brightness=0.25, contrast=0.15, saturation=0.15, hue=0.15),
508
+ transforms.GaussianBlur(kernel_size=(7, 9))
509
+ ])
510
+
511
+ Normalize = transforms.Compose([
512
+ transforms.ToTensor(),
513
+ transforms.Normalize(mean=IM_NORM_MEAN, std=IM_NORM_STD)
514
+ ])
515
+
516
+
517
+ def transform_train(args: Namespace, do_aug=True):
518
+ return transforms.Compose([ResizeTrainImage(args, MAX_HW, do_aug)])
519
+
520
+ def transform_val(args: Namespace):
521
+ return transforms.Compose([ResizeValImage(args, MAX_HW)])
522
+
523
+ def transform_pre_train(args: Namespace):
524
+ return transforms.Compose([ResizePreTrainImage(args, MAX_HW)])
util/__pycache__/FSC147.cpython-38.pyc ADDED
Binary file (15.3 kB). View file
 
util/__pycache__/FSC147.cpython-39.pyc ADDED
Binary file (14.4 kB). View file
 
util/__pycache__/FSC147_test.cpython-38.pyc ADDED
Binary file (16.6 kB). View file
 
util/__pycache__/lr_sched.cpython-38.pyc ADDED
Binary file (628 Bytes). View file
 
util/__pycache__/lr_sched.cpython-39.pyc ADDED
Binary file (628 Bytes). View file
 
util/__pycache__/misc.cpython-38.pyc ADDED
Binary file (19.5 kB). View file
 
util/__pycache__/misc.cpython-39.pyc ADDED
Binary file (19.4 kB). View file
 
util/__pycache__/pos_embed.cpython-38.pyc ADDED
Binary file (2.41 kB). View file
 
util/__pycache__/pos_embed.cpython-39.pyc ADDED
Binary file (2.39 kB). View file
 
util/crop.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms import functional as F
13
+
14
+
15
+ class RandomResizedCrop(transforms.RandomResizedCrop):
16
+ """
17
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18
+ This may lead to results different with torchvision's version.
19
+ Following BYOL's TF code:
20
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21
+ """
22
+ @staticmethod
23
+ def get_params(img, scale, ratio):
24
+ width, height = F._get_image_size(img)
25
+ area = height * width
26
+
27
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
28
+ log_ratio = torch.log(torch.tensor(ratio))
29
+ aspect_ratio = torch.exp(
30
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
31
+ ).item()
32
+
33
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
34
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
35
+
36
+ w = min(w, width)
37
+ h = min(h, height)
38
+
39
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
40
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
41
+
42
+ return i, j, h, w
util/datasets.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # --------------------------------------------------------
10
+
11
+ import os
12
+ import PIL
13
+
14
+ from torchvision import datasets, transforms
15
+
16
+ from timm.data import create_transform
17
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18
+
19
+
20
+ def build_dataset(is_train, args):
21
+ transform = build_transform(is_train, args)
22
+
23
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
24
+ dataset = datasets.ImageFolder(root, transform=transform)
25
+
26
+ print(dataset)
27
+
28
+ return dataset
29
+
30
+
31
+ def build_transform(is_train, args):
32
+ mean = IMAGENET_DEFAULT_MEAN
33
+ std = IMAGENET_DEFAULT_STD
34
+ # train transform
35
+ if is_train:
36
+ # this should always dispatch to transforms_imagenet_train
37
+ transform = create_transform(
38
+ input_size=args.input_size,
39
+ is_training=True,
40
+ color_jitter=args.color_jitter,
41
+ auto_augment=args.aa,
42
+ interpolation='bicubic',
43
+ re_prob=args.reprob,
44
+ re_mode=args.remode,
45
+ re_count=args.recount,
46
+ mean=mean,
47
+ std=std,
48
+ )
49
+ return transform
50
+
51
+ # eval transform
52
+ t = []
53
+ if args.input_size <= 224:
54
+ crop_pct = 224 / 256
55
+ else:
56
+ crop_pct = 1.0
57
+ size = int(args.input_size / crop_pct)
58
+ t.append(
59
+ transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
60
+ )
61
+ t.append(transforms.CenterCrop(args.input_size))
62
+
63
+ t.append(transforms.ToTensor())
64
+ t.append(transforms.Normalize(mean, std))
65
+ return transforms.Compose(t)
util/lars.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # LARS optimizer, implementation from MoCo v3:
8
+ # https://github.com/facebookresearch/moco-v3
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+
14
+ class LARS(torch.optim.Optimizer):
15
+ """
16
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
17
+ """
18
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
19
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
20
+ super().__init__(params, defaults)
21
+
22
+ @torch.no_grad()
23
+ def step(self):
24
+ for g in self.param_groups:
25
+ for p in g['params']:
26
+ dp = p.grad
27
+
28
+ if dp is None:
29
+ continue
30
+
31
+ if p.ndim > 1: # if not normalization gamma/beta or bias
32
+ dp = dp.add(p, alpha=g['weight_decay'])
33
+ param_norm = torch.norm(p)
34
+ update_norm = torch.norm(dp)
35
+ one = torch.ones_like(param_norm)
36
+ q = torch.where(param_norm > 0.,
37
+ torch.where(update_norm > 0,
38
+ (g['trust_coefficient'] * param_norm / update_norm), one),
39
+ one)
40
+ dp = dp.mul(q)
41
+
42
+ param_state = self.state[p]
43
+ if 'mu' not in param_state:
44
+ param_state['mu'] = torch.zeros_like(p)
45
+ mu = param_state['mu']
46
+ mu.mul_(g['momentum']).add_(dp)
47
+ p.add_(mu, alpha=-g['lr'])
util/lr_decay.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # ELECTRA https://github.com/google-research/electra
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import json
13
+
14
+
15
+ def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
16
+ """
17
+ Parameter groups for layer-wise lr decay
18
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19
+ """
20
+ param_group_names = {}
21
+ param_groups = {}
22
+
23
+ num_layers = len(model.blocks) + 1
24
+
25
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26
+
27
+ for n, p in model.named_parameters():
28
+ if not p.requires_grad:
29
+ continue
30
+
31
+ # no decay: all 1D parameters and model specific ones
32
+ if p.ndim == 1 or n in no_weight_decay_list:
33
+ g_decay = "no_decay"
34
+ this_decay = 0.
35
+ else:
36
+ g_decay = "decay"
37
+ this_decay = weight_decay
38
+
39
+ layer_id = get_layer_id_for_vit(n, num_layers)
40
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
41
+
42
+ if group_name not in param_group_names:
43
+ this_scale = layer_scales[layer_id]
44
+
45
+ param_group_names[group_name] = {
46
+ "lr_scale": this_scale,
47
+ "weight_decay": this_decay,
48
+ "params": [],
49
+ }
50
+ param_groups[group_name] = {
51
+ "lr_scale": this_scale,
52
+ "weight_decay": this_decay,
53
+ "params": [],
54
+ }
55
+
56
+ param_group_names[group_name]["params"].append(n)
57
+ param_groups[group_name]["params"].append(p)
58
+
59
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60
+
61
+ return list(param_groups.values())
62
+
63
+
64
+ def get_layer_id_for_vit(name, num_layers):
65
+ """
66
+ Assign a parameter with its layer id
67
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68
+ """
69
+ if name in ['cls_token', 'pos_embed']:
70
+ return 0
71
+ elif name.startswith('patch_embed'):
72
+ return 0
73
+ elif name.startswith('blocks'):
74
+ return int(name.split('.')[1]) + 1
75
+ else:
76
+ return num_layers
util/lr_sched.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ def adjust_learning_rate(optimizer, epoch, args):
10
+ """Decay the learning rate with half-cycle cosine after warmup"""
11
+ if epoch < args.warmup_epochs:
12
+ lr = args.lr * epoch / args.warmup_epochs
13
+ else:
14
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16
+ for param_group in optimizer.param_groups:
17
+ if "lr_scale" in param_group:
18
+ param_group["lr"] = lr * param_group["lr_scale"]
19
+ else:
20
+ param_group["lr"] = lr
21
+ return lr
util/misc.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import builtins
13
+ import datetime
14
+ import os
15
+ import time
16
+ import json
17
+ from collections import defaultdict, deque
18
+ from pathlib import Path
19
+ # from typing import Union
20
+
21
+ import pandas as pd
22
+ import torch
23
+ import torch.distributed as dist
24
+ import wandb
25
+ # from torch._six import inf
26
+ from torch import inf
27
+ import matplotlib.pyplot as plt
28
+ from torchvision import transforms
29
+ import cv2
30
+ from tqdm import tqdm
31
+ from typing import Union, List
32
+
33
+ class SmoothedValue(object):
34
+ """Track a series of values and provide access to smoothed values over a
35
+ window or the global series average.
36
+ """
37
+
38
+ def __init__(self, window_size=20, fmt=None):
39
+ if fmt is None:
40
+ fmt = "{median:.4f} ({global_avg:.4f})"
41
+ self.deque = deque(maxlen=window_size)
42
+ self.total = 0.0
43
+ self.count = 0
44
+ self.fmt = fmt
45
+
46
+ def update(self, value, n=1):
47
+ self.deque.append(value)
48
+ self.count += n
49
+ self.total += value * n
50
+
51
+ def synchronize_between_processes(self):
52
+ """
53
+ Warning: does not synchronize the deque!
54
+ """
55
+ if not is_dist_avail_and_initialized():
56
+ return
57
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
58
+ dist.barrier()
59
+ dist.all_reduce(t)
60
+ t = t.tolist()
61
+ self.count = int(t[0])
62
+ self.total = t[1]
63
+
64
+ @property
65
+ def median(self):
66
+ d = torch.tensor(list(self.deque))
67
+ return d.median().item()
68
+
69
+ @property
70
+ def avg(self):
71
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
72
+ return d.mean().item()
73
+
74
+ @property
75
+ def global_avg(self):
76
+ if self.count == 0:
77
+ # Return a default value or handle the zero count scenario
78
+ return 0 # Or any other default value or handling mechanism
79
+ else:
80
+ return self.total / self.count
81
+ # return self.total / self.count
82
+
83
+ @property
84
+ def max(self):
85
+ return max(self.deque)
86
+
87
+ @property
88
+ def value(self):
89
+ return self.deque[-1]
90
+
91
+ def __str__(self):
92
+ return self.fmt.format(
93
+ median=self.median,
94
+ avg=self.avg,
95
+ global_avg=self.global_avg,
96
+ max=self.max,
97
+ value=self.value)
98
+
99
+
100
+ class MetricLogger(object):
101
+ def __init__(self, delimiter="\t"):
102
+ self.meters = defaultdict(SmoothedValue)
103
+ self.delimiter = delimiter
104
+
105
+ def update(self, **kwargs):
106
+ for k, v in kwargs.items():
107
+ if v is None:
108
+ continue
109
+ if isinstance(v, torch.Tensor):
110
+ v = v.item()
111
+ assert isinstance(v, (float, int))
112
+ self.meters[k].update(v)
113
+
114
+ def __getattr__(self, attr):
115
+ if attr in self.meters:
116
+ return self.meters[attr]
117
+ if attr in self.__dict__:
118
+ return self.__dict__[attr]
119
+ raise AttributeError("'{}' object has no attribute '{}'".format(
120
+ type(self).__name__, attr))
121
+
122
+ def __str__(self):
123
+ loss_str = []
124
+ for name, meter in self.meters.items():
125
+ loss_str.append(
126
+ "{}: {}".format(name, str(meter))
127
+ )
128
+ return self.delimiter.join(loss_str)
129
+
130
+ def synchronize_between_processes(self):
131
+ for meter in self.meters.values():
132
+ meter.synchronize_between_processes()
133
+
134
+ def add_meter(self, name, meter):
135
+ self.meters[name] = meter
136
+
137
+ def log_every(self, iterable, print_freq, header=None):
138
+ i = 0
139
+ if not header:
140
+ header = ''
141
+ start_time = time.time()
142
+ end = time.time()
143
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
144
+ data_time = SmoothedValue(fmt='{avg:.4f}')
145
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
146
+ log_msg = [
147
+ header,
148
+ '[{0' + space_fmt + '}/{1}]',
149
+ 'eta: {eta}',
150
+ '{meters}',
151
+ 'time: {time}',
152
+ 'data: {data}'
153
+ ]
154
+ if torch.cuda.is_available():
155
+ log_msg.append('max mem: {memory:.0f}')
156
+ log_msg = self.delimiter.join(log_msg)
157
+ MB = 1024.0 * 1024.0
158
+ for obj in iterable:
159
+ data_time.update(time.time() - end)
160
+ yield obj
161
+ iter_time.update(time.time() - end)
162
+ if i % print_freq == 0 or i == len(iterable) - 1:
163
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
164
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
165
+ if torch.cuda.is_available():
166
+ print(log_msg.format(
167
+ i, len(iterable), eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time), data=str(data_time),
170
+ memory=torch.cuda.max_memory_allocated() / MB))
171
+ else:
172
+ print(log_msg.format(
173
+ i, len(iterable), eta=eta_string,
174
+ meters=str(self),
175
+ time=str(iter_time), data=str(data_time)))
176
+ i += 1
177
+ end = time.time()
178
+ total_time = time.time() - start_time
179
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
180
+ print('{} Total time: {} ({:.4f} s / it)'.format(
181
+ header, total_time_str, total_time / len(iterable)))
182
+
183
+
184
+ def setup_for_distributed(is_master):
185
+ """
186
+ This function disables printing when not in master process
187
+ """
188
+ builtin_print = builtins.print
189
+
190
+ def print(*args, **kwargs):
191
+ force = kwargs.pop('force', False)
192
+ force = force or (get_world_size() > 8)
193
+ if is_master or force:
194
+ now = datetime.datetime.now().time()
195
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
196
+ builtin_print(*args, **kwargs)
197
+
198
+ builtins.print = print
199
+
200
+
201
+ def is_dist_avail_and_initialized():
202
+ if not dist.is_available():
203
+ return False
204
+ if not dist.is_initialized():
205
+ return False
206
+ return True
207
+
208
+
209
+ def get_world_size():
210
+ if not is_dist_avail_and_initialized():
211
+ return 1
212
+ return dist.get_world_size()
213
+
214
+
215
+ def get_rank():
216
+ if not is_dist_avail_and_initialized():
217
+ return 0
218
+ return dist.get_rank()
219
+
220
+
221
+ def is_main_process():
222
+ return get_rank() == 0
223
+
224
+
225
+ def save_on_master(*args, **kwargs):
226
+ if is_main_process():
227
+ torch.save(*args, **kwargs)
228
+
229
+
230
+ def init_distributed_mode(args):
231
+ if args.dist_on_itp:
232
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
233
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
234
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
235
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
236
+ os.environ['LOCAL_RANK'] = str(args.gpu)
237
+ os.environ['RANK'] = str(args.rank)
238
+ os.environ['WORLD_SIZE'] = str(args.world_size)
239
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
240
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
241
+ args.rank = int(os.environ["RANK"])
242
+ args.world_size = int(os.environ['WORLD_SIZE'])
243
+ args.gpu = int(os.environ['LOCAL_RANK'])
244
+ elif 'SLURM_PROCID' in os.environ:
245
+ args.rank = int(os.environ['SLURM_PROCID'])
246
+ args.gpu = args.rank % torch.cuda.device_count()
247
+ else:
248
+ print('Not using distributed mode')
249
+ setup_for_distributed(is_master=True) # hack
250
+ args.distributed = False
251
+ return
252
+
253
+ args.distributed = True
254
+
255
+ torch.cuda.set_device(args.gpu)
256
+ args.dist_backend = 'nccl'
257
+ print('| distributed init (rank {}): {}, gpu {}'.format(
258
+ args.rank, args.dist_url, args.gpu), flush=True)
259
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
260
+ world_size=args.world_size, rank=args.rank)
261
+ torch.distributed.barrier()
262
+ setup_for_distributed(args.rank == 0)
263
+
264
+
265
+ class NativeScalerWithGradNormCount:
266
+ state_dict_key = "amp_scaler"
267
+
268
+ def __init__(self):
269
+ self._scaler = torch.cuda.amp.GradScaler()
270
+
271
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
272
+ self._scaler.scale(loss).backward(create_graph=create_graph)
273
+ if update_grad:
274
+ if clip_grad is not None:
275
+ assert parameters is not None
276
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
277
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
278
+ else:
279
+ self._scaler.unscale_(optimizer)
280
+ norm = get_grad_norm_(parameters)
281
+ self._scaler.step(optimizer)
282
+ self._scaler.update()
283
+ else:
284
+ norm = None
285
+ return norm
286
+
287
+ def state_dict(self):
288
+ return self._scaler.state_dict()
289
+
290
+ def load_state_dict(self, state_dict):
291
+ self._scaler.load_state_dict(state_dict)
292
+
293
+
294
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
295
+ if isinstance(parameters, torch.Tensor):
296
+ parameters = [parameters]
297
+ parameters = [p for p in parameters if p.grad is not None]
298
+ norm_type = float(norm_type)
299
+ if len(parameters) == 0:
300
+ return torch.tensor(0.)
301
+ device = parameters[0].grad.device
302
+ if norm_type == inf:
303
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
304
+ else:
305
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
306
+ return total_norm
307
+
308
+
309
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, suffix="", upload=True):
310
+ if suffix:
311
+ suffix = f"__{suffix}"
312
+ output_dir = Path(args.output_dir)
313
+ ckpt_name = f"checkpoint{suffix}.pth"
314
+ if loss_scaler is not None:
315
+ checkpoint_paths = [output_dir / ckpt_name]
316
+ for checkpoint_path in checkpoint_paths:
317
+ to_save = {
318
+ 'model': model_without_ddp.state_dict(),
319
+ 'optimizer': optimizer.state_dict(),
320
+ 'epoch': epoch,
321
+ 'scaler': loss_scaler.state_dict(),
322
+ 'args': args,
323
+ }
324
+ save_on_master(to_save, checkpoint_path)
325
+ if upload and is_main_process():
326
+ log_wandb_model(f"checkpoint{suffix}", checkpoint_path, epoch)
327
+ print("checkpoint sent to W&B (if)")
328
+ else:
329
+ client_state = {'epoch': epoch}
330
+ model.save_checkpoint(save_dir=args.output_dir, tag=ckpt_name, client_state=client_state)
331
+ if upload and is_main_process():
332
+ log_wandb_model(f"checkpoint{suffix}", output_dir / ckpt_name, epoch)
333
+ print("checkpoint sent to W&B (else)")
334
+
335
+
336
+ def log_wandb_model(title, path, epoch):
337
+ artifact = wandb.Artifact(title, type="model")
338
+ artifact.add_file(path)
339
+ artifact.metadata["epoch"] = epoch
340
+ wandb.log_artifact(artifact_or_path=artifact, name=title)
341
+
342
+
343
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
344
+ if args.resume:
345
+ if args.resume.startswith('https'):
346
+ checkpoint = torch.hub.load_state_dict_from_url(
347
+ args.resume, map_location='cpu', check_hash=True)
348
+ else:
349
+ checkpoint = torch.load(args.resume, map_location='cpu')
350
+
351
+ if 'pos_embed' in checkpoint['model'] and checkpoint['model']['pos_embed'].shape != model_without_ddp.state_dict()['pos_embed'].shape:
352
+ print(f"Removing key pos_embed from pretrained checkpoint")
353
+ del checkpoint['model']['pos_embed']
354
+
355
+ if 'decoder_pos_embed' in checkpoint['model'] and checkpoint['model']['decoder_pos_embed'].shape != model_without_ddp.state_dict()['decoder_pos_embed'].shape:
356
+ print(f"Removing key decoder_pos_embed from pretrained checkpoint")
357
+ del checkpoint['model']['decoder_pos_embed']
358
+
359
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
360
+ print("Resume checkpoint %s" % args.resume)
361
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
362
+ optimizer.load_state_dict(checkpoint['optimizer'])
363
+ args.start_epoch = checkpoint['epoch'] + 1
364
+ if 'scaler' in checkpoint:
365
+ loss_scaler.load_state_dict(checkpoint['scaler'])
366
+ print("With optim & sched!")
367
+
368
+ def load_model_FSC(args, model_without_ddp):
369
+ if args.resume:
370
+ if args.resume.startswith('https'):
371
+ checkpoint = torch.hub.load_state_dict_from_url(
372
+ args.resume, map_location='cpu', check_hash=True)
373
+ else:
374
+ checkpoint = torch.load(args.resume, map_location='cpu')
375
+
376
+ if 'pos_embed' in checkpoint['model'] and checkpoint['model']['pos_embed'].shape != model_without_ddp.state_dict()['pos_embed'].shape:
377
+ print(f"Removing key pos_embed from pretrained checkpoint")
378
+ del checkpoint['model']['pos_embed']
379
+
380
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
381
+ print(f"Resume checkpoint {args.resume} ({checkpoint['epoch']})")
382
+
383
+ def load_model_FSC1(args, model_without_ddp):
384
+ if args.resume:
385
+ if args.resume.startswith('https'):
386
+ checkpoint = torch.hub.load_state_dict_from_url(
387
+ args.resume, map_location='cpu', check_hash=True)
388
+ else:
389
+ checkpoint = torch.load(args.resume, map_location='cpu')
390
+ #model = timm.create_model('vit_base_patch16_224', pretrained=True)
391
+ #torch.save(model.state_dict(), './output_abnopre_dir/checkpoint-6657.pth')
392
+ checkpoint1 = torch.load('./output_abnopre_dir/checkpoint-6657.pth', map_location='cpu')
393
+
394
+ if 'pos_embed' in checkpoint['model'] and checkpoint['model']['pos_embed'].shape != model_without_ddp.state_dict()['pos_embed'].shape:
395
+ print(f"Removing key pos_embed from pretrained checkpoint")
396
+ del checkpoint['model']['pos_embed']
397
+
398
+ del checkpoint1['cls_token'],checkpoint1['pos_embed']
399
+
400
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
401
+ model_without_ddp.load_state_dict(checkpoint1, strict=False)
402
+ print("Resume checkpoint %s" % args.resume)
403
+
404
+
405
+ def load_model_FSC_full(args, model_without_ddp, optimizer, loss_scaler):
406
+ if args.resume:
407
+ if args.resume.startswith('https'):
408
+ checkpoint = torch.hub.load_state_dict_from_url(
409
+ args.resume, map_location='cpu', check_hash=True)
410
+ else:
411
+ checkpoint = torch.load(args.resume, map_location='cpu')
412
+
413
+ if 'pos_embed' in checkpoint['model'] and checkpoint['model']['pos_embed'].shape != \
414
+ model_without_ddp.state_dict()['pos_embed'].shape:
415
+ print(f"Removing key pos_embed from pretrained checkpoint")
416
+ del checkpoint['model']['pos_embed']
417
+
418
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
419
+ print("Resume checkpoint %s" % args.resume)
420
+
421
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and args.do_resume:
422
+ optimizer.load_state_dict(checkpoint['optimizer'])
423
+ args.start_epoch = checkpoint['epoch'] + 1
424
+ if 'scaler' in checkpoint:
425
+ loss_scaler.load_state_dict(checkpoint['scaler'])
426
+ print("With optim & scheduler!")
427
+
428
+
429
+ def all_reduce_mean(x):
430
+ world_size = get_world_size()
431
+ if world_size > 1:
432
+ x_reduce = torch.tensor(x).cuda()
433
+ dist.all_reduce(x_reduce)
434
+ x_reduce /= world_size
435
+ return x_reduce.item()
436
+ else:
437
+ return x
438
+
439
+
440
+ def plot_counts(res_csv: Union[str, List[str]], output_dir: str, suffix: str = "", smooth: bool = False):
441
+ if suffix:
442
+ suffix = f"_{suffix}"
443
+ if smooth:
444
+ suffix = f"_smooth{suffix}"
445
+ if type(res_csv) == str:
446
+ res_csv = [res_csv]
447
+
448
+ plt.figure(figsize=(15, 5))
449
+
450
+ for res in res_csv:
451
+ name = Path(res).parent.name
452
+ df = pd.read_csv(res)
453
+ print(df)
454
+
455
+ df.sort_values(by="name", inplace=True)
456
+ df.reset_index(drop=True, inplace=True)
457
+ df.index += 1
458
+ print(df)
459
+
460
+ if smooth:
461
+ time_arr = df.index[5:-5]
462
+ smooth_pred_mean = df['prediction'].iloc[5:-5].rolling(25).mean()
463
+ smooth_pred_std = df['prediction'].iloc[5:-5].rolling(25).std()
464
+ plt.plot(time_arr, smooth_pred_mean, label=name)
465
+ plt.fill_between(time_arr, smooth_pred_mean + smooth_pred_std, smooth_pred_mean - smooth_pred_std, alpha=.2)
466
+ plt.xlabel('Frame')
467
+ plt.ylabel('Count')
468
+ else:
469
+ plt.plot(df.index, df['prediction'], label=name)
470
+
471
+ plt.legend()
472
+ plt.savefig(os.path.join(output_dir, f'counts{suffix}.png'), dpi=300)
473
+
474
+
475
+ def write_zeroshot_annotations(p: Path):
476
+ with open(p / 'annotations.json', 'a') as split:
477
+ split.write('{\n')
478
+ for img in p.iterdir():
479
+ if img.is_file():
480
+ split.write(f' "{img.name}": {{\n' \
481
+ ' "H": 960,\n' \
482
+ ' "W": 1280,\n' \
483
+ ' "box_examples_coordinates": [],\n' \
484
+ ' "points": []\n' \
485
+ ' },\n')
486
+ split.write("}")
487
+
488
+ with open(p / 'split.json', 'a') as split:
489
+ split.write('{\n "test":\n [\n')
490
+ for img in p.iterdir():
491
+ if img.is_file():
492
+ split.write(f' "{img.name}",\n')
493
+ split.write(" ]\n}")
494
+
495
+
496
+ def make_grid(imgs, h, w):
497
+ assert len(imgs) == 9
498
+ rows = []
499
+ for i in range(0, 9, 3):
500
+ row = torch.cat((imgs[i], imgs[i + 1], imgs[i + 2]), -1)
501
+ rows += [row]
502
+ grid = torch.cat((rows[0], rows[1], rows[2]), 0)
503
+ grid = transforms.Resize((h, w))(grid.unsqueeze(0))
504
+ return grid.squeeze(0)
505
+
506
+
507
+ def min_max(t):
508
+ t_shape = t.shape
509
+ t = t.view(t_shape[0], -1)
510
+ t -= t.min(1, keepdim=True)[0]
511
+ t /= t.max(1, keepdim=True)[0]
512
+ t = t.view(*t_shape)
513
+ return t
514
+
515
+
516
+ def min_max_np(v, new_min=0, new_max=1):
517
+ v_min, v_max = v.min(), v.max()
518
+ return (v - v_min) / (v_max - v_min) * (new_max - new_min) + new_min
519
+
520
+
521
+ def get_box_map(sample, pos, device, external=False):
522
+ box_map = torch.zeros([sample.shape[1], sample.shape[2]], device=device)
523
+ if external is False:
524
+ for rect in pos:
525
+ for i in range(rect[2] - rect[0]):
526
+ box_map[min(rect[0] + i, sample.shape[1] - 1), min(rect[1], sample.shape[2] - 1)] = 10
527
+ box_map[min(rect[0] + i, sample.shape[1] - 1), min(rect[3], sample.shape[2] - 1)] = 10
528
+ for i in range(rect[3] - rect[1]):
529
+ box_map[min(rect[0], sample.shape[1] - 1), min(rect[1] + i, sample.shape[2] - 1)] = 10
530
+ box_map[min(rect[2], sample.shape[1] - 1), min(rect[1] + i, sample.shape[2] - 1)] = 10
531
+ box_map = box_map.unsqueeze(0).repeat(3, 1, 1)
532
+ return box_map
533
+
534
+
535
+ timerfunc = time.perf_counter
536
+
537
+ class measure_time(object):
538
+ def __enter__(self):
539
+ self.start = timerfunc()
540
+ return self
541
+
542
+ def __exit__(self, typ, value, traceback):
543
+ self.duration = timerfunc() - self.start
544
+
545
+ def __add__(self, other):
546
+ return self.duration + other.duration
547
+
548
+ def __sub__(self, other):
549
+ return self.duration - other.duration
550
+
551
+ def __str__(self):
552
+ return str(self.duration)
553
+
554
+
555
+ def log_test_results(test_dir):
556
+ test_dir = Path(test_dir)
557
+ logs = []
558
+ for d in test_dir.iterdir():
559
+ if d.is_dir() and (d / "log.txt").exists():
560
+ print(d.name)
561
+ with open(d / "log.txt") as f:
562
+ last = f.readlines()[-1]
563
+ j = json.loads(last)
564
+ j['name'] = d.name
565
+ logs.append(j)
566
+ df = pd.DataFrame(logs)
567
+
568
+ df.sort_values('name', inplace=True, ignore_index=True)
569
+ cols = list(df.columns)
570
+ cols = cols[-1:] + cols[:-1]
571
+ df = df[cols]
572
+
573
+ df.to_csv(test_dir / "logs.csv", index=False)
574
+
575
+
576
+ COLORS = {
577
+ 'muted blue': '#1f77b4',
578
+ 'safety orange': '#ff7f0e',
579
+ 'cooked asparagus green': '#2ca02c',
580
+ 'brick red': '#d62728',
581
+ 'muted purple': '#9467bd',
582
+ 'chestnut brown': '#8c564b',
583
+ 'raspberry yogurt pink': '#e377c2',
584
+ 'middle gray': '#7f7f7f',
585
+ 'curry yellow-green': '#bcbd22',
586
+ 'blue-teal': '#17becf',
587
+ 'muted blue light': '#419ede',
588
+ 'safety orange light': '#ffa85b',
589
+ 'cooked asparagus green light': '#4bce4b',
590
+ 'brick red light': '#e36667'
591
+ }
592
+
593
+
594
+ def plot_test_results(test_dir):
595
+ import plotly.graph_objects as go
596
+
597
+ test_dir = Path(test_dir)
598
+ df = pd.read_csv(test_dir / "logs.csv")
599
+ df.sort_values('name', inplace=True)
600
+
601
+ fig = go.Figure()
602
+ fig.add_trace(go.Scatter(x=df['name'], y=df['MAE'], line_color=COLORS['muted blue'],
603
+ mode='lines', name='MAE'))
604
+ fig.add_trace(go.Scatter(x=df['name'], y=df['RMSE'], line_color=COLORS['safety orange'],
605
+ mode='lines', name='RMSE'))
606
+ fig.add_trace(go.Scatter(x=df['name'], y=df['NAE'], line_color=COLORS['cooked asparagus green'],
607
+ mode='lines', name='NAE'))
608
+
609
+ fig.update_yaxes(type="log")
610
+ fig.write_image(test_dir / "plot.jpeg", scale=4)
611
+ fig.write_html(test_dir / "plot.html", auto_open=False)
612
+
613
+
614
+ def frames2vid(input_dir: str, output_file: str, pattern: str, fps: int, h=720, w=1280):
615
+ input_dir = Path(input_dir)
616
+ video_file = None
617
+ files = sorted(input_dir.glob(pattern))
618
+ video_file = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
619
+ for img in tqdm(files, total=len(files)):
620
+ frame = cv2.imread(str(img))
621
+ frame = cv2.resize(frame, (w, h))
622
+ video_file.write(frame)
623
+
624
+ video_file.release()
util/pos_embed.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """
51
+ embed_dim: output dimension for each position
52
+ pos: a list of positions to be encoded: size (M,)
53
+ out: (M, D)
54
+ """
55
+ assert embed_dim % 2 == 0
56
+ # omega = np.arange(embed_dim // 2, dtype=np.float)
57
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
58
+ omega /= embed_dim / 2.
59
+ omega = 1. / 10000**omega # (D/2,)
60
+
61
+ pos = pos.reshape(-1) # (M,)
62
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
63
+
64
+ emb_sin = np.sin(out) # (M, D/2)
65
+ emb_cos = np.cos(out) # (M, D/2)
66
+
67
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
68
+ return emb
69
+
70
+
71
+ # --------------------------------------------------------
72
+ # Interpolate position embeddings for high-resolution
73
+ # References:
74
+ # DeiT: https://github.com/facebookresearch/deit
75
+ # --------------------------------------------------------
76
+ def interpolate_pos_embed(model, checkpoint_model):
77
+ if 'pos_embed' in checkpoint_model:
78
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
79
+ embedding_size = pos_embed_checkpoint.shape[-1]
80
+ num_patches = model.patch_embed.num_patches
81
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
82
+ # height (== width) for the checkpoint position embedding
83
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
84
+ # height (== width) for the new position embedding
85
+ new_size = int(num_patches ** 0.5)
86
+ # class_token and dist_token are kept unchanged
87
+ if orig_size != new_size:
88
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
89
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
90
+ # only the position tokens are interpolated
91
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
92
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
93
+ pos_tokens = torch.nn.functional.interpolate(
94
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
95
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
96
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
97
+ checkpoint_model['pos_embed'] = new_pos_embed