File size: 10,065 Bytes
ece7754 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
#!/usr/bin/env python3
import os
import glob
import torch
import numpy as np
from PIL import Image
from skimage import io, transform
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import argparse
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET
from model import U2NETP
try:
from download_from_hf import download_u2net_model
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d - mi) / (ma - mi)
return dn
def save_output(image_name, pred, d_dir, threshold=0.5):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
# 二值化处理:使用指定阈值,生成标准的0/255二值mask
binary_mask = (predict_np > threshold).astype(np.uint8) * 255
im = Image.fromarray(binary_mask).convert('L') # 使用L模式(灰度)
img_name = image_name.split(os.sep)[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1, len(bbb)):
imidx = imidx + "." + bbb[i]
imo.save(d_dir + imidx + '.png')
def process_mvtec_loco_dataset(dataset_path, model_path, output_dir='fg_mask',
threshold=0.5, categories=None, splits=None,
batch_size=1, num_workers=1):
print('...load U2NET---')
net = U2NET(3, 1)
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location='cpu'))
net.eval()
# Use provided categories or default
if categories is None:
categories = ['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins']
# Use provided splits or default
if splits is None:
splits = ['test', 'train']
# Create output directory structure
mask_root = os.path.join(dataset_path, output_dir)
os.makedirs(mask_root, exist_ok=True)
for category in categories:
print(f"Processing category: {category}")
category_path = os.path.join(dataset_path, category)
# Process specified splits
for split in splits:
split_path = os.path.join(category_path, split)
if not os.path.exists(split_path):
print(f"Skipping {category}/{split} - directory not found")
continue
# Get all subdirectories in test/train (e.g., good, logical_anomalies, structural_anomalies)
subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
for subdir in subdirs:
subdir_path = os.path.join(split_path, subdir)
output_path = os.path.join(mask_root, category, split, subdir)
print(f" Processing {category}/{split}/{subdir}")
# Get all PNG images in this subdirectory
image_list = glob.glob(os.path.join(subdir_path, '*.png'))
if not image_list:
print(f" No PNG images found in {subdir_path}")
continue
print(f" Found {len(image_list)} images")
# Ensure output directory exists
os.makedirs(output_path, exist_ok=True)
# Create dataset and dataloader
test_salobj_dataset = SalObjDataset(img_name_list=image_list,
lbl_name_list=[],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)]))
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
# Process each image
for i, data_test in enumerate(test_salobj_dataloader):
if (i + 1) % 20 == 0 or i == 0: # Print progress every 20 images
print(f" Processing {i+1}/{len(image_list)}: {os.path.basename(image_list[i])}")
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = inputs_test.cuda()
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
pred = d1[:, 0, :, :]
pred = normPRED(pred)
# Save result
save_output(image_list[i], pred, output_path + os.sep, threshold)
del d1, d2, d3, d4, d5, d6, d7
print("All categories and splits processed successfully!")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate foreground masks for MVTec LOCO dataset using U2NET',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
Examples:
# Use default paths
python mvtec_loco_fg_segmentation.py
# Specify custom dataset and model paths
python mvtec_loco_fg_segmentation.py --dataset_path /path/to/mvtec_loco --model_path /path/to/u2net.pth
# Process specific categories only
python mvtec_loco_fg_segmentation.py --categories breakfast_box juice_bottle
# Use different threshold for binary mask
python mvtec_loco_fg_segmentation.py --threshold 0.3
''')
parser.add_argument('--dataset_path', type=str, default='/root/hy-data/datasets/mvtec_loco_anomaly_detection',
help='Path to MVTec LOCO dataset root directory (default: /root/hy-data/datasets/mvtec_loco_anomaly_detection)')
parser.add_argument('--model_path', type=str, default='./saved_models/u2net/u2net.pth',
help='Path to U2NET model weights file (default: ./saved_models/u2net/u2net.pth)')
parser.add_argument('--output_dir', type=str, default='fg_mask',
help='Output directory name for generated masks (default: fg_mask)')
parser.add_argument('--threshold', type=float, default=0.5,
help='Threshold for binary mask generation (default: 0.5)')
parser.add_argument('--categories', nargs='+',
default=['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins'],
help='Categories to process (default: all 5 categories)')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size for processing (default: 1)')
parser.add_argument('--num_workers', type=int, default=1,
help='Number of data loading workers (default: 1)')
parser.add_argument('--splits', nargs='+', default=['test', 'train'],
help='Dataset splits to process (default: test train)')
args = parser.parse_args()
# Validate paths
if not os.path.exists(args.dataset_path):
print(f"ERROR: Dataset path not found: {args.dataset_path}")
print("Please check the dataset path and make sure MVTec LOCO dataset is properly extracted.")
exit(1)
if not os.path.exists(args.model_path):
print(f"Model not found: {args.model_path}")
# Try to download from HuggingFace if available
if HF_AVAILABLE:
print("Attempting to download model from HuggingFace Hub...")
downloaded_path = download_u2net_model(args.model_path)
if downloaded_path is None:
print("Failed to download from HuggingFace.")
print("Please download manually from:")
print("https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view")
exit(1)
else:
print("HuggingFace Hub not available. Please install: pip install huggingface_hub")
print("Or download manually from:")
print("https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view")
exit(1)
# Validate categories
valid_categories = ['breakfast_box', 'screw_bag', 'juice_bottle', 'splicing_connectors', 'pushpins']
invalid_categories = [cat for cat in args.categories if cat not in valid_categories]
if invalid_categories:
print(f"ERROR: Invalid categories: {invalid_categories}")
print(f"Valid categories are: {valid_categories}")
exit(1)
print(f"Configuration:")
print(f" Dataset path: {args.dataset_path}")
print(f" Model path: {args.model_path}")
print(f" Output directory: {args.output_dir}")
print(f" Binary threshold: {args.threshold}")
print(f" Categories: {args.categories}")
print(f" Splits: {args.splits}")
print(f" Batch size: {args.batch_size}")
print(f" Workers: {args.num_workers}")
print()
process_mvtec_loco_dataset(
dataset_path=args.dataset_path,
model_path=args.model_path,
output_dir=args.output_dir,
threshold=args.threshold,
categories=args.categories,
splits=args.splits,
batch_size=args.batch_size,
num_workers=args.num_workers
) |