--- tags: - image-classification - timm - transformers - animetimm - dghs-imgutils library_name: timm license: gpl-3.0 datasets: - animetimm/danbooru-wdtagger-v4-w640-ws-full base_model: - timm/eva02_large_patch14_448.mim_m38m_ft_in22k_in1k --- # Anime Tagger eva02_large_patch14_448.dbv4-full ## Model Details - **Model Type:** Multilabel Image classification / feature backbone - **Model Stats:** - Params: 316.8M - FLOPs / MACs: 620.9G / 310.1G - Image size: train = 448 x 448, test = 448 x 448 - **Dataset:** [animetimm/danbooru-wdtagger-v4-w640-ws-full](https://huggingface.co/datasets/animetimm/danbooru-wdtagger-v4-w640-ws-full) - Tags Count: 12476 - General (#0) Tags Count: 9225 - Character (#4) Tags Count: 3247 - Rating (#9) Tags Count: 4 ## Results | # | Macro@0.40 (F1/MCC/P/R) | Micro@0.40 (F1/MCC/P/R) | Macro@Best (F1/P/R) | |:----------:|:-----------------------------:|:-----------------------------:|:---------------------:| | Validation | 0.528 / 0.537 / 0.600 / 0.503 | 0.678 / 0.678 / 0.693 / 0.664 | --- | | Test | 0.529 / 0.538 / 0.601 / 0.503 | 0.679 / 0.678 / 0.694 / 0.665 | 0.574 / 0.580 / 0.591 | * `Macro/Micro@0.40` means the metrics on the threshold 0.40. * `Macro@Best` means the mean metrics on the tag-level thresholds on each tags, which should have the best F1 scores. ## Thresholds | Category | Name | Alpha | Threshold | Micro@Thr (F1/P/R) | Macro@0.40 (F1/P/R) | Macro@Best (F1/P/R) | |:----------:|:---------:|:-------:|:-----------:|:---------------------:|:---------------------:|:---------------------:| | 0 | general | 1 | 0.37 | 0.667 / 0.666 / 0.668 | 0.399 / 0.493 / 0.367 | 0.453 / 0.452 / 0.483 | | 4 | character | 1 | 0.57 | 0.922 / 0.951 / 0.895 | 0.896 / 0.909 / 0.889 | 0.917 / 0.943 / 0.895 | | 9 | rating | 1 | 0.4 | 0.824 / 0.784 / 0.868 | 0.829 / 0.800 / 0.862 | 0.831 / 0.806 / 0.859 | * `Micro@Thr` means the metrics on the category-level suggested thresholds, which are listed in the table above. * `Macro@0.40` means the metrics on the threshold 0.40. * `Macro@Best` means the metrics on the tag-level thresholds on each tags, which should have the best F1 scores. For tag-level thresholds, you can find them in [selected_tags.csv](https://huggingface.co/animetimm/eva02_large_patch14_448.dbv4-full/resolve/main/selected_tags.csv). ## How to Use We provided a sample image for our code samples, you can find it [here](https://huggingface.co/animetimm/eva02_large_patch14_448.dbv4-full/blob/main/sample.webp). ### Use TIMM And Torch Install [dghs-imgutils](https://github.com/deepghs/imgutils), [timm](https://github.com/huggingface/pytorch-image-models) and other necessary requirements with the following command ```shell pip install 'dghs-imgutils>=0.17.0' torch huggingface_hub timm pillow pandas ``` After that you can load this model with timm library, and use it for train, validation and test, with the following code ```python import json import pandas as pd import torch from huggingface_hub import hf_hub_download from imgutils.data import load_image from imgutils.preprocess import create_torchvision_transforms from timm import create_model repo_id = 'animetimm/eva02_large_patch14_448.dbv4-full' model = create_model(f'hf-hub:{repo_id}', pretrained=True) model.eval() with open(hf_hub_download(repo_id=repo_id, repo_type='model', filename='preprocess.json'), 'r') as f: preprocessor = create_torchvision_transforms(json.load(f)['test']) # Compose( # PadToSize(size=(512, 512), interpolation=bilinear, background_color=white) # Resize(size=(448, 448), interpolation=bicubic, max_size=None, antialias=True) # CenterCrop(size=[448, 448]) # MaybeToTensor() # Normalize(mean=tensor([0.4815, 0.4578, 0.4082]), std=tensor([0.2686, 0.2613, 0.2758])) # ) image = load_image('https://huggingface.co/animetimm/eva02_large_patch14_448.dbv4-full/resolve/main/sample.webp') input_ = preprocessor(image).unsqueeze(0) # input_, shape: torch.Size([1, 3, 448, 448]), dtype: torch.float32 with torch.no_grad(): output = model(input_) prediction = torch.sigmoid(output)[0] # output, shape: torch.Size([1, 12476]), dtype: torch.float32 # prediction, shape: torch.Size([12476]), dtype: torch.float32 df_tags = pd.read_csv( hf_hub_download(repo_id=repo_id, repo_type='model', filename='selected_tags.csv'), keep_default_na=False ) tags = df_tags['name'] mask = prediction.numpy() >= df_tags['best_threshold'] print(dict(zip(tags[mask].tolist(), prediction[mask].tolist()))) # {'sensitive': 0.6976025700569153, # '1girl': 0.9952899217605591, # 'solo': 0.9671481847763062, # 'looking_at_viewer': 0.7711699604988098, # 'blush': 0.7974982261657715, # 'smile': 0.8849270939826965, # 'short_hair': 0.817248523235321, # 'long_sleeves': 0.5171797275543213, # 'brown_hair': 0.6675055623054504, # 'dress': 0.6894800662994385, # 'closed_mouth': 0.35917922854423523, # 'sitting': 0.7595945000648499, # 'purple_eyes': 0.8275928497314453, # 'flower': 0.8742285966873169, # 'braid': 0.8496974110603333, # 'blunt_bangs': 0.39164724946022034, # 'tears': 0.8591281771659851, # 'floral_print': 0.44396182894706726, # 'crying': 0.4951671063899994, # 'plant': 0.758698046207428, # 'blue_flower': 0.5387876629829407, # 'tearing_up': 0.11903537809848785, # 'crying_with_eyes_open': 0.3073916733264923, # 'crown_braid': 0.7725721001625061, # 'potted_plant': 0.8286207318305969, # 'flower_pot': 0.6531336307525635, # 'happy_tears': 0.3884831964969635, # 'pavement': 0.2094476968050003, # 'wiping_tears': 0.6769278645515442, # 'holding_flower_pot': 0.12655559182167053} ``` ### Use ONNX Model For Inference Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command ```shell pip install 'dghs-imgutils>=0.17.0' ``` Use `multilabel_timm_predict` function with the following code ```python from imgutils.generic import multilabel_timm_predict general, character, rating = multilabel_timm_predict( 'https://huggingface.co/animetimm/eva02_large_patch14_448.dbv4-full/resolve/main/sample.webp', repo_id='animetimm/eva02_large_patch14_448.dbv4-full', fmt=('general', 'character', 'rating'), ) print(general) # {'1girl': 0.9952900409698486, # 'solo': 0.9671480655670166, # 'smile': 0.8849270343780518, # 'flower': 0.8742280602455139, # 'tears': 0.8591268062591553, # 'braid': 0.8496923446655273, # 'potted_plant': 0.8286197185516357, # 'purple_eyes': 0.8275918364524841, # 'short_hair': 0.8172485828399658, # 'blush': 0.7974982857704163, # 'crown_braid': 0.772567629814148, # 'looking_at_viewer': 0.7711694240570068, # 'sitting': 0.759594738483429, # 'plant': 0.7586977481842041, # 'dress': 0.6894786357879639, # 'wiping_tears': 0.6769236326217651, # 'brown_hair': 0.6675049662590027, # 'flower_pot': 0.6531318426132202, # 'blue_flower': 0.5387848615646362, # 'long_sleeves': 0.5171791315078735, # 'crying': 0.4951639473438263, # 'floral_print': 0.44396066665649414, # 'blunt_bangs': 0.39164483547210693, # 'happy_tears': 0.3884800672531128, # 'closed_mouth': 0.3591785430908203, # 'crying_with_eyes_open': 0.30738943815231323, # 'pavement': 0.20944759249687195, # 'holding_flower_pot': 0.12655416131019592, # 'tearing_up': 0.11903449892997742} print(character) # {} print(rating) # {'sensitive': 0.6976030468940735} ``` For further information, see [documentation of function multilabel_timm_predict](https://dghs-imgutils.deepghs.org/main/api_doc/generic/multilabel_timm.html#multilabel-timm-predict).