|
import sys
|
|
import time
|
|
|
|
print('\033[1m' + "\nπͺβ¨ Starting Magic Crop β¨πͺ\n" + '\033[0m')
|
|
print("π Importing libraries...")
|
|
|
|
start_time = time.time()
|
|
|
|
import argparse
|
|
import os
|
|
import json
|
|
import math
|
|
import shutil
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
|
|
print(f"π Basic libraries imported in {time.time() - start_time:.2f} seconds\n")
|
|
print("π¦ Importing PyTorch and transformers (this may take a while)...")
|
|
|
|
import torch
|
|
from transformers import AutoProcessor, AutoModelForCausalLM
|
|
|
|
print(f"β All libraries imported in {time.time() - start_time:.2f} seconds\n")
|
|
|
|
def get_corner_distance(box, image_width, image_height):
|
|
x1, y1, x2, y2 = box
|
|
center_x = (x1 + x2) / 2
|
|
center_y = (y1 + y2) / 2
|
|
|
|
corners = [(0, 0), (image_width, 0), (0, image_height), (image_width, image_height)]
|
|
return min(math.sqrt((cx - center_x)**2 + (cy - center_y)**2) for cx, cy in corners)
|
|
|
|
def get_box_size(box):
|
|
x1, y1, x2, y2 = box
|
|
return (x2 - x1) * (y2 - y1)
|
|
|
|
def process_batch(model, processor, images, device, torch_dtype, prompt):
|
|
prompts = [f"<OPEN_VOCABULARY_DETECTION> {prompt}"] * len(images)
|
|
inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to(device, torch_dtype)
|
|
|
|
generated_ids = model.generate(
|
|
input_ids=inputs["input_ids"],
|
|
pixel_values=inputs["pixel_values"],
|
|
max_new_tokens=1024,
|
|
num_beams=3
|
|
)
|
|
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
|
|
|
|
parsed_answers = [processor.post_process_generation(text, task="<OPEN_VOCABULARY_DETECTION>", image_size=(img.width, img.height))
|
|
for text, img in zip(generated_texts, images)]
|
|
|
|
return parsed_answers
|
|
|
|
def get_object_detection(model, processor, image, device, torch_dtype):
|
|
prompt = "<OD>"
|
|
inputs = processor(text=[prompt], images=[image], return_tensors="pt", padding=True).to(device, torch_dtype)
|
|
|
|
generated_ids = model.generate(
|
|
input_ids=inputs["input_ids"],
|
|
pixel_values=inputs["pixel_values"],
|
|
max_new_tokens=1024,
|
|
num_beams=3
|
|
)
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
|
|
|
parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
|
|
|
|
return parsed_answer['<OD>']['bboxes']
|
|
|
|
def calculate_crop(image, detected_box, object_boxes=None):
|
|
x1, y1, x2, y2 = detected_box
|
|
width, height = image.size
|
|
|
|
crop_above = y1 * width
|
|
crop_below = (height - y2) * width
|
|
crop_left = x1 * height
|
|
crop_right = (width - x2) * height
|
|
|
|
if crop_above >= crop_below:
|
|
vertical_crop = ("above", (0, 0, width, y1), crop_above)
|
|
else:
|
|
vertical_crop = ("below", (0, y2, width, height), crop_below)
|
|
|
|
if crop_left >= crop_right:
|
|
horizontal_crop = ("left", (0, 0, x1, height), crop_left)
|
|
else:
|
|
horizontal_crop = ("right", (x2, 0, width, height), crop_right)
|
|
|
|
if object_boxes:
|
|
def calculate_affected_pixels(crop_box):
|
|
affected_pixels = 0
|
|
for obj_box in object_boxes:
|
|
ox1, oy1, ox2, oy2 = obj_box
|
|
cx1, cy1, cx2, cy2 = crop_box
|
|
|
|
intersection_area = max(0, min(ox2, cx2) - max(ox1, cx1)) * max(0, min(oy2, cy2) - max(oy1, cy1))
|
|
affected_pixels += intersection_area
|
|
|
|
return affected_pixels
|
|
|
|
vertical_affected = calculate_affected_pixels(vertical_crop[1])
|
|
horizontal_affected = calculate_affected_pixels(horizontal_crop[1])
|
|
|
|
if vertical_affected <= horizontal_affected:
|
|
best_crop = horizontal_crop
|
|
else:
|
|
best_crop = vertical_crop
|
|
else:
|
|
best_crop = vertical_crop if vertical_crop[2] >= horizontal_crop[2] else horizontal_crop
|
|
|
|
return best_crop[0], best_crop[1]
|
|
|
|
def process_image(model, processor, image_path, output_folder, prompt, object_aware, crop_threshold, debug, device, torch_dtype):
|
|
image = Image.open(image_path)
|
|
parsed_answers = process_batch(model, processor, [image], device, torch_dtype, prompt)
|
|
parsed_answer = parsed_answers[0]
|
|
|
|
if debug:
|
|
print("\nOPEN_VOCABULARY_DETECTION output:")
|
|
print(json.dumps(parsed_answer, indent=2))
|
|
|
|
if '<OPEN_VOCABULARY_DETECTION>' in parsed_answer and 'bboxes' in parsed_answer['<OPEN_VOCABULARY_DETECTION>']:
|
|
bboxes = parsed_answer['<OPEN_VOCABULARY_DETECTION>']['bboxes']
|
|
labels = parsed_answer['<OPEN_VOCABULARY_DETECTION>']['bboxes_labels']
|
|
|
|
detected_boxes = [box for box, label in zip(bboxes, labels) if prompt.lower() in label.lower()]
|
|
|
|
if detected_boxes:
|
|
sorted_boxes = sorted(detected_boxes,
|
|
key=lambda box: (get_corner_distance(box, image.width, image.height),
|
|
get_box_size(box)))
|
|
|
|
detected_box = sorted_boxes[0]
|
|
|
|
object_boxes = None
|
|
if object_aware:
|
|
object_boxes = get_object_detection(model, processor, image, device, torch_dtype)
|
|
if debug:
|
|
print("Object Detection output:")
|
|
print(json.dumps(object_boxes, indent=2))
|
|
|
|
crop_type, crop_box = calculate_crop(image, detected_box, object_boxes)
|
|
|
|
crop_area = (crop_box[2] - crop_box[0]) * (crop_box[3] - crop_box[1])
|
|
total_area = image.width * image.height
|
|
crop_percentage = (total_area - crop_area) / total_area * 100
|
|
|
|
if crop_percentage > crop_threshold:
|
|
print(f"Skipping {image_path} due to large crop area: {crop_percentage:.2f}%")
|
|
return False
|
|
|
|
if debug:
|
|
print(f"Chosen crop type: {crop_type}")
|
|
print(f"Pixels preserved: {crop_area}")
|
|
|
|
cropped_image = image.crop(crop_box)
|
|
|
|
filename = os.path.basename(image_path)
|
|
name, ext = os.path.splitext(filename)
|
|
output_filename = f"{name}_crop_{crop_type}.jpg"
|
|
output_path = os.path.join(output_folder, output_filename)
|
|
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
cropped_image.save(output_path, 'JPEG', quality=98)
|
|
print(f"Cropped image saved as: {output_path}")
|
|
return True
|
|
else:
|
|
print(f"No {prompt} found in the image.")
|
|
else:
|
|
print(f"No {prompt} detected in the image.")
|
|
return False
|
|
|
|
def load_model_and_processor(device, torch_dtype):
|
|
print("Initializing model and processor...")
|
|
start_time = time.time()
|
|
|
|
print("π₯οΈ Loading processor...")
|
|
try:
|
|
processor = AutoProcessor.from_pretrained("./Florence-2-large", trust_remote_code=True, clean_up_tokenization_spaces=True, local_files_only=True)
|
|
except Exception as e:
|
|
print(f"Error loading processor: {str(e)}")
|
|
processor_time = time.time() - start_time
|
|
print(f"β±οΈ Processor loaded in {processor_time:.2f} seconds\n")
|
|
|
|
print("π€ Loading model...")
|
|
model = AutoModelForCausalLM.from_pretrained("./Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True, local_files_only=True).to(device)
|
|
total_time = time.time() - start_time
|
|
print(f"β±οΈ Model loaded and moved to device in {total_time:.2f} seconds\n")
|
|
|
|
return model, processor
|
|
|
|
from PIL import Image
|
|
|
|
def crop_images(input_paths, output_folder, batch_size, prompt, object_aware, crop_threshold, recursive, debug, move_skipped, move_errored):
|
|
print(f"π₯ Input paths: {', '.join(input_paths)}")
|
|
print(f"π€ Output folder: {output_folder}")
|
|
print(f"ποΈ Batch size: {batch_size}")
|
|
print(f"π¬ Prompt: {prompt}")
|
|
print(f"π― Object-aware: {'Yes' if object_aware else 'No'}")
|
|
print(f"π Recursive: {'Yes' if recursive else 'No'}")
|
|
print(f"π Debug mode: {'On' if debug else 'Off'}")
|
|
print(f"π Crop threshold: {crop_threshold}%")
|
|
|
|
print("\nπ‘ Initializing...")
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
print(f"π Using device: {device}\n")
|
|
|
|
model, processor = load_model_and_processor(device, torch_dtype)
|
|
print("π’ Initialization complete.")
|
|
|
|
def get_image_files(folder):
|
|
return [f for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
|
|
|
|
total_images = 0
|
|
folders_to_process = []
|
|
errored_files = []
|
|
skipped_files = []
|
|
skipped_dirs = set()
|
|
errored_dirs = set()
|
|
|
|
for input_path in input_paths:
|
|
if os.path.isfile(input_path):
|
|
total_images += 1
|
|
folders_to_process.append((os.path.dirname(input_path), [os.path.basename(input_path)]))
|
|
else:
|
|
if recursive:
|
|
for root, _, files in os.walk(input_path):
|
|
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
|
|
if image_files:
|
|
total_images += len(image_files)
|
|
folders_to_process.append((root, image_files))
|
|
else:
|
|
image_files = get_image_files(input_path)
|
|
total_images += len(image_files)
|
|
folders_to_process.append((input_path, image_files))
|
|
|
|
print(f"\nπ’ Total images to process: {total_images}")
|
|
|
|
use_overall_progress = len(folders_to_process) > 1 or recursive
|
|
|
|
overall_progress = tqdm(total=total_images, desc="πΌοΈ Processing images", position=1, leave=True) if use_overall_progress else None
|
|
|
|
for folder_path, image_files in folders_to_process:
|
|
print(f"\n\nProcessing folder: {folder_path}")
|
|
print(f"Images in this folder: {len(image_files)}")
|
|
|
|
for i in tqdm(range(0, len(image_files), batch_size), desc="π§Ί Processing batches", position=0, leave=True):
|
|
batch_files = image_files[i:i+batch_size]
|
|
images = []
|
|
failed_files = []
|
|
|
|
|
|
for img_file in batch_files:
|
|
try:
|
|
img = Image.open(os.path.join(folder_path, img_file))
|
|
if img.mode != 'RGB':
|
|
img = img.convert('RGB')
|
|
images.append(img)
|
|
except OSError as e:
|
|
print(f"Error opening image file: {img_file} - {e}")
|
|
failed_files.append(img_file)
|
|
errored_files.append(os.path.join(folder_path, img_file))
|
|
if move_errored:
|
|
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
|
|
target_dir = os.path.join(output_folder, rel_path, "_Errored_")
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
shutil.copy(os.path.join(folder_path, img_file), target_dir)
|
|
errored_dirs.add(target_dir)
|
|
if use_overall_progress:
|
|
overall_progress.update(1)
|
|
|
|
if not images:
|
|
continue
|
|
|
|
|
|
try:
|
|
parsed_answers = process_batch(model, processor, images, device, torch_dtype, prompt)
|
|
except Exception as e:
|
|
print(f"Error processing batch: {e}")
|
|
|
|
if failed_files:
|
|
retry_files = [f for f in batch_files if f not in failed_files]
|
|
retry_images = []
|
|
for img_file in retry_files:
|
|
try:
|
|
img = Image.open(os.path.join(folder_path, img_file))
|
|
if img.mode != 'RGB':
|
|
img = img.convert('RGB')
|
|
retry_images.append(img)
|
|
except OSError as e:
|
|
print(f"Error opening image file on retry: {img_file} - {e}")
|
|
errored_files.append(os.path.join(folder_path, img_file))
|
|
if move_errored:
|
|
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
|
|
target_dir = os.path.join(output_folder, rel_path, "_Errored_")
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
shutil.copy(os.path.join(folder_path, img_file), target_dir)
|
|
errored_dirs.add(target_dir)
|
|
if use_overall_progress:
|
|
overall_progress.update(1)
|
|
if retry_images:
|
|
try:
|
|
parsed_answers = process_batch(model, processor, retry_images, device, torch_dtype, prompt)
|
|
batch_files = retry_files
|
|
except Exception as e:
|
|
print(f"Error processing batch on retry: {e}")
|
|
errored_files.extend([os.path.join(folder_path, f) for f in retry_files])
|
|
if move_errored:
|
|
for f in retry_files:
|
|
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
|
|
target_dir = os.path.join(output_folder, rel_path, "_Errored_")
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
shutil.copy(os.path.join(folder_path, f), target_dir)
|
|
errored_dirs.add(target_dir)
|
|
if use_overall_progress:
|
|
overall_progress.update(len(retry_files))
|
|
continue
|
|
|
|
for img_file, image, parsed_answer in zip(batch_files, images, parsed_answers):
|
|
if debug:
|
|
print(f"\nProcessing: {img_file}")
|
|
print("OPEN_VOCABULARY_DETECTION output:")
|
|
print(json.dumps(parsed_answer, indent=2))
|
|
|
|
if '<OPEN_VOCABULARY_DETECTION>' in parsed_answer and 'bboxes' in parsed_answer['<OPEN_VOCABULARY_DETECTION>']:
|
|
bboxes = parsed_answer['<OPEN_VOCABULARY_DETECTION>']['bboxes']
|
|
labels = parsed_answer['<OPEN_VOCABULARY_DETECTION>']['bboxes_labels']
|
|
|
|
detected_boxes = [box for box, label in zip(bboxes, labels) if prompt.lower() in label.lower()]
|
|
|
|
if detected_boxes:
|
|
sorted_boxes = sorted(detected_boxes,
|
|
key=lambda box: (get_corner_distance(box, image.width, image.height),
|
|
get_box_size(box)))
|
|
|
|
detected_box = sorted_boxes[0]
|
|
|
|
object_boxes = None
|
|
if object_aware:
|
|
object_boxes = get_object_detection(model, processor, image, device, torch_dtype)
|
|
if debug:
|
|
print("Object Detection output:")
|
|
print(json.dumps(object_boxes, indent=2))
|
|
|
|
crop_type, crop_box = calculate_crop(image, detected_box, object_boxes)
|
|
|
|
crop_area = (crop_box[2] - crop_box[0]) * (crop_box[3] - crop_box[1])
|
|
total_area = image.width * image.height
|
|
crop_percentage = (total_area - crop_area) / total_area * 100
|
|
|
|
if crop_percentage > crop_threshold:
|
|
print(f"Skipping {img_file} due to large crop area: {crop_percentage:.2f}%")
|
|
skipped_files.append(os.path.join(folder_path, img_file))
|
|
if move_skipped:
|
|
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
|
|
target_dir = os.path.join(output_folder, rel_path, "_Skipped_")
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
shutil.copy(os.path.join(folder_path, img_file), target_dir)
|
|
skipped_dirs.add(target_dir)
|
|
else:
|
|
try:
|
|
cropped_image = image.crop(crop_box)
|
|
filename, ext = os.path.splitext(img_file)
|
|
output_filename = f"{filename}_crop_{crop_type}.jpg"
|
|
|
|
|
|
if os.path.isfile(output_folder):
|
|
output_folder = os.path.dirname(output_folder)
|
|
|
|
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
|
|
output_path = os.path.join(output_folder, rel_path, output_filename)
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
|
cropped_image.save(output_path, 'JPEG', quality=98)
|
|
if debug:
|
|
print(f"Cropped image saved as: {output_path}")
|
|
except OSError as e:
|
|
print(f"Error cropping image file: {img_file} - {e}")
|
|
errored_files.append(os.path.join(folder_path, img_file))
|
|
if move_errored:
|
|
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
|
|
target_dir = os.path.join(output_folder, rel_path, "_Errored_")
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
shutil.copy(os.path.join(folder_path, img_file), target_dir)
|
|
errored_dirs.add(target_dir)
|
|
else:
|
|
print(f"No {prompt} found in {img_file}")
|
|
else:
|
|
print(f"No {prompt} detected in {img_file}")
|
|
|
|
if use_overall_progress:
|
|
overall_progress.update(1)
|
|
|
|
if use_overall_progress:
|
|
overall_progress.close()
|
|
|
|
print("\nβ
Processing complete!")
|
|
if errored_files:
|
|
print("\nErrored files:")
|
|
for file in errored_files:
|
|
print(file)
|
|
|
|
if skipped_files:
|
|
print("\nFiles skipped due to large areas being cropped:")
|
|
for file in skipped_files:
|
|
print(file)
|
|
|
|
if move_errored and errored_dirs:
|
|
print("\nErrored directories:")
|
|
for dir_path in errored_dirs:
|
|
print(dir_path)
|
|
|
|
if move_skipped and skipped_dirs:
|
|
print("\nSkipped directories:")
|
|
for dir_path in skipped_dirs:
|
|
print(dir_path)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Crop images to preserve maximum pixels around specified prompt")
|
|
parser.add_argument("input_paths", nargs='+', type=str, help="Paths to the input images or folders containing input images")
|
|
parser.add_argument("-r", "--recursive", action="store_true", help="Process folders recursively")
|
|
parser.add_argument("-o", "--output_folder", type=str, help="Path to the output folder")
|
|
parser.add_argument("--bs", type=int, default=1, help="Batch size for processing images (default: 1)")
|
|
parser.add_argument("--prompt", type=str, default="Watermark", help="Prompt for object detection (default: Watermark)")
|
|
parser.add_argument("--object-aware", action="store_true", help="Enable object-aware cropping")
|
|
parser.add_argument("--crop-threshold", type=float, default=20.0, help="Threshold for maximum allowed crop area percentage (default: 20%)")
|
|
parser.add_argument("--move-skipped", action="store_true", help="Copy skipped files to '_Skipped_' folder")
|
|
parser.add_argument("--move-errored", action="store_true", help="Copy errored files to '_Errored_' folder")
|
|
parser.add_argument("--debug", action="store_true", help="Enable debug output")
|
|
args = parser.parse_args()
|
|
|
|
if args.output_folder:
|
|
output_folder = args.output_folder
|
|
else:
|
|
output_folder = os.path.commonpath(args.input_paths)
|
|
|
|
crop_images(args.input_paths, output_folder, args.bs, args.prompt, args.object_aware, args.crop_threshold, args.recursive, args.debug, args.move_skipped, args.move_errored) |