Spaces:
Running
A newer version of the Gradio SDK is available:
5.29.0
Clothes Segmentation using U2NET
This repo contains training code, inference code and pre-trained model for Cloths Parsing from human portrait.
Here clothes are parsed into 3 category: Upper body(red), Lower body(green) and Full body(yellow)
This model works well with any background and almost all poses. For more samples visit samples.md
Techinal details
U2NET : This project uses an amazing U2NET as a deep learning model. Instead of having 1 channel output from u2net for typical salient object detection task it outputs 4 channels each respresting upper body cloth, lower body cloth, fully body cloth and background. Only categorical cross-entropy loss is used for a given version of the checkpoint.
Dataset : U2net is trained on 45k images iMaterialist (Fashion) 2019 at FGVC6 dataset. To reduce complexity, I have clubbed the original 42 categories from dataset labels into 3 categories (upper body, lower body and full body). All images are resized into square
¯\_(ツ)_/¯
768 x 768 px for training. (This experiment was conducted with 768 px but around 384 px will work fine too if one is retraining on another dataset).
Training
For training this project requires,
- PyTorch > 1.3.0
- tensorboardX
- gdown
Download dataset from this link, extract all items.
Set path of
train
folder which contains training images andtrain.csv
which is label csv file inoptions/base_options.py
To port original u2net of all layer except last layer please run
python setup_model_weights.py
and it will generate weights after model surgey inprev_checkpoints
folder.You can explore various options in
options/base_options.py
like checkpoint saving folder, logs folder etc.For single gpu set
distributed = False
inoptions/base_options.py
, for multi gpu set it toTrue
.For single gpu run
python train.py
For multi gpu run
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=4 --use_env train.py
Here command is for single node, 4 gpu. Tested only for single node.You can watch loss graphs and samples in tensorboard by running tensorboard command in log folder.
Testing/Inference
- Download pretrained model from this link(165 MB) in
trained_checkpoint
folder. - Put input images in
input_images
folder - Run
python infer.py
for inference. - Output will be saved in
output_images
OR
Acknowledgements
- U2net model is from original u2net repo. Thanks to Xuebin Qin for amazing repo.
- Complete repo follows structure of Pix2pixHD repo