u2net-mvtec-loco-segmentation / mvtec_loco_fg_segmentation.py
zhiqing0205
Add complete U2Net project with HuggingFace preparation
ece7754
#!/usr/bin/env python3
import os
import glob
import torch
import numpy as np
from PIL import Image
from skimage import io, transform
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import argparse
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET
from model import U2NETP
try:
from download_from_hf import download_u2net_model
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d - mi) / (ma - mi)
return dn
def save_output(image_name, pred, d_dir, threshold=0.5):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
# 二值化处理:使用指定阈值,生成标准的0/255二值mask
binary_mask = (predict_np > threshold).astype(np.uint8) * 255
im = Image.fromarray(binary_mask).convert('L') # 使用L模式(灰度)
img_name = image_name.split(os.sep)[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1, len(bbb)):
imidx = imidx + "." + bbb[i]
imo.save(d_dir + imidx + '.png')
def process_mvtec_loco_dataset(dataset_path, model_path, output_dir='fg_mask',
threshold=0.5, categories=None, splits=None,
batch_size=1, num_workers=1):
print('...load U2NET---')
net = U2NET(3, 1)
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location='cpu'))
net.eval()
# Use provided categories or default
if categories is None:
categories = ['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins']
# Use provided splits or default
if splits is None:
splits = ['test', 'train']
# Create output directory structure
mask_root = os.path.join(dataset_path, output_dir)
os.makedirs(mask_root, exist_ok=True)
for category in categories:
print(f"Processing category: {category}")
category_path = os.path.join(dataset_path, category)
# Process specified splits
for split in splits:
split_path = os.path.join(category_path, split)
if not os.path.exists(split_path):
print(f"Skipping {category}/{split} - directory not found")
continue
# Get all subdirectories in test/train (e.g., good, logical_anomalies, structural_anomalies)
subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
for subdir in subdirs:
subdir_path = os.path.join(split_path, subdir)
output_path = os.path.join(mask_root, category, split, subdir)
print(f" Processing {category}/{split}/{subdir}")
# Get all PNG images in this subdirectory
image_list = glob.glob(os.path.join(subdir_path, '*.png'))
if not image_list:
print(f" No PNG images found in {subdir_path}")
continue
print(f" Found {len(image_list)} images")
# Ensure output directory exists
os.makedirs(output_path, exist_ok=True)
# Create dataset and dataloader
test_salobj_dataset = SalObjDataset(img_name_list=image_list,
lbl_name_list=[],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)]))
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
# Process each image
for i, data_test in enumerate(test_salobj_dataloader):
if (i + 1) % 20 == 0 or i == 0: # Print progress every 20 images
print(f" Processing {i+1}/{len(image_list)}: {os.path.basename(image_list[i])}")
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = inputs_test.cuda()
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
pred = d1[:, 0, :, :]
pred = normPRED(pred)
# Save result
save_output(image_list[i], pred, output_path + os.sep, threshold)
del d1, d2, d3, d4, d5, d6, d7
print("All categories and splits processed successfully!")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate foreground masks for MVTec LOCO dataset using U2NET',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
Examples:
# Use default paths
python mvtec_loco_fg_segmentation.py
# Specify custom dataset and model paths
python mvtec_loco_fg_segmentation.py --dataset_path /path/to/mvtec_loco --model_path /path/to/u2net.pth
# Process specific categories only
python mvtec_loco_fg_segmentation.py --categories breakfast_box juice_bottle
# Use different threshold for binary mask
python mvtec_loco_fg_segmentation.py --threshold 0.3
''')
parser.add_argument('--dataset_path', type=str, default='/root/hy-data/datasets/mvtec_loco_anomaly_detection',
help='Path to MVTec LOCO dataset root directory (default: /root/hy-data/datasets/mvtec_loco_anomaly_detection)')
parser.add_argument('--model_path', type=str, default='./saved_models/u2net/u2net.pth',
help='Path to U2NET model weights file (default: ./saved_models/u2net/u2net.pth)')
parser.add_argument('--output_dir', type=str, default='fg_mask',
help='Output directory name for generated masks (default: fg_mask)')
parser.add_argument('--threshold', type=float, default=0.5,
help='Threshold for binary mask generation (default: 0.5)')
parser.add_argument('--categories', nargs='+',
default=['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins'],
help='Categories to process (default: all 5 categories)')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size for processing (default: 1)')
parser.add_argument('--num_workers', type=int, default=1,
help='Number of data loading workers (default: 1)')
parser.add_argument('--splits', nargs='+', default=['test', 'train'],
help='Dataset splits to process (default: test train)')
args = parser.parse_args()
# Validate paths
if not os.path.exists(args.dataset_path):
print(f"ERROR: Dataset path not found: {args.dataset_path}")
print("Please check the dataset path and make sure MVTec LOCO dataset is properly extracted.")
exit(1)
if not os.path.exists(args.model_path):
print(f"Model not found: {args.model_path}")
# Try to download from HuggingFace if available
if HF_AVAILABLE:
print("Attempting to download model from HuggingFace Hub...")
downloaded_path = download_u2net_model(args.model_path)
if downloaded_path is None:
print("Failed to download from HuggingFace.")
print("Please download manually from:")
print("https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view")
exit(1)
else:
print("HuggingFace Hub not available. Please install: pip install huggingface_hub")
print("Or download manually from:")
print("https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view")
exit(1)
# Validate categories
valid_categories = ['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins']
invalid_categories = [cat for cat in args.categories if cat not in valid_categories]
if invalid_categories:
print(f"ERROR: Invalid categories: {invalid_categories}")
print(f"Valid categories are: {valid_categories}")
exit(1)
print(f"Configuration:")
print(f" Dataset path: {args.dataset_path}")
print(f" Model path: {args.model_path}")
print(f" Output directory: {args.output_dir}")
print(f" Binary threshold: {args.threshold}")
print(f" Categories: {args.categories}")
print(f" Splits: {args.splits}")
print(f" Batch size: {args.batch_size}")
print(f" Workers: {args.num_workers}")
print()
process_mvtec_loco_dataset(
dataset_path=args.dataset_path,
model_path=args.model_path,
output_dir=args.output_dir,
threshold=args.threshold,
categories=args.categories,
splits=args.splits,
batch_size=args.batch_size,
num_workers=args.num_workers
)