import sys import time print('\033[1m' + "\n๐Ÿช„โœจ Starting Magic Crop โœจ๐Ÿช„\n" + '\033[0m') print("๐Ÿ“š Importing libraries...") start_time = time.time() import argparse import os import json import math import shutil from tqdm import tqdm from PIL import Image print(f"๐Ÿƒ Basic libraries imported in {time.time() - start_time:.2f} seconds\n") print("๐Ÿ”ฆ Importing PyTorch and transformers (this may take a while)...") import torch from transformers import AutoProcessor, AutoModelForCausalLM print(f"โŒ› All libraries imported in {time.time() - start_time:.2f} seconds\n") def get_corner_distance(box, image_width, image_height): x1, y1, x2, y2 = box center_x = (x1 + x2) / 2 center_y = (y1 + y2) / 2 corners = [(0, 0), (image_width, 0), (0, image_height), (image_width, image_height)] return min(math.sqrt((cx - center_x)**2 + (cy - center_y)**2) for cx, cy in corners) def get_box_size(box): x1, y1, x2, y2 = box return (x2 - x1) * (y2 - y1) def process_batch(model, processor, images, device, torch_dtype, prompt): prompts = [f" {prompt}"] * len(images) inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to(device, torch_dtype) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 ) generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False) parsed_answers = [processor.post_process_generation(text, task="", image_size=(img.width, img.height)) for text, img in zip(generated_texts, images)] return parsed_answers def get_object_detection(model, processor, image, device, torch_dtype): prompt = "" inputs = processor(text=[prompt], images=[image], return_tensors="pt", padding=True).to(device, torch_dtype) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation(generated_text, task="", image_size=(image.width, image.height)) return parsed_answer['']['bboxes'] def calculate_crop(image, detected_box, object_boxes=None): x1, y1, x2, y2 = detected_box width, height = image.size crop_above = y1 * width crop_below = (height - y2) * width crop_left = x1 * height crop_right = (width - x2) * height if crop_above >= crop_below: vertical_crop = ("above", (0, 0, width, y1), crop_above) else: vertical_crop = ("below", (0, y2, width, height), crop_below) if crop_left >= crop_right: horizontal_crop = ("left", (0, 0, x1, height), crop_left) else: horizontal_crop = ("right", (x2, 0, width, height), crop_right) if object_boxes: def calculate_affected_pixels(crop_box): affected_pixels = 0 for obj_box in object_boxes: ox1, oy1, ox2, oy2 = obj_box cx1, cy1, cx2, cy2 = crop_box intersection_area = max(0, min(ox2, cx2) - max(ox1, cx1)) * max(0, min(oy2, cy2) - max(oy1, cy1)) affected_pixels += intersection_area return affected_pixels vertical_affected = calculate_affected_pixels(vertical_crop[1]) horizontal_affected = calculate_affected_pixels(horizontal_crop[1]) if vertical_affected <= horizontal_affected: best_crop = horizontal_crop else: best_crop = vertical_crop else: best_crop = vertical_crop if vertical_crop[2] >= horizontal_crop[2] else horizontal_crop return best_crop[0], best_crop[1] def process_image(model, processor, image_path, output_folder, prompt, object_aware, crop_threshold, debug, device, torch_dtype): image = Image.open(image_path) parsed_answers = process_batch(model, processor, [image], device, torch_dtype, prompt) parsed_answer = parsed_answers[0] if debug: print("\nOPEN_VOCABULARY_DETECTION output:") print(json.dumps(parsed_answer, indent=2)) if '' in parsed_answer and 'bboxes' in parsed_answer['']: bboxes = parsed_answer['']['bboxes'] labels = parsed_answer['']['bboxes_labels'] detected_boxes = [box for box, label in zip(bboxes, labels) if prompt.lower() in label.lower()] if detected_boxes: sorted_boxes = sorted(detected_boxes, key=lambda box: (get_corner_distance(box, image.width, image.height), get_box_size(box))) detected_box = sorted_boxes[0] object_boxes = None if object_aware: object_boxes = get_object_detection(model, processor, image, device, torch_dtype) if debug: print("Object Detection output:") print(json.dumps(object_boxes, indent=2)) crop_type, crop_box = calculate_crop(image, detected_box, object_boxes) crop_area = (crop_box[2] - crop_box[0]) * (crop_box[3] - crop_box[1]) total_area = image.width * image.height crop_percentage = (total_area - crop_area) / total_area * 100 if crop_percentage > crop_threshold: print(f"Skipping {image_path} due to large crop area: {crop_percentage:.2f}%") return False if debug: print(f"Chosen crop type: {crop_type}") print(f"Pixels preserved: {crop_area}") cropped_image = image.crop(crop_box) filename = os.path.basename(image_path) name, ext = os.path.splitext(filename) output_filename = f"{name}_crop_{crop_type}.jpg" output_path = os.path.join(output_folder, output_filename) os.makedirs(output_folder, exist_ok=True) cropped_image.save(output_path, 'JPEG', quality=98) print(f"Cropped image saved as: {output_path}") return True else: print(f"No {prompt} found in the image.") else: print(f"No {prompt} detected in the image.") return False def load_model_and_processor(device, torch_dtype): print("Initializing model and processor...") start_time = time.time() print("๐Ÿ–ฅ๏ธ Loading processor...") try: processor = AutoProcessor.from_pretrained("./Florence-2-large", trust_remote_code=True, clean_up_tokenization_spaces=True, local_files_only=True) except Exception as e: print(f"Error loading processor: {str(e)}") processor_time = time.time() - start_time print(f"โฑ๏ธ Processor loaded in {processor_time:.2f} seconds\n") print("๐Ÿค– Loading model...") model = AutoModelForCausalLM.from_pretrained("./Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True, local_files_only=True).to(device) total_time = time.time() - start_time print(f"โฑ๏ธ Model loaded and moved to device in {total_time:.2f} seconds\n") return model, processor from PIL import Image def crop_images(input_paths, output_folder, batch_size, prompt, object_aware, crop_threshold, recursive, debug, move_skipped, move_errored): print(f"๐Ÿ“ฅ Input paths: {', '.join(input_paths)}") print(f"๐Ÿ“ค Output folder: {output_folder}") print(f"๐ŸŽž๏ธ Batch size: {batch_size}") print(f"๐Ÿ’ฌ Prompt: {prompt}") print(f"๐ŸŽฏ Object-aware: {'Yes' if object_aware else 'No'}") print(f"๐Ÿ”€ Recursive: {'Yes' if recursive else 'No'}") print(f"๐Ÿ› Debug mode: {'On' if debug else 'Off'}") print(f"๐Ÿ“ Crop threshold: {crop_threshold}%") print("\n๐ŸŸก Initializing...") device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 print(f"๐Ÿ“Ÿ Using device: {device}\n") model, processor = load_model_and_processor(device, torch_dtype) print("๐ŸŸข Initialization complete.") def get_image_files(folder): return [f for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))] total_images = 0 folders_to_process = [] errored_files = [] skipped_files = [] skipped_dirs = set() errored_dirs = set() for input_path in input_paths: if os.path.isfile(input_path): total_images += 1 folders_to_process.append((os.path.dirname(input_path), [os.path.basename(input_path)])) else: if recursive: for root, _, files in os.walk(input_path): image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))] if image_files: total_images += len(image_files) folders_to_process.append((root, image_files)) else: image_files = get_image_files(input_path) total_images += len(image_files) folders_to_process.append((input_path, image_files)) print(f"\n๐Ÿ”ข Total images to process: {total_images}") use_overall_progress = len(folders_to_process) > 1 or recursive overall_progress = tqdm(total=total_images, desc="๐Ÿ–ผ๏ธ Processing images", position=1, leave=True) if use_overall_progress else None for folder_path, image_files in folders_to_process: print(f"\n\nProcessing folder: {folder_path}") print(f"Images in this folder: {len(image_files)}") for i in tqdm(range(0, len(image_files), batch_size), desc="๐Ÿงบ Processing batches", position=0, leave=True): batch_files = image_files[i:i+batch_size] images = [] failed_files = [] # Try to open images, collecting any that fail for img_file in batch_files: try: img = Image.open(os.path.join(folder_path, img_file)) if img.mode != 'RGB': img = img.convert('RGB') images.append(img) except OSError as e: print(f"Error opening image file: {img_file} - {e}") failed_files.append(img_file) errored_files.append(os.path.join(folder_path, img_file)) if move_errored: rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths)) target_dir = os.path.join(output_folder, rel_path, "_Errored_") os.makedirs(target_dir, exist_ok=True) shutil.copy(os.path.join(folder_path, img_file), target_dir) errored_dirs.add(target_dir) if use_overall_progress: overall_progress.update(1) if not images: continue # Attempt to process the batch try: parsed_answers = process_batch(model, processor, images, device, torch_dtype, prompt) except Exception as e: print(f"Error processing batch: {e}") # Retry without failed files if failed_files: retry_files = [f for f in batch_files if f not in failed_files] retry_images = [] for img_file in retry_files: try: img = Image.open(os.path.join(folder_path, img_file)) if img.mode != 'RGB': img = img.convert('RGB') retry_images.append(img) except OSError as e: print(f"Error opening image file on retry: {img_file} - {e}") errored_files.append(os.path.join(folder_path, img_file)) if move_errored: rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths)) target_dir = os.path.join(output_folder, rel_path, "_Errored_") os.makedirs(target_dir, exist_ok=True) shutil.copy(os.path.join(folder_path, img_file), target_dir) errored_dirs.add(target_dir) if use_overall_progress: overall_progress.update(1) if retry_images: try: parsed_answers = process_batch(model, processor, retry_images, device, torch_dtype, prompt) batch_files = retry_files # Update batch_files to exclude failed ones except Exception as e: print(f"Error processing batch on retry: {e}") errored_files.extend([os.path.join(folder_path, f) for f in retry_files]) if move_errored: for f in retry_files: rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths)) target_dir = os.path.join(output_folder, rel_path, "_Errored_") os.makedirs(target_dir, exist_ok=True) shutil.copy(os.path.join(folder_path, f), target_dir) errored_dirs.add(target_dir) if use_overall_progress: overall_progress.update(len(retry_files)) continue for img_file, image, parsed_answer in zip(batch_files, images, parsed_answers): if debug: print(f"\nProcessing: {img_file}") print("OPEN_VOCABULARY_DETECTION output:") print(json.dumps(parsed_answer, indent=2)) if '' in parsed_answer and 'bboxes' in parsed_answer['']: bboxes = parsed_answer['']['bboxes'] labels = parsed_answer['']['bboxes_labels'] detected_boxes = [box for box, label in zip(bboxes, labels) if prompt.lower() in label.lower()] if detected_boxes: sorted_boxes = sorted(detected_boxes, key=lambda box: (get_corner_distance(box, image.width, image.height), get_box_size(box))) detected_box = sorted_boxes[0] object_boxes = None if object_aware: object_boxes = get_object_detection(model, processor, image, device, torch_dtype) if debug: print("Object Detection output:") print(json.dumps(object_boxes, indent=2)) crop_type, crop_box = calculate_crop(image, detected_box, object_boxes) crop_area = (crop_box[2] - crop_box[0]) * (crop_box[3] - crop_box[1]) total_area = image.width * image.height crop_percentage = (total_area - crop_area) / total_area * 100 if crop_percentage > crop_threshold: print(f"Skipping {img_file} due to large crop area: {crop_percentage:.2f}%") skipped_files.append(os.path.join(folder_path, img_file)) if move_skipped: rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths)) target_dir = os.path.join(output_folder, rel_path, "_Skipped_") os.makedirs(target_dir, exist_ok=True) shutil.copy(os.path.join(folder_path, img_file), target_dir) skipped_dirs.add(target_dir) else: try: cropped_image = image.crop(crop_box) filename, ext = os.path.splitext(img_file) output_filename = f"{filename}_crop_{crop_type}.jpg" # Ensure output_folder is a directory path if os.path.isfile(output_folder): output_folder = os.path.dirname(output_folder) rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths)) output_path = os.path.join(output_folder, rel_path, output_filename) # Create the directory path, not the file path os.makedirs(os.path.dirname(output_path), exist_ok=True) cropped_image.save(output_path, 'JPEG', quality=98) if debug: print(f"Cropped image saved as: {output_path}") except OSError as e: print(f"Error cropping image file: {img_file} - {e}") errored_files.append(os.path.join(folder_path, img_file)) if move_errored: rel_path = os.path.relpath(folder_path, os.path.commonpath(input_paths)) target_dir = os.path.join(output_folder, rel_path, "_Errored_") os.makedirs(target_dir, exist_ok=True) shutil.copy(os.path.join(folder_path, img_file), target_dir) errored_dirs.add(target_dir) else: print(f"No {prompt} found in {img_file}") else: print(f"No {prompt} detected in {img_file}") if use_overall_progress: overall_progress.update(1) if use_overall_progress: overall_progress.close() print("\nโœ… Processing complete!") if errored_files: print("\nErrored files:") for file in errored_files: print(file) if skipped_files: print("\nFiles skipped due to large areas being cropped:") for file in skipped_files: print(file) if move_errored and errored_dirs: print("\nErrored directories:") for dir_path in errored_dirs: print(dir_path) if move_skipped and skipped_dirs: print("\nSkipped directories:") for dir_path in skipped_dirs: print(dir_path) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Crop images to preserve maximum pixels around specified prompt") parser.add_argument("input_paths", nargs='+', type=str, help="Paths to the input images or folders containing input images") parser.add_argument("-r", "--recursive", action="store_true", help="Process folders recursively") parser.add_argument("-o", "--output_folder", type=str, help="Path to the output folder") parser.add_argument("--bs", type=int, default=1, help="Batch size for processing images (default: 1)") parser.add_argument("--prompt", type=str, default="Watermark", help="Prompt for object detection (default: Watermark)") parser.add_argument("--object-aware", action="store_true", help="Enable object-aware cropping") parser.add_argument("--crop-threshold", type=float, default=20.0, help="Threshold for maximum allowed crop area percentage (default: 20%)") parser.add_argument("--move-skipped", action="store_true", help="Copy skipped files to '_Skipped_' folder") parser.add_argument("--move-errored", action="store_true", help="Copy errored files to '_Errored_' folder") parser.add_argument("--debug", action="store_true", help="Enable debug output") args = parser.parse_args() if args.output_folder: output_folder = args.output_folder else: output_folder = os.path.commonpath(args.input_paths) crop_images(args.input_paths, output_folder, args.bs, args.prompt, args.object_aware, args.crop_threshold, args.recursive, args.debug, args.move_skipped, args.move_errored)