Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
MalloryWittwerEPFL commited on
Commit
3be2146
·
verified ·
1 Parent(s): 6a41652

Upload model

Browse files
config.json CHANGED
@@ -21,6 +21,10 @@
21
  "shear": 3,
22
  "translate": 0.25
23
  },
 
 
 
 
24
  "batch_size": 8,
25
  "bbox_conf_threshold": 0.01,
26
  "bboxes": {
 
21
  "shear": 3,
22
  "translate": 0.25
23
  },
24
+ "auto_map": {
25
+ "AutoConfig": "configuration_cetacean_classifier.CetaceanClassifierConfig",
26
+ "AutoModelForImageClassification": "modeling_cetacean_classifier.CetaceanClassifierModelForImageClassification"
27
+ },
28
  "batch_size": 8,
29
  "bbox_conf_threshold": 0.01,
30
  "bboxes": {
configuration_cetacean_classifier.py CHANGED
@@ -7,29 +7,6 @@ class CetaceanClassifierConfig(PretrainedConfig):
7
 
8
  def __init__(
9
  self,
10
- # block_type="bottleneck",
11
- # layers: List[int] = [3, 4, 6, 3],
12
- # num_classes: int = 1000,
13
- # input_channels: int = 3,
14
- # cardinality: int = 1,
15
- # base_width: int = 64,
16
- # stem_width: int = 64,
17
- # stem_type: str = "",
18
- # avg_down: bool = False,
19
  **kwargs,
20
  ):
21
- # if block_type not in ["basic", "bottleneck"]:
22
- # raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
23
- # if stem_type not in ["", "deep", "deep-tiered"]:
24
- # raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
25
-
26
- # self.block_type = block_type
27
- # self.layers = layers
28
- # self.num_classes = num_classes
29
- # self.input_channels = input_channels
30
- # self.cardinality = cardinality
31
- # self.base_width = base_width
32
- # self.stem_width = stem_width
33
- # self.stem_type = stem_type
34
- # self.avg_down = avg_down
35
  super().__init__(**kwargs)
 
7
 
8
  def __init__(
9
  self,
 
 
 
 
 
 
 
 
 
10
  **kwargs,
11
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  super().__init__(**kwargs)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7513de376ac126563e7785aabedcee668ce9c9b3d20663f49e66645f480a416c
3
  size 296028464
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c9afc61a269bf406f5b23389c57e4efe365eb4b67aa62730b731916fb62b6f0
3
  size 296028464
modeling_cetacean_classifier.py CHANGED
@@ -44,7 +44,12 @@ class CetaceanClassifierModelForImageClassification(PreTrainedModel):
44
 
45
  def __init__(self, config):
46
  super().__init__(config)
47
- self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
 
 
 
 
 
48
  self.model.eval()
49
 
50
  def preprocess_image(self, img: Image) -> torch.Tensor:
 
44
 
45
  def __init__(self, config):
46
  super().__init__(config)
47
+
48
+ self.model = SphereClassifier(cfg=config.to_dict())
49
+
50
+ # load_from_checkpoint("cetacean_classifier/last.ckpt")
51
+ # self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
52
+
53
  self.model.eval()
54
 
55
  def preprocess_image(self, img: Image) -> torch.Tensor:
train.py CHANGED
@@ -1,20 +1,9 @@
1
- # import argparse
2
- # import os
3
- # import warnings
4
  from typing import Dict, List, Optional, Tuple
5
 
6
  import numpy as np
7
- # import optuna
8
- # import pandas as pd
9
  import timm
10
  import torch
11
- # import wandb
12
- # from optuna.integration import PyTorchLightningPruningCallback
13
  from pytorch_lightning import LightningDataModule, LightningModule, Trainer
14
- # from pytorch_lightning import loggers as pl_loggers
15
- # from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
16
- # from sklearn.model_selection import StratifiedKFold
17
- # from torch.utils.data import ConcatDataset, DataLoader
18
 
19
  from .config import Config, load_config
20
  # from .dataset import WhaleDataset, load_df
@@ -22,94 +11,19 @@ from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcente
22
  from .utils import WarmupCosineLambda, map_dict, topk_average_precision
23
 
24
 
25
- # def parse():
26
- # parser = argparse.ArgumentParser(description="Training for HappyWhale")
27
- # parser.add_argument("--out_base_dir", default="result")
28
- # parser.add_argument("--in_base_dir", default="input")
29
- # parser.add_argument("--exp_name", default="tmp")
30
- # parser.add_argument("--load_snapshot", action="store_true")
31
- # parser.add_argument("--save_checkpoint", action="store_true")
32
- # parser.add_argument("--wandb_logger", action="store_true")
33
- # parser.add_argument("--config_path", default="config/debug.yaml")
34
- # return parser.parse_args()
35
-
36
-
37
- # class WhaleDataModule(LightningDataModule):
38
- # def __init__(
39
- # self,
40
- # df: pd.DataFrame,
41
- # cfg: Config,
42
- # image_dir: str,
43
- # val_bbox_name: str,
44
- # fold: int,
45
- # additional_dataset: WhaleDataset = None,
46
- # ):
47
- # super().__init__()
48
- # self.cfg = cfg
49
- # self.image_dir = image_dir
50
- # self.val_bbox_name = val_bbox_name
51
- # self.additional_dataset = additional_dataset
52
- # if cfg.n_data != -1:
53
- # df = df.iloc[: cfg.n_data]
54
- # self.all_df = df
55
- # if fold == -1:
56
- # self.train_df = df
57
- # else:
58
- # skf = StratifiedKFold(n_splits=cfg.n_splits, shuffle=True, random_state=0)
59
- # train_idx, val_idx = list(skf.split(df, df.individual_id))[fold]
60
- # self.train_df = df.iloc[train_idx].copy()
61
- # self.val_df = df.iloc[val_idx].copy()
62
- # # relabel ids not included in training data as "new individual"
63
- # new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id)
64
- # self.val_df.individual_id.mask(new_mask, cfg.num_classes, inplace=True)
65
- # print(f"new: {(self.val_df.individual_id == cfg.num_classes).sum()} / {len(self.val_df)}")
66
-
67
- # def get_dataset(self, df, data_aug):
68
- # return WhaleDataset(df, self.cfg, self.image_dir, self.val_bbox_name, data_aug)
69
-
70
- # def train_dataloader(self):
71
- # dataset = self.get_dataset(self.train_df, True)
72
- # if self.additional_dataset is not None:
73
- # dataset = ConcatDataset([dataset, self.additional_dataset])
74
- # return DataLoader(
75
- # dataset,
76
- # batch_size=self.cfg.batch_size,
77
- # shuffle=True,
78
- # num_workers=2,
79
- # pin_memory=True,
80
- # drop_last=True,
81
- # )
82
-
83
- # def val_dataloader(self):
84
- # if self.cfg.n_splits == -1:
85
- # return None
86
- # return DataLoader(
87
- # self.get_dataset(self.val_df, False),
88
- # batch_size=self.cfg.batch_size,
89
- # shuffle=False,
90
- # num_workers=2,
91
- # pin_memory=True,
92
- # )
93
-
94
- # def all_dataloader(self):
95
- # return DataLoader(
96
- # self.get_dataset(self.all_df, False),
97
- # batch_size=self.cfg.batch_size,
98
- # shuffle=False,
99
- # num_workers=2,
100
- # pin_memory=True,
101
- # )
102
-
103
-
104
  class SphereClassifier(LightningModule):
105
  def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
106
  super().__init__()
 
107
  if not isinstance(cfg, Config):
108
  cfg = Config(cfg)
109
  self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
110
  self.test_results_fp = None
111
 
112
- print(cfg.model_name)
 
 
 
113
 
114
  # NN architecture
115
  self.backbone = timm.create_model(
@@ -234,85 +148,3 @@ class SphereClassifier(LightningModule):
234
  result = torch.cat([x[key] for x in outputs], dim=0)
235
  epoch_results[key] = result.detach().cpu().numpy()
236
  np.savez_compressed(self.test_results_fp, **epoch_results)
237
-
238
-
239
- # def train(
240
- # df: pd.DataFrame,
241
- # args: argparse.Namespace,
242
- # cfg: Config,
243
- # fold: int,
244
- # do_inference: bool = False,
245
- # additional_dataset: WhaleDataset = None,
246
- # optuna_trial: Optional[optuna.Trial] = None,
247
- # ) -> Optional[float]:
248
- # out_dir = f"{args.out_base_dir}/{args.exp_name}/{fold}"
249
- # id_class_nums = df.individual_id.value_counts().sort_index().values
250
- # species_class_nums = df.species.value_counts().sort_index().values
251
- # model = SphereClassifier(cfg, id_class_nums=id_class_nums, species_class_nums=species_class_nums)
252
- # data_module = WhaleDataModule(
253
- # df, cfg, f"{args.in_base_dir}/train_images", cfg.val_bbox, fold, additional_dataset=additional_dataset
254
- # )
255
- # loggers = [pl_loggers.CSVLogger(out_dir)]
256
- # if args.wandb_logger:
257
- # loggers.append(
258
- # pl_loggers.WandbLogger(
259
- # project="kaggle-happywhale", group=args.exp_name, name=f"{args.exp_name}/{fold}", save_dir=out_dir
260
- # )
261
- # )
262
- # callbacks = [LearningRateMonitor("epoch")]
263
- # if optuna_trial is not None:
264
- # callbacks.append(PyTorchLightningPruningCallback(optuna_trial, "val/mapNone"))
265
- # if args.save_checkpoint:
266
- # callbacks.append(ModelCheckpoint(out_dir, save_last=True, save_top_k=0))
267
- # trainer = Trainer(
268
- # gpus=torch.cuda.device_count(),
269
- # max_epochs=cfg["max_epochs"],
270
- # logger=loggers,
271
- # callbacks=callbacks,
272
- # checkpoint_callback=args.save_checkpoint,
273
- # precision=16,
274
- # sync_batchnorm=True,
275
- # )
276
- # ckpt_path = f"{out_dir}/last.ckpt"
277
- # if not os.path.exists(ckpt_path) or not args.load_snapshot:
278
- # ckpt_path = None
279
- # trainer.fit(model, ckpt_path=ckpt_path, datamodule=data_module)
280
- # if do_inference:
281
- # for test_bbox in cfg.test_bboxes:
282
- # # all train data
283
- # model.test_results_fp = f"{out_dir}/train_{test_bbox}_results.npz"
284
- # trainer.test(model, data_module.all_dataloader())
285
- # # test data
286
- # model.test_results_fp = f"{out_dir}/test_{test_bbox}_results.npz"
287
- # df_test = load_df(args.in_base_dir, cfg, "sample_submission.csv", False)
288
- # test_data_module = WhaleDataModule(df_test, cfg, f"{args.in_base_dir}/test_images", test_bbox, -1)
289
- # trainer.test(model, test_data_module.all_dataloader())
290
-
291
- # if args.wandb_logger:
292
- # wandb.finish()
293
- # if optuna_trial is not None:
294
- # return trainer.callback_metrics["val/mapNone"].item()
295
- # else:
296
- # return None
297
-
298
-
299
- # def main():
300
- # args = parse()
301
- # warnings.filterwarnings("ignore", ".*does not have many workers.*")
302
- # cfg = load_config(args.config_path, "config/default.yaml")
303
- # print(cfg)
304
- # df = load_df(args.in_base_dir, cfg, "train.csv", True)
305
- # pseudo_dataset = None
306
- # if cfg.pseudo_label is not None:
307
- # pseudo_df = load_df(args.in_base_dir, cfg, cfg.pseudo_label, False)
308
- # pseudo_dataset = WhaleDataset(
309
- # pseudo_df[pseudo_df.conf > cfg.pseudo_conf_threshold], cfg, f"{args.in_base_dir}/test_images", "", True
310
- # )
311
- # if cfg["n_splits"] == -1:
312
- # train(df, args, cfg, -1, do_inference=True, additional_dataset=pseudo_dataset)
313
- # else:
314
- # train(df, args, cfg, 0, do_inference=True, additional_dataset=pseudo_dataset)
315
-
316
-
317
- # if __name__ == "__main__":
318
- # main()
 
 
 
 
1
  from typing import Dict, List, Optional, Tuple
2
 
3
  import numpy as np
 
 
4
  import timm
5
  import torch
 
 
6
  from pytorch_lightning import LightningDataModule, LightningModule, Trainer
 
 
 
 
7
 
8
  from .config import Config, load_config
9
  # from .dataset import WhaleDataset, load_df
 
11
  from .utils import WarmupCosineLambda, map_dict, topk_average_precision
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class SphereClassifier(LightningModule):
15
  def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
16
  super().__init__()
17
+ # import pdb; pdb.set_trace()
18
  if not isinstance(cfg, Config):
19
  cfg = Config(cfg)
20
  self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
21
  self.test_results_fp = None
22
 
23
+ # import json
24
+ # cfg_json = json.dumps(cfg)
25
+ # with open("config_extracted.json", "w") as file:
26
+ # file.write(cfg_json)
27
 
28
  # NN architecture
29
  self.backbone = timm.create_model(
 
148
  result = torch.cat([x[key] for x in outputs], dim=0)
149
  epoch_results[key] = result.detach().cpu().numpy()
150
  np.savez_compressed(self.test_results_fp, **epoch_results)