Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
MalloryWittwerEPFL's picture
Upload model
6257083 verified
raw
history blame
14 kB
# import argparse
# import os
# import warnings
from typing import Dict, List, Optional, Tuple
import numpy as np
# import optuna
# import pandas as pd
import timm
import torch
# import wandb
# from optuna.integration import PyTorchLightningPruningCallback
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
# from pytorch_lightning import loggers as pl_loggers
# from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
# from sklearn.model_selection import StratifiedKFold
# from torch.utils.data import ConcatDataset, DataLoader
from .config import Config, load_config
# from .dataset import WhaleDataset, load_df
from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
from .utils import WarmupCosineLambda, map_dict, topk_average_precision
# def parse():
# parser = argparse.ArgumentParser(description="Training for HappyWhale")
# parser.add_argument("--out_base_dir", default="result")
# parser.add_argument("--in_base_dir", default="input")
# parser.add_argument("--exp_name", default="tmp")
# parser.add_argument("--load_snapshot", action="store_true")
# parser.add_argument("--save_checkpoint", action="store_true")
# parser.add_argument("--wandb_logger", action="store_true")
# parser.add_argument("--config_path", default="config/debug.yaml")
# return parser.parse_args()
# class WhaleDataModule(LightningDataModule):
# def __init__(
# self,
# df: pd.DataFrame,
# cfg: Config,
# image_dir: str,
# val_bbox_name: str,
# fold: int,
# additional_dataset: WhaleDataset = None,
# ):
# super().__init__()
# self.cfg = cfg
# self.image_dir = image_dir
# self.val_bbox_name = val_bbox_name
# self.additional_dataset = additional_dataset
# if cfg.n_data != -1:
# df = df.iloc[: cfg.n_data]
# self.all_df = df
# if fold == -1:
# self.train_df = df
# else:
# skf = StratifiedKFold(n_splits=cfg.n_splits, shuffle=True, random_state=0)
# train_idx, val_idx = list(skf.split(df, df.individual_id))[fold]
# self.train_df = df.iloc[train_idx].copy()
# self.val_df = df.iloc[val_idx].copy()
# # relabel ids not included in training data as "new individual"
# new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id)
# self.val_df.individual_id.mask(new_mask, cfg.num_classes, inplace=True)
# print(f"new: {(self.val_df.individual_id == cfg.num_classes).sum()} / {len(self.val_df)}")
# def get_dataset(self, df, data_aug):
# return WhaleDataset(df, self.cfg, self.image_dir, self.val_bbox_name, data_aug)
# def train_dataloader(self):
# dataset = self.get_dataset(self.train_df, True)
# if self.additional_dataset is not None:
# dataset = ConcatDataset([dataset, self.additional_dataset])
# return DataLoader(
# dataset,
# batch_size=self.cfg.batch_size,
# shuffle=True,
# num_workers=2,
# pin_memory=True,
# drop_last=True,
# )
# def val_dataloader(self):
# if self.cfg.n_splits == -1:
# return None
# return DataLoader(
# self.get_dataset(self.val_df, False),
# batch_size=self.cfg.batch_size,
# shuffle=False,
# num_workers=2,
# pin_memory=True,
# )
# def all_dataloader(self):
# return DataLoader(
# self.get_dataset(self.all_df, False),
# batch_size=self.cfg.batch_size,
# shuffle=False,
# num_workers=2,
# pin_memory=True,
# )
class SphereClassifier(LightningModule):
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
super().__init__()
if not isinstance(cfg, Config):
cfg = Config(cfg)
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
self.test_results_fp = None
print(cfg.model_name)
# NN architecture
self.backbone = timm.create_model(
cfg.model_name,
in_chans=3,
pretrained=cfg.pretrained,
num_classes=0,
features_only=True,
out_indices=cfg.out_indices,
)
feature_dims = self.backbone.feature_info.channels()
print(f"feature dims: {feature_dims}")
self.global_pools = torch.nn.ModuleList(
[GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
)
self.mid_features = np.sum(feature_dims)
if cfg.normalization == "batchnorm":
self.neck = torch.nn.BatchNorm1d(self.mid_features)
elif cfg.normalization == "layernorm":
self.neck = torch.nn.LayerNorm(self.mid_features)
self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
if id_class_nums is not None and species_class_nums is not None:
margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
margins_species = (
np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
+ cfg.margin_cons_species
)
print("margins_id", margins_id)
print("margins_species", margins_species)
self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
self.loss_fn_id = torch.nn.CrossEntropyLoss()
self.loss_fn_species = torch.nn.CrossEntropyLoss()
def get_feat(self, x: torch.Tensor) -> torch.Tensor:
ms = self.backbone(x)
h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
return self.neck(h)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
feat = self.get_feat(x)
return self.head_id(feat), self.head_species(feat)
def training_step(self, batch, batch_idx):
x, ids, species = batch["image"], batch["label"], batch["label_species"]
logits_ids, logits_species = self(x)
margin_logits_ids = self.margin_fn_id(logits_ids, ids)
loss_ids = self.loss_fn_id(margin_logits_ids, ids)
loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
with torch.no_grad():
self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
self.log_dict(
{"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
on_step=False,
on_epoch=True,
)
return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)
def validation_step(self, batch, batch_idx):
x, ids, species = batch["image"], batch["label"], batch["label_species"]
out1, out_species1 = self(x)
out2, out_species2 = self(x.flip(3))
output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
self.log_dict(
{"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
on_step=False,
on_epoch=True,
)
def configure_optimizers(self):
backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
head_params = (
list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
)
params = [
{"params": backbone_params, "lr": self.hparams.lr_backbone},
{"params": head_params, "lr": self.hparams.lr_head},
]
if self.hparams.optimizer == "Adam":
optimizer = torch.optim.Adam(params)
elif self.hparams.optimizer == "AdamW":
optimizer = torch.optim.AdamW(params)
elif self.hparams.optimizer == "RAdam":
optimizer = torch.optim.RAdam(params)
warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
cycle_steps = self.hparams.max_epochs - warmup_steps
lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return [optimizer], [scheduler]
def test_step(self, batch, batch_idx):
x = batch["image"]
feat1 = self.get_feat(x)
out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
feat2 = self.get_feat(x.flip(3))
out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
return {
"original_index": batch["original_index"],
"label": batch["label"],
"label_species": batch["label_species"],
"pred_logit": pred_logit[:, :1000],
"pred_idx": pred_idx[:, :1000],
"pred_species": ((out_species1 + out_species2) / 2).cpu(),
"embed_features1": feat1.cpu(),
"embed_features2": feat2.cpu(),
}
def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
outputs = self.all_gather(outputs)
if self.trainer.global_rank == 0:
epoch_results: Dict[str, np.ndarray] = {}
for key in outputs[0].keys():
if torch.cuda.device_count() > 1:
result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
else:
result = torch.cat([x[key] for x in outputs], dim=0)
epoch_results[key] = result.detach().cpu().numpy()
np.savez_compressed(self.test_results_fp, **epoch_results)
# def train(
# df: pd.DataFrame,
# args: argparse.Namespace,
# cfg: Config,
# fold: int,
# do_inference: bool = False,
# additional_dataset: WhaleDataset = None,
# optuna_trial: Optional[optuna.Trial] = None,
# ) -> Optional[float]:
# out_dir = f"{args.out_base_dir}/{args.exp_name}/{fold}"
# id_class_nums = df.individual_id.value_counts().sort_index().values
# species_class_nums = df.species.value_counts().sort_index().values
# model = SphereClassifier(cfg, id_class_nums=id_class_nums, species_class_nums=species_class_nums)
# data_module = WhaleDataModule(
# df, cfg, f"{args.in_base_dir}/train_images", cfg.val_bbox, fold, additional_dataset=additional_dataset
# )
# loggers = [pl_loggers.CSVLogger(out_dir)]
# if args.wandb_logger:
# loggers.append(
# pl_loggers.WandbLogger(
# project="kaggle-happywhale", group=args.exp_name, name=f"{args.exp_name}/{fold}", save_dir=out_dir
# )
# )
# callbacks = [LearningRateMonitor("epoch")]
# if optuna_trial is not None:
# callbacks.append(PyTorchLightningPruningCallback(optuna_trial, "val/mapNone"))
# if args.save_checkpoint:
# callbacks.append(ModelCheckpoint(out_dir, save_last=True, save_top_k=0))
# trainer = Trainer(
# gpus=torch.cuda.device_count(),
# max_epochs=cfg["max_epochs"],
# logger=loggers,
# callbacks=callbacks,
# checkpoint_callback=args.save_checkpoint,
# precision=16,
# sync_batchnorm=True,
# )
# ckpt_path = f"{out_dir}/last.ckpt"
# if not os.path.exists(ckpt_path) or not args.load_snapshot:
# ckpt_path = None
# trainer.fit(model, ckpt_path=ckpt_path, datamodule=data_module)
# if do_inference:
# for test_bbox in cfg.test_bboxes:
# # all train data
# model.test_results_fp = f"{out_dir}/train_{test_bbox}_results.npz"
# trainer.test(model, data_module.all_dataloader())
# # test data
# model.test_results_fp = f"{out_dir}/test_{test_bbox}_results.npz"
# df_test = load_df(args.in_base_dir, cfg, "sample_submission.csv", False)
# test_data_module = WhaleDataModule(df_test, cfg, f"{args.in_base_dir}/test_images", test_bbox, -1)
# trainer.test(model, test_data_module.all_dataloader())
# if args.wandb_logger:
# wandb.finish()
# if optuna_trial is not None:
# return trainer.callback_metrics["val/mapNone"].item()
# else:
# return None
# def main():
# args = parse()
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
# cfg = load_config(args.config_path, "config/default.yaml")
# print(cfg)
# df = load_df(args.in_base_dir, cfg, "train.csv", True)
# pseudo_dataset = None
# if cfg.pseudo_label is not None:
# pseudo_df = load_df(args.in_base_dir, cfg, cfg.pseudo_label, False)
# pseudo_dataset = WhaleDataset(
# pseudo_df[pseudo_df.conf > cfg.pseudo_conf_threshold], cfg, f"{args.in_base_dir}/test_images", "", True
# )
# if cfg["n_splits"] == -1:
# train(df, args, cfg, -1, do_inference=True, additional_dataset=pseudo_dataset)
# else:
# train(df, args, cfg, 0, do_inference=True, additional_dataset=pseudo_dataset)
# if __name__ == "__main__":
# main()