--- tags: - image-classification - timm - transformers - animetimm - dghs-imgutils library_name: timm license: gpl-3.0 datasets: - animetimm/danbooru-wdtagger-v4-w640-ws-full base_model: - timm/convnext_base.fb_in22k_ft_in1k --- # Anime Tagger convnext_base.dbv4-full ## Model Details - **Model Type:** Multilabel Image classification / feature backbone - **Model Stats:** - Params: 100.4M - FLOPs / MACs: 123.1G / 61.4G - Image size: train = 448 x 448, test = 448 x 448 - **Dataset:** [animetimm/danbooru-wdtagger-v4-w640-ws-full](https://huggingface.co/datasets/animetimm/danbooru-wdtagger-v4-w640-ws-full) - Tags Count: 12476 - General (#0) Tags Count: 9225 - Character (#4) Tags Count: 3247 - Rating (#9) Tags Count: 4 ## Results | # | Macro@0.40 (F1/MCC/P/R) | Micro@0.40 (F1/MCC/P/R) | Macro@Best (F1/P/R) | |:----------:|:-----------------------------:|:-----------------------------:|:---------------------:| | Validation | 0.483 / 0.489 / 0.510 / 0.488 | 0.648 / 0.647 / 0.643 / 0.653 | --- | | Test | 0.459 / 0.467 / 0.513 / 0.448 | 0.637 / 0.637 / 0.653 / 0.622 | 0.501 / 0.523 / 0.507 | * `Macro/Micro@0.40` means the metrics on the threshold 0.40. * `Macro@Best` means the mean metrics on the tag-level thresholds on each tags, which should have the best F1 scores. ## Thresholds | Category | Name | Alpha | Threshold | Micro@Thr (F1/P/R) | Macro@0.40 (F1/P/R) | Macro@Best (F1/P/R) | |:----------:|:---------:|:-------:|:-----------:|:---------------------:|:---------------------:|:---------------------:| | 0 | general | 1 | 0.37 | 0.625 / 0.626 / 0.624 | 0.333 / 0.407 / 0.315 | 0.378 / 0.388 / 0.401 | | 4 | character | 1 | 0.65 | 0.859 / 0.917 / 0.808 | 0.815 / 0.813 / 0.826 | 0.851 / 0.904 / 0.810 | | 9 | rating | 1 | 0.39 | 0.807 / 0.760 / 0.860 | 0.813 / 0.779 / 0.853 | 0.814 / 0.782 / 0.853 | * `Micro@Thr` means the metrics on the category-level suggested thresholds, which are listed in the table above. * `Macro@0.40` means the metrics on the threshold 0.40. * `Macro@Best` means the metrics on the tag-level thresholds on each tags, which should have the best F1 scores. For tag-level thresholds, you can find them in [selected_tags.csv](https://huggingface.co/animetimm/convnext_base.dbv4-full/resolve/main/selected_tags.csv). ## How to Use We provided a sample image for our code samples, you can find it [here](https://huggingface.co/animetimm/convnext_base.dbv4-full/blob/main/sample.webp). ### Use TIMM And Torch Install [dghs-imgutils](https://github.com/deepghs/imgutils), [timm](https://github.com/huggingface/pytorch-image-models) and other necessary requirements with the following command ```shell pip install 'dghs-imgutils>=0.17.0' torch huggingface_hub timm pillow pandas ``` After that you can load this model with timm library, and use it for train, validation and test, with the following code ```python import json import pandas as pd import torch from huggingface_hub import hf_hub_download from imgutils.data import load_image from imgutils.preprocess import create_torchvision_transforms from timm import create_model repo_id = 'animetimm/convnext_base.dbv4-full' model = create_model(f'hf-hub:{repo_id}', pretrained=True) model.eval() with open(hf_hub_download(repo_id=repo_id, repo_type='model', filename='preprocess.json'), 'r') as f: preprocessor = create_torchvision_transforms(json.load(f)['test']) # Compose( # PadToSize(size=(512, 512), interpolation=bilinear, background_color=white) # Resize(size=448, interpolation=bicubic, max_size=None, antialias=True) # CenterCrop(size=[448, 448]) # MaybeToTensor() # Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250])) # ) image = load_image('https://huggingface.co/animetimm/convnext_base.dbv4-full/resolve/main/sample.webp') input_ = preprocessor(image).unsqueeze(0) # input_, shape: torch.Size([1, 3, 448, 448]), dtype: torch.float32 with torch.no_grad(): output = model(input_) prediction = torch.sigmoid(output)[0] # output, shape: torch.Size([1, 12476]), dtype: torch.float32 # prediction, shape: torch.Size([12476]), dtype: torch.float32 df_tags = pd.read_csv( hf_hub_download(repo_id=repo_id, repo_type='model', filename='selected_tags.csv'), keep_default_na=False ) tags = df_tags['name'] mask = prediction.numpy() >= df_tags['best_threshold'] print(dict(zip(tags[mask].tolist(), prediction[mask].tolist()))) # {'sensitive': 0.7467835545539856, # '1girl': 0.9981738328933716, # 'solo': 0.9809709191322327, # 'looking_at_viewer': 0.7293463945388794, # 'blush': 0.909403920173645, # 'smile': 0.9033298492431641, # 'short_hair': 0.8594529628753662, # 'shirt': 0.5168469548225403, # 'long_sleeves': 0.7876078486442566, # 'brown_hair': 0.4587213099002838, # 'holding': 0.5555751323699951, # 'dress': 0.6512914896011353, # 'closed_mouth': 0.2889435291290283, # 'sitting': 0.4474199116230011, # 'purple_eyes': 0.8057997226715088, # 'flower': 0.9538715481758118, # 'braid': 0.906929075717926, # 'blunt_bangs': 0.3067224323749542, # 'tears': 0.8093919157981873, # 'crying': 0.37240323424339294, # 'plant': 0.7373866438865662, # 'blue_flower': 0.6387472748756409, # 'tearing_up': 0.13282062113285065, # 'brown_dress': 0.5204353332519531, # 'crown_braid': 0.7680639028549194, # 'potted_plant': 0.7755435109138489, # 'flower_pot': 0.6366523504257202, # 'happy_tears': 0.16147702932357788, # 'pavement': 0.15952551364898682, # 'wiping_tears': 0.8405019044876099, # 'stone_floor': 0.051104169338941574, # 'cobblestone': 0.02183498628437519, # 'scratching_cheek': 0.16645653545856476} ``` ### Use ONNX Model For Inference Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command ```shell pip install 'dghs-imgutils>=0.17.0' ``` Use `multilabel_timm_predict` function with the following code ```python from imgutils.generic import multilabel_timm_predict general, character, rating = multilabel_timm_predict( 'https://huggingface.co/animetimm/convnext_base.dbv4-full/resolve/main/sample.webp', repo_id='animetimm/convnext_base.dbv4-full', fmt=('general', 'character', 'rating'), ) print(general) # {'1girl': 0.9981737732887268, # 'solo': 0.980971097946167, # 'flower': 0.9538715481758118, # 'blush': 0.9094040393829346, # 'braid': 0.9069291353225708, # 'smile': 0.9033299088478088, # 'short_hair': 0.8594533801078796, # 'wiping_tears': 0.840501070022583, # 'tears': 0.8093917369842529, # 'purple_eyes': 0.8057996034622192, # 'long_sleeves': 0.7876076698303223, # 'potted_plant': 0.7755429744720459, # 'crown_braid': 0.7680644989013672, # 'plant': 0.7373863458633423, # 'looking_at_viewer': 0.729346513748169, # 'dress': 0.6512922644615173, # 'blue_flower': 0.6387467384338379, # 'flower_pot': 0.6366517543792725, # 'holding': 0.5555749535560608, # 'brown_dress': 0.5204361081123352, # 'shirt': 0.5168467164039612, # 'brown_hair': 0.4587220847606659, # 'sitting': 0.4474193751811981, # 'crying': 0.3724023699760437, # 'blunt_bangs': 0.30672210454940796, # 'closed_mouth': 0.28894317150115967, # 'scratching_cheek': 0.1664559245109558, # 'happy_tears': 0.1614767611026764, # 'pavement': 0.1595260202884674, # 'tearing_up': 0.1328202188014984, # 'stone_floor': 0.05110400915145874, # 'cobblestone': 0.021834969520568848} print(character) # {} print(rating) # {'sensitive': 0.7467842102050781} ``` For further information, see [documentation of function multilabel_timm_predict](https://dghs-imgutils.deepghs.org/main/api_doc/generic/multilabel_timm.html#multilabel-timm-predict).