import os import sys sys.path.append(os.getcwd()) import glob import argparse import torch from torchvision import transforms import torchvision.transforms.functional as F import numpy as np from PIL import Image from ram.models.ram_lora import ram from ram import inference_ram as inference from utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix tensor_transforms = transforms.Compose([ transforms.ToTensor(), ]) ram_transforms = transforms.Compose([ transforms.Resize((384, 384)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: w, h = img.size scale = size / min(w, h) new_w, new_h = int(w * scale), int(h * scale) img = img.resize((new_w, new_h), Image.LANCZOS) left = (new_w - size) // 2 top = (new_h - size) // 2 return img.crop((left, top, left + size, top + size)) def get_validation_prompt(args, image, prompt_image_path, dape_model=None, vlm_model=None, device='cuda'): # prepare low-res tensor for SR input lq = tensor_transforms(image).unsqueeze(0).to(device) # select prompt source if args.prompt_type == "null": prompt_text = args.prompt or "" elif args.prompt_type == "dape": lq_ram = ram_transforms(lq).to(dtype=weight_dtype) captions = inference(lq_ram, dape_model) prompt_text = f"{captions[0]}, {args.prompt}," if args.prompt else captions[0] elif args.prompt_type in ("vlm"): message_text = None if args.rec_type == "recursive": message_text = "What is in this image? Give me a set of words." print(f'MESSAGE TEXT: {message_text}') messages = [ {"role": "system", "content": f"{message_text}"}, { "role": "user", "content": [ {"type": "image", "image": prompt_image_path} ] } ] text = vlm_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = vlm_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) elif args.rec_type == "recursive_multiscale": start_image_path = prompt_image_path[0] input_image_path = prompt_image_path[1] message_text = "The second image is a zoom-in of the first image. Based on this knowledge, what is in the second image? Give me a set of words." print(f'START IMAGE PATH: {start_image_path}\nINPUT IMAGE PATH: {input_image_path}\nMESSAGE TEXT: {message_text}') messages = [ {"role": "system", "content": f"{message_text}"}, { "role": "user", "content": [ {"type": "image", "image": start_image_path}, {"type": "image", "image": input_image_path} ] } ] print(f'MESSAGES\n{messages}') text = vlm_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = vlm_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) else: raise ValueError(f"VLM prompt generation not implemented for rec_type: {args.rec_type}") inputs = inputs.to("cuda") original_sr_devices = {} if args.efficient_memory and 'model' in globals() and hasattr(model, 'text_enc_1'): # Check if SR model is defined print("Moving SR model components to CPU for VLM inference.") original_sr_devices['text_enc_1'] = model.text_enc_1.device original_sr_devices['text_enc_2'] = model.text_enc_2.device original_sr_devices['text_enc_3'] = model.text_enc_3.device original_sr_devices['transformer'] = model.transformer.device original_sr_devices['vae'] = model.vae.device model.text_enc_1.to('cpu') model.text_enc_2.to('cpu') model.text_enc_3.to('cpu') model.transformer.to('cpu') model.vae.to('cpu') vlm_model.to('cuda') # vlm_model should already be on its device_map="auto" device generated_ids = vlm_model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = vlm_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) prompt_text = f"{output_text[0]}, {args.prompt}," if args.prompt else output_text[0] if args.efficient_memory and 'model' in globals() and hasattr(model, 'text_enc_1'): print("Restoring SR model components to original devices.") vlm_model.to('cpu') # If vlm_model was moved to a specific cuda device and needs to be offloaded model.text_enc_1.to(original_sr_devices['text_enc_1']) model.text_enc_2.to(original_sr_devices['text_enc_2']) model.text_enc_3.to(original_sr_devices['text_enc_3']) model.transformer.to(original_sr_devices['transformer']) model.vae.to(original_sr_devices['vae']) else: raise ValueError(f"Unknown prompt_type: {args.prompt_type}") return prompt_text, lq if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--input_image', '-i', type=str, default='preset/datasets/test_dataset/input', help='path to the input image') parser.add_argument('--output_dir', '-o', type=str, default='preset/datasets/test_dataset/output', help='the directory to save the output') parser.add_argument('--pretrained_model_name_or_path', type=str, default=None, help='sd model path') parser.add_argument('--seed', type=int, default=42, help='Random seed to be used') parser.add_argument('--process_size', type=int, default=512) parser.add_argument('--upscale', type=int, default=4) parser.add_argument('--align_method', type=str, choices=['wavelet', 'adain', 'nofix'], default='nofix') parser.add_argument('--lora_path', type=str, default=None, help='for LoRA of SR model') parser.add_argument('--vae_path', type=str, default=None) parser.add_argument('--prompt', type=str, default='', help='user prompts') parser.add_argument('--prompt_type', type=str, choices=['null','dape','vlm'], default='dape', help='type of prompt to use') parser.add_argument('--ram_path', type=str, default=None) parser.add_argument('--ram_ft_path', type=str, default=None) parser.add_argument('--save_prompts', type=bool, default=True) parser.add_argument('--mixed_precision', type=str, choices=['fp16', 'fp32'], default='fp16') parser.add_argument('--merge_and_unload_lora', action='store_true', help='merge lora weights before inference') parser.add_argument('--lora_rank', type=int, default=4) parser.add_argument('--vae_decoder_tiled_size', type=int, default=224) parser.add_argument('--vae_encoder_tiled_size', type=int, default=1024) parser.add_argument('--latent_tiled_size', type=int, default=96) parser.add_argument('--latent_tiled_overlap', type=int, default=32) parser.add_argument('--rec_type', type=str, choices=['nearest', 'bicubic','onestep','recursive','recursive_multiscale'], default='recursive', help='type of inference to use') parser.add_argument('--rec_num', type=int, default=4) parser.add_argument('--efficient_memory', default=False, action='store_true') args = parser.parse_args() global weight_dtype weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 # initialize SR model model = None if args.rec_type not in ('nearest', 'bicubic'): if not args.efficient_memory: from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler model = SD3Euler() model.text_enc_1.to('cuda') model.text_enc_2.to('cuda') model.text_enc_3.to('cuda') model.transformer.to('cuda', dtype=torch.float32) model.vae.to('cuda', dtype=torch.float32) for p in [model.text_enc_1, model.text_enc_2, model.text_enc_3, model.transformer, model.vae]: p.requires_grad_(False) model_test = OSEDiff_SD3_TEST(args, model) else: # For efficient memory, text encoders are moved to CPU/GPU on demand in get_validation_prompt # Only load transformer and VAE initially if they are always on GPU from osediff_sd3 import OSEDiff_SD3_TEST_efficient, SD3Euler model = SD3Euler() model.transformer.to('cuda', dtype=torch.float32) model.vae.to('cuda', dtype=torch.float32) for p in [model.text_enc_1, model.text_enc_2, model.text_enc_3, model.transformer, model.vae]: p.requires_grad_(False) model_test = OSEDiff_SD3_TEST_efficient(args, model) # gather input images if os.path.isdir(args.input_image): image_names = sorted(glob.glob(f'{args.input_image}/*.png')) else: image_names = [args.input_image] # load DAPE if needed DAPE = None if args.prompt_type == "dape": DAPE = ram(pretrained=args.ram_path, pretrained_condition=args.ram_ft_path, image_size=384, vit='swin_l') DAPE.eval().to("cuda") DAPE = DAPE.to(dtype=weight_dtype) # load VLM pipeline if needed vlm_model = None global vlm_processor global process_vision_info vlm_processor = None if args.prompt_type == "vlm": from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info vlm_model_name = "Qwen/Qwen2.5-VL-3B-Instruct" print(f"Loading base VLM model: {vlm_model_name}") vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( vlm_model_name, torch_dtype="auto", device_map="auto" ) vlm_processor = AutoProcessor.from_pretrained(vlm_model_name) print('Base VLM LOADING COMPLETE') os.makedirs(args.output_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'per-sample'), exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'per-scale'), exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'recursive'), exist_ok=True) print(f'There are {len(image_names)} images.') print(f'Align Method Used: {args.align_method}') print(f'Prompt Type: {args.prompt_type}') # inference loop for image_name in image_names: bname = os.path.basename(image_name) rec_dir = os.path.join(args.output_dir, 'per-sample', bname[:-4]) os.makedirs(rec_dir, exist_ok=True) if args.save_prompts: txt_path = os.path.join(rec_dir, 'txt') os.makedirs(txt_path, exist_ok=True) print(f'#### IMAGE: {bname}') # first image os.makedirs(os.path.join(args.output_dir, 'per-scale', 'scale0'), exist_ok=True) first_image = Image.open(image_name).convert('RGB') first_image = resize_and_center_crop(first_image, args.process_size) first_image.save(f'{rec_dir}/0.png') first_image.save(os.path.join(args.output_dir, 'per-scale', 'scale0', bname)) # recursion for rec in range(args.rec_num): print(f'RECURSION: {rec}') os.makedirs(os.path.join(args.output_dir, 'per-scale', f'scale{rec+1}'), exist_ok=True) start_image_path = None input_image_path = None prompt_image_path = None # this will hold the path(s) for prompt extraction current_sr_input_image_pil = None if args.rec_type in ('nearest', 'bicubic', 'onestep'): start_image_pil_path = f'{rec_dir}/0.png' start_image_pil = Image.open(start_image_pil_path).convert('RGB') rscale = pow(args.upscale, rec+1) w, h = start_image_pil.size new_w, new_h = w // rscale, h // rscale # crop from the original highest-res image available for this step cropped_region = start_image_pil.crop(((w-new_w)//2, (h-new_h)//2, (w+new_w)//2, (h+new_h)//2)) if args.rec_type == 'onestep': current_sr_input_image_pil = cropped_region.resize((w, h), Image.BICUBIC) prompt_image_path = f'{rec_dir}/0_input_for_{rec+1}.png' current_sr_input_image_pil.save(prompt_image_path) elif args.rec_type == 'bicubic': current_sr_input_image_pil = cropped_region.resize((w, h), Image.BICUBIC) current_sr_input_image_pil.save(f'{rec_dir}/{rec+1}.png') current_sr_input_image_pil.save(os.path.join(args.output_dir, 'per-scale', f'scale{rec+1}', bname)) continue elif args.rec_type == 'nearest': current_sr_input_image_pil = cropped_region.resize((w, h), Image.NEAREST) current_sr_input_image_pil.save(f'{rec_dir}/{rec+1}.png') current_sr_input_image_pil.save(os.path.join(args.output_dir, 'per-scale', f'scale{rec+1}', bname)) continue elif args.rec_type == 'recursive': # input for SR is based on the previous SR output, cropped and resized prev_sr_output_path = f'{rec_dir}/{rec}.png' prev_sr_output_pil = Image.open(prev_sr_output_path).convert('RGB') rscale = args.upscale w, h = prev_sr_output_pil.size if rscale != 0: new_w, new_h = w // rscale, h // rscale else: new_w, new_h = w, h cropped_region = prev_sr_output_pil.crop(((w-new_w)//2, (h-new_h)//2, (w+new_w)//2, (h+new_h)//2)) current_sr_input_image_pil = cropped_region.resize((w, h), Image.BICUBIC) # this resized image is also the input for VLM input_image_path = f'{rec_dir}/{rec+1}_input.png' current_sr_input_image_pil.save(input_image_path) prompt_image_path = input_image_path elif args.rec_type == 'recursive_multiscale': prev_sr_output_path = f'{rec_dir}/{rec}.png' prev_sr_output_pil = Image.open(prev_sr_output_path).convert('RGB') rscale = args.upscale w, h = prev_sr_output_pil.size if rscale != 0: new_w, new_h = w // rscale, h // rscale else: new_w, new_h = w, h cropped_region = prev_sr_output_pil.crop(((w-new_w)//2, (h-new_h)//2, (w+new_w)//2, (h+new_h)//2)) current_sr_input_image_pil = cropped_region.resize((w, h), Image.BICUBIC) # save the SR input image (which is the "zoomed-in" image for VLM) zoomed_image_path = f'{rec_dir}/{rec+1}_input.png' current_sr_input_image_pil.save(zoomed_image_path) prompt_image_path = [prev_sr_output_path, zoomed_image_path] else: raise ValueError(f"Unknown recursion_type: {args.rec_type}") # generate prompts validation_prompt, lq = get_validation_prompt(args, current_sr_input_image_pil, prompt_image_path, DAPE, vlm_model) if args.save_prompts: with open(os.path.join(txt_path, f'{rec}.txt'), 'w', encoding='utf-8') as f: f.write(validation_prompt) print(f'TAG: {validation_prompt}') # super-resolution with torch.no_grad(): lq = lq * 2 - 1 if args.efficient_memory and model is not None: print("Ensuring SR model components are on CUDA for SR inference.") if not isinstance(model_test, OSEDiff_SD3_TEST_efficient): model.text_enc_1.to('cuda') model.text_enc_2.to('cuda') model.text_enc_3.to('cuda') # transformer and VAE should already be on CUDA per initialization model.transformer.to('cuda', dtype=torch.float32) model.vae.to('cuda', dtype=torch.float32) output_image = model_test(lq, prompt=validation_prompt) output_image = torch.clamp(output_image[0].cpu(), -1.0, 1.0) output_pil = transforms.ToPILImage()(output_image * 0.5 + 0.5) if args.align_method == 'adain': output_pil = adain_color_fix(target=output_pil, source=current_sr_input_image_pil) elif args.align_method == 'wavelet': output_pil = wavelet_color_fix(target=output_pil, source=current_sr_input_image_pil) output_pil.save(f'{rec_dir}/{rec+1}.png') # this is the SR output output_pil.save(os.path.join(args.output_dir, 'per-scale', f'scale{rec+1}', bname)) # concatenate and save imgs = [Image.open(os.path.join(rec_dir, f'{i}.png')).convert('RGB') for i in range(args.rec_num+1)] concat = Image.new('RGB', (sum(im.width for im in imgs), max(im.height for im in imgs))) x_off = 0 for im in imgs: concat.paste(im, (x_off, 0)) x_off += im.width concat.save(os.path.join(rec_dir, bname)) concat.save(os.path.join(args.output_dir, 'recursive', bname))