OneFormer Swin Tiny fine-tuned for clothes segmentation

OneFormer model fine-tuned on ATR dataset for clothes segmentation but can also be used for human segmentation.

The dataset on hugging face is called "mattmdjaga/human_parsing_dataset".

Inference

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import requests
import torch
from PIL import Image
from matplotlib import cm
from transformers import AutoProcessor, OneFormerForUniversalSegmentation

processor = AutoProcessor.from_pretrained("pooya-mohammadi/oneformer_ade20k_swin_tiny_clothes")
model = OneFormerForUniversalSegmentation.from_pretrained("pooya-mohammadi/oneformer_ade20k_swin_tiny_clothes",
                                                          is_training=False)
model.config.id2label = model.config.backbone_config.id2label
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# load image
url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"

image = Image.open(requests.get(url, stream=True).raw)
image.save("input.jpg")

# prepare image for the model
inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# forward pass (no need for gradients at inference time)
with torch.no_grad():
    outputs = model(**inputs)

# postprocessing
semantic_segmentation = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]


def draw_semantic_segmentation(segmentation):
    # get the used color map
    viridis = cm.get_cmap('viridis', torch.max(segmentation))
    # get all the unique numbers
    labels_ids = torch.unique(segmentation).tolist()
    fig, ax = plt.subplots()
    ax.imshow(segmentation)
    handles = []
    for label_id in labels_ids:
        label = model.config.id2label[label_id]
        color = viridis(label_id)
        handles.append(mpatches.Patch(color=color, label=label))
    ax.legend(handles=handles)
    plt.savefig("output.png")


draw_semantic_segmentation(semantic_segmentation.cpu())

Downloads last month
47
Safetensors
Model size
50.8M params
Tensor type
I64
·
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support