# import argparse | |
# import os | |
# import warnings | |
from typing import Dict, List, Optional, Tuple | |
import numpy as np | |
# import optuna | |
# import pandas as pd | |
import timm | |
import torch | |
# import wandb | |
# from optuna.integration import PyTorchLightningPruningCallback | |
from pytorch_lightning import LightningDataModule, LightningModule, Trainer | |
# from pytorch_lightning import loggers as pl_loggers | |
# from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | |
# from sklearn.model_selection import StratifiedKFold | |
# from torch.utils.data import ConcatDataset, DataLoader | |
from .config import Config, load_config | |
# from .dataset import WhaleDataset, load_df | |
from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM | |
from .utils import WarmupCosineLambda, map_dict, topk_average_precision | |
# def parse(): | |
# parser = argparse.ArgumentParser(description="Training for HappyWhale") | |
# parser.add_argument("--out_base_dir", default="result") | |
# parser.add_argument("--in_base_dir", default="input") | |
# parser.add_argument("--exp_name", default="tmp") | |
# parser.add_argument("--load_snapshot", action="store_true") | |
# parser.add_argument("--save_checkpoint", action="store_true") | |
# parser.add_argument("--wandb_logger", action="store_true") | |
# parser.add_argument("--config_path", default="config/debug.yaml") | |
# return parser.parse_args() | |
# class WhaleDataModule(LightningDataModule): | |
# def __init__( | |
# self, | |
# df: pd.DataFrame, | |
# cfg: Config, | |
# image_dir: str, | |
# val_bbox_name: str, | |
# fold: int, | |
# additional_dataset: WhaleDataset = None, | |
# ): | |
# super().__init__() | |
# self.cfg = cfg | |
# self.image_dir = image_dir | |
# self.val_bbox_name = val_bbox_name | |
# self.additional_dataset = additional_dataset | |
# if cfg.n_data != -1: | |
# df = df.iloc[: cfg.n_data] | |
# self.all_df = df | |
# if fold == -1: | |
# self.train_df = df | |
# else: | |
# skf = StratifiedKFold(n_splits=cfg.n_splits, shuffle=True, random_state=0) | |
# train_idx, val_idx = list(skf.split(df, df.individual_id))[fold] | |
# self.train_df = df.iloc[train_idx].copy() | |
# self.val_df = df.iloc[val_idx].copy() | |
# # relabel ids not included in training data as "new individual" | |
# new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id) | |
# self.val_df.individual_id.mask(new_mask, cfg.num_classes, inplace=True) | |
# print(f"new: {(self.val_df.individual_id == cfg.num_classes).sum()} / {len(self.val_df)}") | |
# def get_dataset(self, df, data_aug): | |
# return WhaleDataset(df, self.cfg, self.image_dir, self.val_bbox_name, data_aug) | |
# def train_dataloader(self): | |
# dataset = self.get_dataset(self.train_df, True) | |
# if self.additional_dataset is not None: | |
# dataset = ConcatDataset([dataset, self.additional_dataset]) | |
# return DataLoader( | |
# dataset, | |
# batch_size=self.cfg.batch_size, | |
# shuffle=True, | |
# num_workers=2, | |
# pin_memory=True, | |
# drop_last=True, | |
# ) | |
# def val_dataloader(self): | |
# if self.cfg.n_splits == -1: | |
# return None | |
# return DataLoader( | |
# self.get_dataset(self.val_df, False), | |
# batch_size=self.cfg.batch_size, | |
# shuffle=False, | |
# num_workers=2, | |
# pin_memory=True, | |
# ) | |
# def all_dataloader(self): | |
# return DataLoader( | |
# self.get_dataset(self.all_df, False), | |
# batch_size=self.cfg.batch_size, | |
# shuffle=False, | |
# num_workers=2, | |
# pin_memory=True, | |
# ) | |
class SphereClassifier(LightningModule): | |
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None): | |
super().__init__() | |
if not isinstance(cfg, Config): | |
cfg = Config(cfg) | |
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"]) | |
self.test_results_fp = None | |
print(cfg.model_name) | |
# NN architecture | |
self.backbone = timm.create_model( | |
cfg.model_name, | |
in_chans=3, | |
pretrained=cfg.pretrained, | |
num_classes=0, | |
features_only=True, | |
out_indices=cfg.out_indices, | |
) | |
feature_dims = self.backbone.feature_info.channels() | |
print(f"feature dims: {feature_dims}") | |
self.global_pools = torch.nn.ModuleList( | |
[GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices] | |
) | |
self.mid_features = np.sum(feature_dims) | |
if cfg.normalization == "batchnorm": | |
self.neck = torch.nn.BatchNorm1d(self.mid_features) | |
elif cfg.normalization == "layernorm": | |
self.neck = torch.nn.LayerNorm(self.mid_features) | |
self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id) | |
self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species) | |
if id_class_nums is not None and species_class_nums is not None: | |
margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id | |
margins_species = ( | |
np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species | |
+ cfg.margin_cons_species | |
) | |
print("margins_id", margins_id) | |
print("margins_species", margins_species) | |
self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id) | |
self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species) | |
self.loss_fn_id = torch.nn.CrossEntropyLoss() | |
self.loss_fn_species = torch.nn.CrossEntropyLoss() | |
def get_feat(self, x: torch.Tensor) -> torch.Tensor: | |
ms = self.backbone(x) | |
h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1) | |
return self.neck(h) | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
feat = self.get_feat(x) | |
return self.head_id(feat), self.head_species(feat) | |
def training_step(self, batch, batch_idx): | |
x, ids, species = batch["image"], batch["label"], batch["label_species"] | |
logits_ids, logits_species = self(x) | |
margin_logits_ids = self.margin_fn_id(logits_ids, ids) | |
loss_ids = self.loss_fn_id(margin_logits_ids, ids) | |
loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species) | |
self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True) | |
self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True) | |
with torch.no_grad(): | |
self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True) | |
self.log_dict( | |
{"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()}, | |
on_step=False, | |
on_epoch=True, | |
) | |
return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio) | |
def validation_step(self, batch, batch_idx): | |
x, ids, species = batch["image"], batch["label"], batch["label_species"] | |
out1, out_species1 = self(x) | |
out2, out_species2 = self(x.flip(3)) | |
output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2 | |
self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True) | |
self.log_dict( | |
{"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()}, | |
on_step=False, | |
on_epoch=True, | |
) | |
def configure_optimizers(self): | |
backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters()) | |
head_params = ( | |
list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters()) | |
) | |
params = [ | |
{"params": backbone_params, "lr": self.hparams.lr_backbone}, | |
{"params": head_params, "lr": self.hparams.lr_head}, | |
] | |
if self.hparams.optimizer == "Adam": | |
optimizer = torch.optim.Adam(params) | |
elif self.hparams.optimizer == "AdamW": | |
optimizer = torch.optim.AdamW(params) | |
elif self.hparams.optimizer == "RAdam": | |
optimizer = torch.optim.RAdam(params) | |
warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio | |
cycle_steps = self.hparams.max_epochs - warmup_steps | |
lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale) | |
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) | |
return [optimizer], [scheduler] | |
def test_step(self, batch, batch_idx): | |
x = batch["image"] | |
feat1 = self.get_feat(x) | |
out1, out_species1 = self.head_id(feat1), self.head_species(feat1) | |
feat2 = self.get_feat(x.flip(3)) | |
out2, out_species2 = self.head_id(feat2), self.head_species(feat2) | |
pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True) | |
return { | |
"original_index": batch["original_index"], | |
"label": batch["label"], | |
"label_species": batch["label_species"], | |
"pred_logit": pred_logit[:, :1000], | |
"pred_idx": pred_idx[:, :1000], | |
"pred_species": ((out_species1 + out_species2) / 2).cpu(), | |
"embed_features1": feat1.cpu(), | |
"embed_features2": feat2.cpu(), | |
} | |
def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]): | |
outputs = self.all_gather(outputs) | |
if self.trainer.global_rank == 0: | |
epoch_results: Dict[str, np.ndarray] = {} | |
for key in outputs[0].keys(): | |
if torch.cuda.device_count() > 1: | |
result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1) | |
else: | |
result = torch.cat([x[key] for x in outputs], dim=0) | |
epoch_results[key] = result.detach().cpu().numpy() | |
np.savez_compressed(self.test_results_fp, **epoch_results) | |
# def train( | |
# df: pd.DataFrame, | |
# args: argparse.Namespace, | |
# cfg: Config, | |
# fold: int, | |
# do_inference: bool = False, | |
# additional_dataset: WhaleDataset = None, | |
# optuna_trial: Optional[optuna.Trial] = None, | |
# ) -> Optional[float]: | |
# out_dir = f"{args.out_base_dir}/{args.exp_name}/{fold}" | |
# id_class_nums = df.individual_id.value_counts().sort_index().values | |
# species_class_nums = df.species.value_counts().sort_index().values | |
# model = SphereClassifier(cfg, id_class_nums=id_class_nums, species_class_nums=species_class_nums) | |
# data_module = WhaleDataModule( | |
# df, cfg, f"{args.in_base_dir}/train_images", cfg.val_bbox, fold, additional_dataset=additional_dataset | |
# ) | |
# loggers = [pl_loggers.CSVLogger(out_dir)] | |
# if args.wandb_logger: | |
# loggers.append( | |
# pl_loggers.WandbLogger( | |
# project="kaggle-happywhale", group=args.exp_name, name=f"{args.exp_name}/{fold}", save_dir=out_dir | |
# ) | |
# ) | |
# callbacks = [LearningRateMonitor("epoch")] | |
# if optuna_trial is not None: | |
# callbacks.append(PyTorchLightningPruningCallback(optuna_trial, "val/mapNone")) | |
# if args.save_checkpoint: | |
# callbacks.append(ModelCheckpoint(out_dir, save_last=True, save_top_k=0)) | |
# trainer = Trainer( | |
# gpus=torch.cuda.device_count(), | |
# max_epochs=cfg["max_epochs"], | |
# logger=loggers, | |
# callbacks=callbacks, | |
# checkpoint_callback=args.save_checkpoint, | |
# precision=16, | |
# sync_batchnorm=True, | |
# ) | |
# ckpt_path = f"{out_dir}/last.ckpt" | |
# if not os.path.exists(ckpt_path) or not args.load_snapshot: | |
# ckpt_path = None | |
# trainer.fit(model, ckpt_path=ckpt_path, datamodule=data_module) | |
# if do_inference: | |
# for test_bbox in cfg.test_bboxes: | |
# # all train data | |
# model.test_results_fp = f"{out_dir}/train_{test_bbox}_results.npz" | |
# trainer.test(model, data_module.all_dataloader()) | |
# # test data | |
# model.test_results_fp = f"{out_dir}/test_{test_bbox}_results.npz" | |
# df_test = load_df(args.in_base_dir, cfg, "sample_submission.csv", False) | |
# test_data_module = WhaleDataModule(df_test, cfg, f"{args.in_base_dir}/test_images", test_bbox, -1) | |
# trainer.test(model, test_data_module.all_dataloader()) | |
# if args.wandb_logger: | |
# wandb.finish() | |
# if optuna_trial is not None: | |
# return trainer.callback_metrics["val/mapNone"].item() | |
# else: | |
# return None | |
# def main(): | |
# args = parse() | |
# warnings.filterwarnings("ignore", ".*does not have many workers.*") | |
# cfg = load_config(args.config_path, "config/default.yaml") | |
# print(cfg) | |
# df = load_df(args.in_base_dir, cfg, "train.csv", True) | |
# pseudo_dataset = None | |
# if cfg.pseudo_label is not None: | |
# pseudo_df = load_df(args.in_base_dir, cfg, cfg.pseudo_label, False) | |
# pseudo_dataset = WhaleDataset( | |
# pseudo_df[pseudo_df.conf > cfg.pseudo_conf_threshold], cfg, f"{args.in_base_dir}/test_images", "", True | |
# ) | |
# if cfg["n_splits"] == -1: | |
# train(df, args, cfg, -1, do_inference=True, additional_dataset=pseudo_dataset) | |
# else: | |
# train(df, args, cfg, 0, do_inference=True, additional_dataset=pseudo_dataset) | |
# if __name__ == "__main__": | |
# main() | |