Upload model
Browse files- config.json +4 -0
- configuration_cetacean_classifier.py +0 -23
- model.safetensors +1 -1
- modeling_cetacean_classifier.py +6 -1
- train.py +5 -173
config.json
CHANGED
@@ -21,6 +21,10 @@
|
|
21 |
"shear": 3,
|
22 |
"translate": 0.25
|
23 |
},
|
|
|
|
|
|
|
|
|
24 |
"batch_size": 8,
|
25 |
"bbox_conf_threshold": 0.01,
|
26 |
"bboxes": {
|
|
|
21 |
"shear": 3,
|
22 |
"translate": 0.25
|
23 |
},
|
24 |
+
"auto_map": {
|
25 |
+
"AutoConfig": "configuration_cetacean_classifier.CetaceanClassifierConfig",
|
26 |
+
"AutoModelForImageClassification": "modeling_cetacean_classifier.CetaceanClassifierModelForImageClassification"
|
27 |
+
},
|
28 |
"batch_size": 8,
|
29 |
"bbox_conf_threshold": 0.01,
|
30 |
"bboxes": {
|
configuration_cetacean_classifier.py
CHANGED
@@ -7,29 +7,6 @@ class CetaceanClassifierConfig(PretrainedConfig):
|
|
7 |
|
8 |
def __init__(
|
9 |
self,
|
10 |
-
# block_type="bottleneck",
|
11 |
-
# layers: List[int] = [3, 4, 6, 3],
|
12 |
-
# num_classes: int = 1000,
|
13 |
-
# input_channels: int = 3,
|
14 |
-
# cardinality: int = 1,
|
15 |
-
# base_width: int = 64,
|
16 |
-
# stem_width: int = 64,
|
17 |
-
# stem_type: str = "",
|
18 |
-
# avg_down: bool = False,
|
19 |
**kwargs,
|
20 |
):
|
21 |
-
# if block_type not in ["basic", "bottleneck"]:
|
22 |
-
# raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
|
23 |
-
# if stem_type not in ["", "deep", "deep-tiered"]:
|
24 |
-
# raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
|
25 |
-
|
26 |
-
# self.block_type = block_type
|
27 |
-
# self.layers = layers
|
28 |
-
# self.num_classes = num_classes
|
29 |
-
# self.input_channels = input_channels
|
30 |
-
# self.cardinality = cardinality
|
31 |
-
# self.base_width = base_width
|
32 |
-
# self.stem_width = stem_width
|
33 |
-
# self.stem_type = stem_type
|
34 |
-
# self.avg_down = avg_down
|
35 |
super().__init__(**kwargs)
|
|
|
7 |
|
8 |
def __init__(
|
9 |
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
**kwargs,
|
11 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
super().__init__(**kwargs)
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 296028464
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c9afc61a269bf406f5b23389c57e4efe365eb4b67aa62730b731916fb62b6f0
|
3 |
size 296028464
|
modeling_cetacean_classifier.py
CHANGED
@@ -44,7 +44,12 @@ class CetaceanClassifierModelForImageClassification(PreTrainedModel):
|
|
44 |
|
45 |
def __init__(self, config):
|
46 |
super().__init__(config)
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
self.model.eval()
|
49 |
|
50 |
def preprocess_image(self, img: Image) -> torch.Tensor:
|
|
|
44 |
|
45 |
def __init__(self, config):
|
46 |
super().__init__(config)
|
47 |
+
|
48 |
+
self.model = SphereClassifier(cfg=config.to_dict())
|
49 |
+
|
50 |
+
# load_from_checkpoint("cetacean_classifier/last.ckpt")
|
51 |
+
# self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
|
52 |
+
|
53 |
self.model.eval()
|
54 |
|
55 |
def preprocess_image(self, img: Image) -> torch.Tensor:
|
train.py
CHANGED
@@ -1,20 +1,9 @@
|
|
1 |
-
# import argparse
|
2 |
-
# import os
|
3 |
-
# import warnings
|
4 |
from typing import Dict, List, Optional, Tuple
|
5 |
|
6 |
import numpy as np
|
7 |
-
# import optuna
|
8 |
-
# import pandas as pd
|
9 |
import timm
|
10 |
import torch
|
11 |
-
# import wandb
|
12 |
-
# from optuna.integration import PyTorchLightningPruningCallback
|
13 |
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
14 |
-
# from pytorch_lightning import loggers as pl_loggers
|
15 |
-
# from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
16 |
-
# from sklearn.model_selection import StratifiedKFold
|
17 |
-
# from torch.utils.data import ConcatDataset, DataLoader
|
18 |
|
19 |
from .config import Config, load_config
|
20 |
# from .dataset import WhaleDataset, load_df
|
@@ -22,94 +11,19 @@ from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcente
|
|
22 |
from .utils import WarmupCosineLambda, map_dict, topk_average_precision
|
23 |
|
24 |
|
25 |
-
# def parse():
|
26 |
-
# parser = argparse.ArgumentParser(description="Training for HappyWhale")
|
27 |
-
# parser.add_argument("--out_base_dir", default="result")
|
28 |
-
# parser.add_argument("--in_base_dir", default="input")
|
29 |
-
# parser.add_argument("--exp_name", default="tmp")
|
30 |
-
# parser.add_argument("--load_snapshot", action="store_true")
|
31 |
-
# parser.add_argument("--save_checkpoint", action="store_true")
|
32 |
-
# parser.add_argument("--wandb_logger", action="store_true")
|
33 |
-
# parser.add_argument("--config_path", default="config/debug.yaml")
|
34 |
-
# return parser.parse_args()
|
35 |
-
|
36 |
-
|
37 |
-
# class WhaleDataModule(LightningDataModule):
|
38 |
-
# def __init__(
|
39 |
-
# self,
|
40 |
-
# df: pd.DataFrame,
|
41 |
-
# cfg: Config,
|
42 |
-
# image_dir: str,
|
43 |
-
# val_bbox_name: str,
|
44 |
-
# fold: int,
|
45 |
-
# additional_dataset: WhaleDataset = None,
|
46 |
-
# ):
|
47 |
-
# super().__init__()
|
48 |
-
# self.cfg = cfg
|
49 |
-
# self.image_dir = image_dir
|
50 |
-
# self.val_bbox_name = val_bbox_name
|
51 |
-
# self.additional_dataset = additional_dataset
|
52 |
-
# if cfg.n_data != -1:
|
53 |
-
# df = df.iloc[: cfg.n_data]
|
54 |
-
# self.all_df = df
|
55 |
-
# if fold == -1:
|
56 |
-
# self.train_df = df
|
57 |
-
# else:
|
58 |
-
# skf = StratifiedKFold(n_splits=cfg.n_splits, shuffle=True, random_state=0)
|
59 |
-
# train_idx, val_idx = list(skf.split(df, df.individual_id))[fold]
|
60 |
-
# self.train_df = df.iloc[train_idx].copy()
|
61 |
-
# self.val_df = df.iloc[val_idx].copy()
|
62 |
-
# # relabel ids not included in training data as "new individual"
|
63 |
-
# new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id)
|
64 |
-
# self.val_df.individual_id.mask(new_mask, cfg.num_classes, inplace=True)
|
65 |
-
# print(f"new: {(self.val_df.individual_id == cfg.num_classes).sum()} / {len(self.val_df)}")
|
66 |
-
|
67 |
-
# def get_dataset(self, df, data_aug):
|
68 |
-
# return WhaleDataset(df, self.cfg, self.image_dir, self.val_bbox_name, data_aug)
|
69 |
-
|
70 |
-
# def train_dataloader(self):
|
71 |
-
# dataset = self.get_dataset(self.train_df, True)
|
72 |
-
# if self.additional_dataset is not None:
|
73 |
-
# dataset = ConcatDataset([dataset, self.additional_dataset])
|
74 |
-
# return DataLoader(
|
75 |
-
# dataset,
|
76 |
-
# batch_size=self.cfg.batch_size,
|
77 |
-
# shuffle=True,
|
78 |
-
# num_workers=2,
|
79 |
-
# pin_memory=True,
|
80 |
-
# drop_last=True,
|
81 |
-
# )
|
82 |
-
|
83 |
-
# def val_dataloader(self):
|
84 |
-
# if self.cfg.n_splits == -1:
|
85 |
-
# return None
|
86 |
-
# return DataLoader(
|
87 |
-
# self.get_dataset(self.val_df, False),
|
88 |
-
# batch_size=self.cfg.batch_size,
|
89 |
-
# shuffle=False,
|
90 |
-
# num_workers=2,
|
91 |
-
# pin_memory=True,
|
92 |
-
# )
|
93 |
-
|
94 |
-
# def all_dataloader(self):
|
95 |
-
# return DataLoader(
|
96 |
-
# self.get_dataset(self.all_df, False),
|
97 |
-
# batch_size=self.cfg.batch_size,
|
98 |
-
# shuffle=False,
|
99 |
-
# num_workers=2,
|
100 |
-
# pin_memory=True,
|
101 |
-
# )
|
102 |
-
|
103 |
-
|
104 |
class SphereClassifier(LightningModule):
|
105 |
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
|
106 |
super().__init__()
|
|
|
107 |
if not isinstance(cfg, Config):
|
108 |
cfg = Config(cfg)
|
109 |
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
|
110 |
self.test_results_fp = None
|
111 |
|
112 |
-
|
|
|
|
|
|
|
113 |
|
114 |
# NN architecture
|
115 |
self.backbone = timm.create_model(
|
@@ -234,85 +148,3 @@ class SphereClassifier(LightningModule):
|
|
234 |
result = torch.cat([x[key] for x in outputs], dim=0)
|
235 |
epoch_results[key] = result.detach().cpu().numpy()
|
236 |
np.savez_compressed(self.test_results_fp, **epoch_results)
|
237 |
-
|
238 |
-
|
239 |
-
# def train(
|
240 |
-
# df: pd.DataFrame,
|
241 |
-
# args: argparse.Namespace,
|
242 |
-
# cfg: Config,
|
243 |
-
# fold: int,
|
244 |
-
# do_inference: bool = False,
|
245 |
-
# additional_dataset: WhaleDataset = None,
|
246 |
-
# optuna_trial: Optional[optuna.Trial] = None,
|
247 |
-
# ) -> Optional[float]:
|
248 |
-
# out_dir = f"{args.out_base_dir}/{args.exp_name}/{fold}"
|
249 |
-
# id_class_nums = df.individual_id.value_counts().sort_index().values
|
250 |
-
# species_class_nums = df.species.value_counts().sort_index().values
|
251 |
-
# model = SphereClassifier(cfg, id_class_nums=id_class_nums, species_class_nums=species_class_nums)
|
252 |
-
# data_module = WhaleDataModule(
|
253 |
-
# df, cfg, f"{args.in_base_dir}/train_images", cfg.val_bbox, fold, additional_dataset=additional_dataset
|
254 |
-
# )
|
255 |
-
# loggers = [pl_loggers.CSVLogger(out_dir)]
|
256 |
-
# if args.wandb_logger:
|
257 |
-
# loggers.append(
|
258 |
-
# pl_loggers.WandbLogger(
|
259 |
-
# project="kaggle-happywhale", group=args.exp_name, name=f"{args.exp_name}/{fold}", save_dir=out_dir
|
260 |
-
# )
|
261 |
-
# )
|
262 |
-
# callbacks = [LearningRateMonitor("epoch")]
|
263 |
-
# if optuna_trial is not None:
|
264 |
-
# callbacks.append(PyTorchLightningPruningCallback(optuna_trial, "val/mapNone"))
|
265 |
-
# if args.save_checkpoint:
|
266 |
-
# callbacks.append(ModelCheckpoint(out_dir, save_last=True, save_top_k=0))
|
267 |
-
# trainer = Trainer(
|
268 |
-
# gpus=torch.cuda.device_count(),
|
269 |
-
# max_epochs=cfg["max_epochs"],
|
270 |
-
# logger=loggers,
|
271 |
-
# callbacks=callbacks,
|
272 |
-
# checkpoint_callback=args.save_checkpoint,
|
273 |
-
# precision=16,
|
274 |
-
# sync_batchnorm=True,
|
275 |
-
# )
|
276 |
-
# ckpt_path = f"{out_dir}/last.ckpt"
|
277 |
-
# if not os.path.exists(ckpt_path) or not args.load_snapshot:
|
278 |
-
# ckpt_path = None
|
279 |
-
# trainer.fit(model, ckpt_path=ckpt_path, datamodule=data_module)
|
280 |
-
# if do_inference:
|
281 |
-
# for test_bbox in cfg.test_bboxes:
|
282 |
-
# # all train data
|
283 |
-
# model.test_results_fp = f"{out_dir}/train_{test_bbox}_results.npz"
|
284 |
-
# trainer.test(model, data_module.all_dataloader())
|
285 |
-
# # test data
|
286 |
-
# model.test_results_fp = f"{out_dir}/test_{test_bbox}_results.npz"
|
287 |
-
# df_test = load_df(args.in_base_dir, cfg, "sample_submission.csv", False)
|
288 |
-
# test_data_module = WhaleDataModule(df_test, cfg, f"{args.in_base_dir}/test_images", test_bbox, -1)
|
289 |
-
# trainer.test(model, test_data_module.all_dataloader())
|
290 |
-
|
291 |
-
# if args.wandb_logger:
|
292 |
-
# wandb.finish()
|
293 |
-
# if optuna_trial is not None:
|
294 |
-
# return trainer.callback_metrics["val/mapNone"].item()
|
295 |
-
# else:
|
296 |
-
# return None
|
297 |
-
|
298 |
-
|
299 |
-
# def main():
|
300 |
-
# args = parse()
|
301 |
-
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
302 |
-
# cfg = load_config(args.config_path, "config/default.yaml")
|
303 |
-
# print(cfg)
|
304 |
-
# df = load_df(args.in_base_dir, cfg, "train.csv", True)
|
305 |
-
# pseudo_dataset = None
|
306 |
-
# if cfg.pseudo_label is not None:
|
307 |
-
# pseudo_df = load_df(args.in_base_dir, cfg, cfg.pseudo_label, False)
|
308 |
-
# pseudo_dataset = WhaleDataset(
|
309 |
-
# pseudo_df[pseudo_df.conf > cfg.pseudo_conf_threshold], cfg, f"{args.in_base_dir}/test_images", "", True
|
310 |
-
# )
|
311 |
-
# if cfg["n_splits"] == -1:
|
312 |
-
# train(df, args, cfg, -1, do_inference=True, additional_dataset=pseudo_dataset)
|
313 |
-
# else:
|
314 |
-
# train(df, args, cfg, 0, do_inference=True, additional_dataset=pseudo_dataset)
|
315 |
-
|
316 |
-
|
317 |
-
# if __name__ == "__main__":
|
318 |
-
# main()
|
|
|
|
|
|
|
|
|
1 |
from typing import Dict, List, Optional, Tuple
|
2 |
|
3 |
import numpy as np
|
|
|
|
|
4 |
import timm
|
5 |
import torch
|
|
|
|
|
6 |
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
|
|
|
|
|
|
|
|
7 |
|
8 |
from .config import Config, load_config
|
9 |
# from .dataset import WhaleDataset, load_df
|
|
|
11 |
from .utils import WarmupCosineLambda, map_dict, topk_average_precision
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
class SphereClassifier(LightningModule):
|
15 |
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
|
16 |
super().__init__()
|
17 |
+
# import pdb; pdb.set_trace()
|
18 |
if not isinstance(cfg, Config):
|
19 |
cfg = Config(cfg)
|
20 |
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
|
21 |
self.test_results_fp = None
|
22 |
|
23 |
+
# import json
|
24 |
+
# cfg_json = json.dumps(cfg)
|
25 |
+
# with open("config_extracted.json", "w") as file:
|
26 |
+
# file.write(cfg_json)
|
27 |
|
28 |
# NN architecture
|
29 |
self.backbone = timm.create_model(
|
|
|
148 |
result = torch.cat([x[key] for x in outputs], dim=0)
|
149 |
epoch_results[key] = result.detach().cpu().numpy()
|
150 |
np.savez_compressed(self.test_results_fp, **epoch_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|