--- tags: - image-classification - timm - transformers - animetimm - dghs-imgutils library_name: timm license: gpl-3.0 datasets: - animetimm/danbooru-wdtagger-v4-w640-ws-full base_model: - timm/mobilenetv4_conv_small.e2400_r224_in1k --- # Anime Tagger mobilenetv4_conv_small.dbv4-full ## Model Details - **Model Type:** Multilabel Image classification / feature backbone - **Model Stats:** - Params: 18.5M - FLOPs / MACs: 1.1G / 556.5M - Image size: train = 384 x 384, test = 384 x 384 - **Dataset:** [animetimm/danbooru-wdtagger-v4-w640-ws-full](https://huggingface.co/datasets/animetimm/danbooru-wdtagger-v4-w640-ws-full) - Tags Count: 12476 - General (#0) Tags Count: 9225 - Character (#4) Tags Count: 3247 - Rating (#9) Tags Count: 4 ## Results | # | Macro@0.40 (F1/MCC/P/R) | Micro@0.40 (F1/MCC/P/R) | Macro@Best (F1/P/R) | |:----------:|:-----------------------------:|:-----------------------------:|:---------------------:| | Validation | 0.311 / 0.337 / 0.509 / 0.252 | 0.528 / 0.538 / 0.660 / 0.440 | --- | | Test | 0.312 / 0.338 / 0.509 / 0.253 | 0.529 / 0.538 / 0.660 / 0.441 | 0.379 / 0.431 / 0.365 | * `Macro/Micro@0.40` means the metrics on the threshold 0.40. * `Macro@Best` means the mean metrics on the tag-level thresholds on each tags, which should have the best F1 scores. ## Thresholds | Category | Name | Alpha | Threshold | Micro@Thr (F1/P/R) | Macro@0.40 (F1/P/R) | Macro@Best (F1/P/R) | |:----------:|:---------:|:-------:|:-----------:|:---------------------:|:---------------------:|:---------------------:| | 0 | general | 1 | 0.28 | 0.530 / 0.556 / 0.505 | 0.183 / 0.391 / 0.137 | 0.261 / 0.293 / 0.267 | | 4 | character | 1 | 0.32 | 0.721 / 0.822 / 0.642 | 0.680 / 0.845 / 0.582 | 0.713 / 0.821 / 0.642 | | 9 | rating | 1 | 0.37 | 0.754 / 0.691 / 0.831 | 0.749 / 0.723 / 0.784 | 0.753 / 0.711 / 0.806 | * `Micro@Thr` means the metrics on the category-level suggested thresholds, which are listed in the table above. * `Macro@0.40` means the metrics on the threshold 0.40. * `Macro@Best` means the metrics on the tag-level thresholds on each tags, which should have the best F1 scores. For tag-level thresholds, you can find them in [selected_tags.csv](https://huggingface.co/animetimm/mobilenetv4_conv_small.dbv4-full/resolve/main/selected_tags.csv). ## How to Use We provided a sample image for our code samples, you can find it [here](https://huggingface.co/animetimm/mobilenetv4_conv_small.dbv4-full/blob/main/sample.webp). ### Use TIMM And Torch Install [dghs-imgutils](https://github.com/deepghs/imgutils), [timm](https://github.com/huggingface/pytorch-image-models) and other necessary requirements with the following command ```shell pip install 'dghs-imgutils>=0.17.0' torch huggingface_hub timm pillow pandas ``` After that you can load this model with timm library, and use it for train, validation and test, with the following code ```python import json import pandas as pd import torch from huggingface_hub import hf_hub_download from imgutils.data import load_image from imgutils.preprocess import create_torchvision_transforms from timm import create_model repo_id = 'animetimm/mobilenetv4_conv_small.dbv4-full' model = create_model(f'hf-hub:{repo_id}', pretrained=True) model.eval() with open(hf_hub_download(repo_id=repo_id, repo_type='model', filename='preprocess.json'), 'r') as f: preprocessor = create_torchvision_transforms(json.load(f)['test']) # Compose( # PadToSize(size=(512, 512), interpolation=bilinear, background_color=white) # Resize(size=384, interpolation=bicubic, max_size=None, antialias=True) # CenterCrop(size=[384, 384]) # MaybeToTensor() # Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250])) # ) image = load_image('https://huggingface.co/animetimm/mobilenetv4_conv_small.dbv4-full/resolve/main/sample.webp') input_ = preprocessor(image).unsqueeze(0) # input_, shape: torch.Size([1, 3, 384, 384]), dtype: torch.float32 with torch.no_grad(): output = model(input_) prediction = torch.sigmoid(output)[0] # output, shape: torch.Size([1, 12476]), dtype: torch.float32 # prediction, shape: torch.Size([12476]), dtype: torch.float32 df_tags = pd.read_csv( hf_hub_download(repo_id=repo_id, repo_type='model', filename='selected_tags.csv'), keep_default_na=False ) tags = df_tags['name'] mask = prediction.numpy() >= df_tags['best_threshold'] print(dict(zip(tags[mask].tolist(), prediction[mask].tolist()))) # {'general': 0.5250747203826904, # 'sensitive': 0.4123866856098175, # '1girl': 0.9827392101287842, # 'solo': 0.9448713064193726, # 'looking_at_viewer': 0.7789917588233948, # 'blush': 0.8255025148391724, # 'smile': 0.8907790780067444, # 'short_hair': 0.7803121209144592, # 'shirt': 0.48012447357177734, # 'long_sleeves': 0.6374803781509399, # 'brown_hair': 0.4931342899799347, # 'holding': 0.6302770376205444, # 'dress': 0.6055158972740173, # 'closed_mouth': 0.4692787826061249, # 'jewelry': 0.2611806094646454, # 'purple_eyes': 0.5491976737976074, # 'upper_body': 0.2779935300350189, # 'flower': 0.6814424991607666, # 'outdoors': 0.5236565470695496, # 'hand_up': 0.19452522695064545, # 'blunt_bangs': 0.2594663202762604, # 'necklace': 0.6076473593711853, # 'head_tilt': 0.09924726188182831, # 'sunlight': 0.0953480675816536, # 'light_smile': 0.07040124386548996, # 'hand_on_own_face': 0.10853870958089828, # 'blue_flower': 0.3045899271965027, # 'tareme': 0.09898919612169266, # 'backlighting': 0.2655039429664612, # 'bouquet': 0.4030725359916687, # 'brown_dress': 0.23282478749752045, # 'shade': 0.047940079122781754, # 'pavement': 0.024188760668039322} ``` ### Use ONNX Model For Inference Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command ```shell pip install 'dghs-imgutils>=0.17.0' ``` Use `multilabel_timm_predict` function with the following code ```python from imgutils.generic import multilabel_timm_predict general, character, rating = multilabel_timm_predict( 'https://huggingface.co/animetimm/mobilenetv4_conv_small.dbv4-full/resolve/main/sample.webp', repo_id='animetimm/mobilenetv4_conv_small.dbv4-full', fmt=('general', 'character', 'rating'), ) print(general) # {'1girl': 0.982739269733429, # 'solo': 0.9448712468147278, # 'smile': 0.8907787799835205, # 'blush': 0.8255019187927246, # 'short_hair': 0.7803121209144592, # 'looking_at_viewer': 0.7789915800094604, # 'flower': 0.6814424395561218, # 'long_sleeves': 0.6374804973602295, # 'holding': 0.6302772164344788, # 'necklace': 0.6076476573944092, # 'dress': 0.6055158376693726, # 'purple_eyes': 0.5491974353790283, # 'outdoors': 0.5236567258834839, # 'brown_hair': 0.49313288927078247, # 'shirt': 0.48012393712997437, # 'closed_mouth': 0.46927782893180847, # 'bouquet': 0.40307432413101196, # 'blue_flower': 0.30458980798721313, # 'upper_body': 0.2779929041862488, # 'backlighting': 0.2655022144317627, # 'jewelry': 0.26118117570877075, # 'blunt_bangs': 0.259465754032135, # 'brown_dress': 0.23282349109649658, # 'hand_up': 0.1945251226425171, # 'hand_on_own_face': 0.10853880643844604, # 'head_tilt': 0.09924730658531189, # 'tareme': 0.0989888608455658, # 'sunlight': 0.09534767270088196, # 'light_smile': 0.07040110230445862, # 'shade': 0.04794028401374817, # 'pavement': 0.024188846349716187} print(character) # {} print(rating) # {'general': 0.5250744819641113, 'sensitive': 0.4123871922492981} ``` For further information, see [documentation of function multilabel_timm_predict](https://dghs-imgutils.deepghs.org/main/api_doc/generic/multilabel_timm.html#multilabel-timm-predict).