Wi-zz's picture
Upload folder using huggingface_hub
42ce57b verified
import sys
import time
print('\033[1m' + "\nπŸͺ„βœ¨ Starting Magic Crop ✨πŸͺ„\n" + '\033[0m')
print("πŸ“š Importing libraries...")
start_time = time.time()
import argparse
import os
import json
import math
import shutil
from tqdm import tqdm
from PIL import Image
print(f"πŸƒ Basic libraries imported in {time.time() - start_time:.2f} seconds\n")
print("πŸ”¦ Importing PyTorch and transformers (this may take a while)...")
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
print(f"βŒ› All libraries imported in {time.time() - start_time:.2f} seconds\n")
def get_corner_distance(box, image_width, image_height):
x1, y1, x2, y2 = box
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
corners = [(0, 0), (image_width, 0), (0, image_height), (image_width, image_height)]
return min(math.sqrt((cx - center_x)**2 + (cy - center_y)**2) for cx, cy in corners)
def get_box_size(box):
x1, y1, x2, y2 = box
return (x2 - x1) * (y2 - y1)
def process_batch(model, processor, images, device, torch_dtype, prompt):
prompts = [f"<OPEN_VOCABULARY_DETECTION> {prompt}"] * len(images)
inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3
)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
parsed_answers = [processor.post_process_generation(text, task="<OPEN_VOCABULARY_DETECTION>", image_size=(img.width, img.height))
for text, img in zip(generated_texts, images)]
return parsed_answers
def get_object_detection(model, processor, image, device, torch_dtype):
prompt = "<OD>"
inputs = processor(text=[prompt], images=[image], return_tensors="pt", padding=True).to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
return parsed_answer['<OD>']['bboxes']
def calculate_crop(image, detected_box, object_boxes=None):
x1, y1, x2, y2 = detected_box
width, height = image.size
crop_above = y1 * width
crop_below = (height - y2) * width
crop_left = x1 * height
crop_right = (width - x2) * height
if crop_above >= crop_below:
vertical_crop = ("above", (0, 0, width, y1), crop_above)
else:
vertical_crop = ("below", (0, y2, width, height), crop_below)
if crop_left >= crop_right:
horizontal_crop = ("left", (0, 0, x1, height), crop_left)
else:
horizontal_crop = ("right", (x2, 0, width, height), crop_right)
if object_boxes:
def calculate_affected_pixels(crop_box):
affected_pixels = 0
for obj_box in object_boxes:
ox1, oy1, ox2, oy2 = obj_box
cx1, cy1, cx2, cy2 = crop_box
intersection_area = max(0, min(ox2, cx2) - max(ox1, cx1)) * max(0, min(oy2, cy2) - max(oy1, cy1))
affected_pixels += intersection_area
return affected_pixels
vertical_affected = calculate_affected_pixels(vertical_crop[1])
horizontal_affected = calculate_affected_pixels(horizontal_crop[1])
if vertical_affected <= horizontal_affected:
best_crop = horizontal_crop
else:
best_crop = vertical_crop
else:
best_crop = vertical_crop if vertical_crop[2] >= horizontal_crop[2] else horizontal_crop
return best_crop[0], best_crop[1]
def process_image(model, processor, image_path, output_folder, prompt, object_aware, crop_threshold, debug, device, torch_dtype):
image = Image.open(image_path)
parsed_answers = process_batch(model, processor, [image], device, torch_dtype, prompt)
parsed_answer = parsed_answers[0]
if debug:
print("\nOPEN_VOCABULARY_DETECTION output:")
print(json.dumps(parsed_answer, indent=2))
if '<OPEN_VOCABULARY_DETECTION>' in parsed_answer and 'bboxes' in parsed_answer['<OPEN_VOCABULARY_DETECTION>']:
bboxes = parsed_answer['<OPEN_VOCABULARY_DETECTION>']['bboxes']
labels = parsed_answer['<OPEN_VOCABULARY_DETECTION>']['bboxes_labels']
detected_boxes = [box for box, label in zip(bboxes, labels) if prompt.lower() in label.lower()]
if detected_boxes:
sorted_boxes = sorted(detected_boxes,
key=lambda box: (get_corner_distance(box, image.width, image.height),
get_box_size(box)))
detected_box = sorted_boxes[0]
object_boxes = None
if object_aware:
object_boxes = get_object_detection(model, processor, image, device, torch_dtype)
if debug:
print("Object Detection output:")
print(json.dumps(object_boxes, indent=2))
crop_type, crop_box = calculate_crop(image, detected_box, object_boxes)
crop_area = (crop_box[2] - crop_box[0]) * (crop_box[3] - crop_box[1])
total_area = image.width * image.height
crop_percentage = (total_area - crop_area) / total_area * 100
if crop_percentage > crop_threshold:
print(f"Skipping {image_path} due to large crop area: {crop_percentage:.2f}%")
return False
if debug:
print(f"Chosen crop type: {crop_type}")
print(f"Pixels preserved: {crop_area}")
cropped_image = image.crop(crop_box)
filename = os.path.basename(image_path)
name, ext = os.path.splitext(filename)
output_filename = f"{name}_crop_{crop_type}.jpg"
output_path = os.path.join(output_folder, output_filename)
os.makedirs(output_folder, exist_ok=True)
cropped_image.save(output_path, 'JPEG', quality=98)
print(f"Cropped image saved as: {output_path}")
return True
else:
print(f"No {prompt} found in the image.")
else:
print(f"No {prompt} detected in the image.")
return False
def load_model_and_processor(device, torch_dtype):
print("Initializing model and processor...")
start_time = time.time()
print("πŸ–₯️ Loading processor...")
try:
processor = AutoProcessor.from_pretrained("./Florence-2-large", trust_remote_code=True, clean_up_tokenization_spaces=True, local_files_only=True)
except Exception as e:
print(f"Error loading processor: {str(e)}")
processor_time = time.time() - start_time
print(f"⏱️ Processor loaded in {processor_time:.2f} seconds\n")
print("πŸ€– Loading model...")
model = AutoModelForCausalLM.from_pretrained("./Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True, local_files_only=True).to(device)
total_time = time.time() - start_time
print(f"⏱️ Model loaded and moved to device in {total_time:.2f} seconds\n")
return model, processor
from PIL import Image
def crop_images(input_paths, output_folder, batch_size, prompt, object_aware, crop_threshold, recursive, debug, move_skipped, move_errored):
print(f"πŸ“₯ Input paths: {', '.join(input_paths)}")
print(f"πŸ“€ Output folder: {output_folder}")
print(f"🎞️ Batch size: {batch_size}")
print(f"πŸ’¬ Prompt: {prompt}")
print(f"🎯 Object-aware: {'Yes' if object_aware else 'No'}")
print(f"πŸ”€ Recursive: {'Yes' if recursive else 'No'}")
print(f"πŸ› Debug mode: {'On' if debug else 'Off'}")
print(f"πŸ“ Crop threshold: {crop_threshold}%")
print("\n🟑 Initializing...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"πŸ“Ÿ Using device: {device}\n")
model, processor = load_model_and_processor(device, torch_dtype)
print("🟒 Initialization complete.")
def get_image_files(folder):
return [f for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
total_images = 0
folders_to_process = []
errored_files = []
skipped_files = []
skipped_dirs = set()
errored_dirs = set()
for input_path in input_paths:
if os.path.isfile(input_path):
total_images += 1
folders_to_process.append((os.path.dirname(input_path), [os.path.basename(input_path)]))
else:
if recursive:
for root, _, files in os.walk(input_path):
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
if image_files:
total_images += len(image_files)
folders_to_process.append((root, image_files))
else:
image_files = get_image_files(input_path)
total_images += len(image_files)
folders_to_process.append((input_path, image_files))
print(f"\nπŸ”’ Total images to process: {total_images}")
use_overall_progress = len(folders_to_process) > 1 or recursive
overall_progress = tqdm(total=total_images, desc="πŸ–ΌοΈ Processing images", position=1, leave=True) if use_overall_progress else None
for folder_path, image_files in folders_to_process:
print(f"\n\nProcessing folder: {folder_path}")
print(f"Images in this folder: {len(image_files)}")
for i in tqdm(range(0, len(image_files), batch_size), desc="🧺 Processing batches", position=0, leave=True):
batch_files = image_files[i:i+batch_size]
images = []
failed_files = []
# Try to open images, collecting any that fail
for img_file in batch_files:
try:
img = Image.open(os.path.join(folder_path, img_file))
if img.mode != 'RGB':
img = img.convert('RGB')
images.append(img)
except OSError as e:
print(f"Error opening image file: {img_file} - {e}")
failed_files.append(img_file)
errored_files.append(os.path.join(folder_path, img_file))
if move_errored:
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
target_dir = os.path.join(output_folder, rel_path, "_Errored_")
os.makedirs(target_dir, exist_ok=True)
shutil.copy(os.path.join(folder_path, img_file), target_dir)
errored_dirs.add(target_dir)
if use_overall_progress:
overall_progress.update(1)
if not images:
continue
# Attempt to process the batch
try:
parsed_answers = process_batch(model, processor, images, device, torch_dtype, prompt)
except Exception as e:
print(f"Error processing batch: {e}")
# Retry without failed files
if failed_files:
retry_files = [f for f in batch_files if f not in failed_files]
retry_images = []
for img_file in retry_files:
try:
img = Image.open(os.path.join(folder_path, img_file))
if img.mode != 'RGB':
img = img.convert('RGB')
retry_images.append(img)
except OSError as e:
print(f"Error opening image file on retry: {img_file} - {e}")
errored_files.append(os.path.join(folder_path, img_file))
if move_errored:
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
target_dir = os.path.join(output_folder, rel_path, "_Errored_")
os.makedirs(target_dir, exist_ok=True)
shutil.copy(os.path.join(folder_path, img_file), target_dir)
errored_dirs.add(target_dir)
if use_overall_progress:
overall_progress.update(1)
if retry_images:
try:
parsed_answers = process_batch(model, processor, retry_images, device, torch_dtype, prompt)
batch_files = retry_files # Update batch_files to exclude failed ones
except Exception as e:
print(f"Error processing batch on retry: {e}")
errored_files.extend([os.path.join(folder_path, f) for f in retry_files])
if move_errored:
for f in retry_files:
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
target_dir = os.path.join(output_folder, rel_path, "_Errored_")
os.makedirs(target_dir, exist_ok=True)
shutil.copy(os.path.join(folder_path, f), target_dir)
errored_dirs.add(target_dir)
if use_overall_progress:
overall_progress.update(len(retry_files))
continue
for img_file, image, parsed_answer in zip(batch_files, images, parsed_answers):
if debug:
print(f"\nProcessing: {img_file}")
print("OPEN_VOCABULARY_DETECTION output:")
print(json.dumps(parsed_answer, indent=2))
if '<OPEN_VOCABULARY_DETECTION>' in parsed_answer and 'bboxes' in parsed_answer['<OPEN_VOCABULARY_DETECTION>']:
bboxes = parsed_answer['<OPEN_VOCABULARY_DETECTION>']['bboxes']
labels = parsed_answer['<OPEN_VOCABULARY_DETECTION>']['bboxes_labels']
detected_boxes = [box for box, label in zip(bboxes, labels) if prompt.lower() in label.lower()]
if detected_boxes:
sorted_boxes = sorted(detected_boxes,
key=lambda box: (get_corner_distance(box, image.width, image.height),
get_box_size(box)))
detected_box = sorted_boxes[0]
object_boxes = None
if object_aware:
object_boxes = get_object_detection(model, processor, image, device, torch_dtype)
if debug:
print("Object Detection output:")
print(json.dumps(object_boxes, indent=2))
crop_type, crop_box = calculate_crop(image, detected_box, object_boxes)
crop_area = (crop_box[2] - crop_box[0]) * (crop_box[3] - crop_box[1])
total_area = image.width * image.height
crop_percentage = (total_area - crop_area) / total_area * 100
if crop_percentage > crop_threshold:
print(f"Skipping {img_file} due to large crop area: {crop_percentage:.2f}%")
skipped_files.append(os.path.join(folder_path, img_file))
if move_skipped:
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
target_dir = os.path.join(output_folder, rel_path, "_Skipped_")
os.makedirs(target_dir, exist_ok=True)
shutil.copy(os.path.join(folder_path, img_file), target_dir)
skipped_dirs.add(target_dir)
else:
try:
cropped_image = image.crop(crop_box)
filename, ext = os.path.splitext(img_file)
output_filename = f"{filename}_crop_{crop_type}.jpg"
# Ensure output_folder is a directory path
if os.path.isfile(output_folder):
output_folder = os.path.dirname(output_folder)
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
output_path = os.path.join(output_folder, rel_path, output_filename)
# Create the directory path, not the file path
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cropped_image.save(output_path, 'JPEG', quality=98)
if debug:
print(f"Cropped image saved as: {output_path}")
except OSError as e:
print(f"Error cropping image file: {img_file} - {e}")
errored_files.append(os.path.join(folder_path, img_file))
if move_errored:
rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths))
target_dir = os.path.join(output_folder, rel_path, "_Errored_")
os.makedirs(target_dir, exist_ok=True)
shutil.copy(os.path.join(folder_path, img_file), target_dir)
errored_dirs.add(target_dir)
else:
print(f"No {prompt} found in {img_file}")
else:
print(f"No {prompt} detected in {img_file}")
if use_overall_progress:
overall_progress.update(1)
if use_overall_progress:
overall_progress.close()
print("\nβœ… Processing complete!")
if errored_files:
print("\nErrored files:")
for file in errored_files:
print(file)
if skipped_files:
print("\nFiles skipped due to large areas being cropped:")
for file in skipped_files:
print(file)
if move_errored and errored_dirs:
print("\nErrored directories:")
for dir_path in errored_dirs:
print(dir_path)
if move_skipped and skipped_dirs:
print("\nSkipped directories:")
for dir_path in skipped_dirs:
print(dir_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop images to preserve maximum pixels around specified prompt")
parser.add_argument("input_paths", nargs='+', type=str, help="Paths to the input images or folders containing input images")
parser.add_argument("-r", "--recursive", action="store_true", help="Process folders recursively")
parser.add_argument("-o", "--output_folder", type=str, help="Path to the output folder")
parser.add_argument("--bs", type=int, default=1, help="Batch size for processing images (default: 1)")
parser.add_argument("--prompt", type=str, default="Watermark", help="Prompt for object detection (default: Watermark)")
parser.add_argument("--object-aware", action="store_true", help="Enable object-aware cropping")
parser.add_argument("--crop-threshold", type=float, default=20.0, help="Threshold for maximum allowed crop area percentage (default: 20%)")
parser.add_argument("--move-skipped", action="store_true", help="Copy skipped files to '_Skipped_' folder")
parser.add_argument("--move-errored", action="store_true", help="Copy errored files to '_Errored_' folder")
parser.add_argument("--debug", action="store_true", help="Enable debug output")
args = parser.parse_args()
if args.output_folder:
output_folder = args.output_folder
else:
output_folder = os.path.commonpath(args.input_paths)
crop_images(args.input_paths, output_folder, args.bs, args.prompt, args.object_aware, args.crop_threshold, args.recursive, args.debug, args.move_skipped, args.move_errored)