|
|
|
|
|
import os |
|
import glob |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from skimage import io, transform |
|
from torch.autograd import Variable |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
import argparse |
|
|
|
from data_loader import RescaleT |
|
from data_loader import ToTensor |
|
from data_loader import ToTensorLab |
|
from data_loader import SalObjDataset |
|
|
|
from model import U2NET |
|
from model import U2NETP |
|
|
|
try: |
|
from download_from_hf import download_u2net_model |
|
HF_AVAILABLE = True |
|
except ImportError: |
|
HF_AVAILABLE = False |
|
|
|
def normPRED(d): |
|
ma = torch.max(d) |
|
mi = torch.min(d) |
|
dn = (d - mi) / (ma - mi) |
|
return dn |
|
|
|
def save_output(image_name, pred, d_dir, threshold=0.5): |
|
predict = pred |
|
predict = predict.squeeze() |
|
predict_np = predict.cpu().data.numpy() |
|
|
|
|
|
binary_mask = (predict_np > threshold).astype(np.uint8) * 255 |
|
|
|
im = Image.fromarray(binary_mask).convert('L') |
|
img_name = image_name.split(os.sep)[-1] |
|
image = io.imread(image_name) |
|
imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR) |
|
|
|
aaa = img_name.split(".") |
|
bbb = aaa[0:-1] |
|
imidx = bbb[0] |
|
for i in range(1, len(bbb)): |
|
imidx = imidx + "." + bbb[i] |
|
|
|
imo.save(d_dir + imidx + '.png') |
|
|
|
def process_mvtec_loco_dataset(dataset_path, model_path, output_dir='fg_mask', |
|
threshold=0.5, categories=None, splits=None, |
|
batch_size=1, num_workers=1): |
|
print('...load U2NET---') |
|
net = U2NET(3, 1) |
|
|
|
if torch.cuda.is_available(): |
|
net.load_state_dict(torch.load(model_path)) |
|
net.cuda() |
|
else: |
|
net.load_state_dict(torch.load(model_path, map_location='cpu')) |
|
net.eval() |
|
|
|
|
|
if categories is None: |
|
categories = ['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins'] |
|
|
|
|
|
if splits is None: |
|
splits = ['test', 'train'] |
|
|
|
|
|
mask_root = os.path.join(dataset_path, output_dir) |
|
os.makedirs(mask_root, exist_ok=True) |
|
|
|
for category in categories: |
|
print(f"Processing category: {category}") |
|
category_path = os.path.join(dataset_path, category) |
|
|
|
|
|
for split in splits: |
|
split_path = os.path.join(category_path, split) |
|
if not os.path.exists(split_path): |
|
print(f"Skipping {category}/{split} - directory not found") |
|
continue |
|
|
|
|
|
subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))] |
|
|
|
for subdir in subdirs: |
|
subdir_path = os.path.join(split_path, subdir) |
|
output_path = os.path.join(mask_root, category, split, subdir) |
|
|
|
print(f" Processing {category}/{split}/{subdir}") |
|
|
|
|
|
image_list = glob.glob(os.path.join(subdir_path, '*.png')) |
|
|
|
if not image_list: |
|
print(f" No PNG images found in {subdir_path}") |
|
continue |
|
|
|
print(f" Found {len(image_list)} images") |
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
|
|
test_salobj_dataset = SalObjDataset(img_name_list=image_list, |
|
lbl_name_list=[], |
|
transform=transforms.Compose([RescaleT(320), |
|
ToTensorLab(flag=0)])) |
|
test_salobj_dataloader = DataLoader(test_salobj_dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers) |
|
|
|
|
|
for i, data_test in enumerate(test_salobj_dataloader): |
|
if (i + 1) % 20 == 0 or i == 0: |
|
print(f" Processing {i+1}/{len(image_list)}: {os.path.basename(image_list[i])}") |
|
|
|
inputs_test = data_test['image'] |
|
inputs_test = inputs_test.type(torch.FloatTensor) |
|
|
|
if torch.cuda.is_available(): |
|
inputs_test = inputs_test.cuda() |
|
|
|
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) |
|
pred = d1[:, 0, :, :] |
|
pred = normPRED(pred) |
|
|
|
|
|
save_output(image_list[i], pred, output_path + os.sep, threshold) |
|
|
|
del d1, d2, d3, d4, d5, d6, d7 |
|
|
|
print("All categories and splits processed successfully!") |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Generate foreground masks for MVTec LOCO dataset using U2NET', |
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
epilog=''' |
|
Examples: |
|
# Use default paths |
|
python mvtec_loco_fg_segmentation.py |
|
|
|
# Specify custom dataset and model paths |
|
python mvtec_loco_fg_segmentation.py --dataset_path /path/to/mvtec_loco --model_path /path/to/u2net.pth |
|
|
|
# Process specific categories only |
|
python mvtec_loco_fg_segmentation.py --categories breakfast_box juice_bottle |
|
|
|
# Use different threshold for binary mask |
|
python mvtec_loco_fg_segmentation.py --threshold 0.3 |
|
''') |
|
|
|
parser.add_argument('--dataset_path', type=str, default='/root/hy-data/datasets/mvtec_loco_anomaly_detection', |
|
help='Path to MVTec LOCO dataset root directory (default: /root/hy-data/datasets/mvtec_loco_anomaly_detection)') |
|
parser.add_argument('--model_path', type=str, default='./saved_models/u2net/u2net.pth', |
|
help='Path to U2NET model weights file (default: ./saved_models/u2net/u2net.pth)') |
|
parser.add_argument('--output_dir', type=str, default='fg_mask', |
|
help='Output directory name for generated masks (default: fg_mask)') |
|
parser.add_argument('--threshold', type=float, default=0.5, |
|
help='Threshold for binary mask generation (default: 0.5)') |
|
parser.add_argument('--categories', nargs='+', |
|
default=['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins'], |
|
help='Categories to process (default: all 5 categories)') |
|
parser.add_argument('--batch_size', type=int, default=1, |
|
help='Batch size for processing (default: 1)') |
|
parser.add_argument('--num_workers', type=int, default=1, |
|
help='Number of data loading workers (default: 1)') |
|
parser.add_argument('--splits', nargs='+', default=['test', 'train'], |
|
help='Dataset splits to process (default: test train)') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.dataset_path): |
|
print(f"ERROR: Dataset path not found: {args.dataset_path}") |
|
print("Please check the dataset path and make sure MVTec LOCO dataset is properly extracted.") |
|
exit(1) |
|
|
|
if not os.path.exists(args.model_path): |
|
print(f"Model not found: {args.model_path}") |
|
|
|
|
|
if HF_AVAILABLE: |
|
print("Attempting to download model from HuggingFace Hub...") |
|
downloaded_path = download_u2net_model(args.model_path) |
|
if downloaded_path is None: |
|
print("Failed to download from HuggingFace.") |
|
print("Please download manually from:") |
|
print("https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view") |
|
exit(1) |
|
else: |
|
print("HuggingFace Hub not available. Please install: pip install huggingface_hub") |
|
print("Or download manually from:") |
|
print("https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view") |
|
exit(1) |
|
|
|
|
|
valid_categories = ['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins'] |
|
invalid_categories = [cat for cat in args.categories if cat not in valid_categories] |
|
if invalid_categories: |
|
print(f"ERROR: Invalid categories: {invalid_categories}") |
|
print(f"Valid categories are: {valid_categories}") |
|
exit(1) |
|
|
|
print(f"Configuration:") |
|
print(f" Dataset path: {args.dataset_path}") |
|
print(f" Model path: {args.model_path}") |
|
print(f" Output directory: {args.output_dir}") |
|
print(f" Binary threshold: {args.threshold}") |
|
print(f" Categories: {args.categories}") |
|
print(f" Splits: {args.splits}") |
|
print(f" Batch size: {args.batch_size}") |
|
print(f" Workers: {args.num_workers}") |
|
print() |
|
|
|
process_mvtec_loco_dataset( |
|
dataset_path=args.dataset_path, |
|
model_path=args.model_path, |
|
output_dir=args.output_dir, |
|
threshold=args.threshold, |
|
categories=args.categories, |
|
splits=args.splits, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers |
|
) |