File size: 5,266 Bytes
4d679c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
061bac4
 
 
4d679c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
061bac4
4d679c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os

import onnxruntime
import torch
import torch.utils.data
import torchvision
from torch import nn
from torchvision.transforms.functional import InterpolationMode

import utils


def evaluate(
    criterion,
    data_loader,
    device,
    model=None,
    model_onnx_path=None,
    print_freq=100,
    log_suffix="",
):
    if model_onnx_path:
        session = onnxruntime.InferenceSession(
            model_onnx_path, providers=["CPUExecutionProvider"]
        )
        input_name = session.get_inputs()[0].name

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            target = target.to(device, non_blocking=True)
            image = image.to(device)

            if model_onnx_path:
                # from torch to numpy (ort)
                input_data = image.cpu().numpy()

                output_data = session.run([], {input_name: input_data})[0]

                # from numpy to torch
                output = torch.from_numpy(output_data).to(device)
            elif model:
                output = model(image)

            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    metric_logger.synchronize_between_processes()

    print(
        f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}"
    )
    return metric_logger.acc1.global_avg


def load_data(valdir):
    # Data loading code
    print("Loading data")
    interpolation = InterpolationMode("bilinear")

    preprocessing = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(256, interpolation=interpolation),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.PILToTensor(),
            torchvision.transforms.ConvertImageDtype(torch.float),
            torchvision.transforms.Normalize(
                mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
            ),
        ]
    )

    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        preprocessing,
    )

    print("Creating data loaders")
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset_test, test_sampler


def main(args):
    print(args)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    val_dir = os.path.join(args.data_path, "val")
    dataset_test, test_sampler = load_data(val_dir)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
    )

    print("Creating model")

    criterion = nn.CrossEntropyLoss()

    model = None
    if args.model_ckpt:
        checkpoint = torch.load(args.model_ckpt, map_location="cpu")
        model = checkpoint["model_ckpt"]
        if "model_ema" in checkpoint:
            state_dict = {}
            for key, value in checkpoint["model_ema"].items():
                if not "module." in key:
                    continue
                state_dict[key.replace("module.", "")] = value
            model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()

    accuracy = evaluate(
        model=model,
        model_onnx_path=args.model_onnx,
        criterion=criterion,
        data_loader=data_loader_test,
        device=device,
    )
    print(f"Model accuracy is: {accuracy}")


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(
        description="PyTorch Classification Training", add_help=add_help
    )

    parser.add_argument(
        "--data-path", default="datasets/imagenet", type=str, help="dataset path"
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        default=32,
        type=int,
        help="images per gpu, the total batch size is $NGPU x batch_size",
    )
    parser.add_argument(
        "-j",
        "--workers",
        default=16,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 16)",
    )
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument(
        "--model-onnx", default="", type=str, help="path of .onnx checkpoint"
    )
    parser.add_argument(
        "--model-ckpt", default="", type=str, help="path of .pth checkpoint"
    )

    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)