""" This is the script of scaled cross-reference customization inference with Calligrapher. """ import os import json import random from PIL import Image import numpy as np from datetime import datetime import torch from diffusers.utils import load_image from pipeline_calligrapher import CalligrapherPipeline from models.calligrapher import Calligrapher from models.transformer_flux_inpainting import FluxTransformer2DModel from utils import resize_img_and_pad, generate_context_reference_image def infer_calligrapher(test_image_dir, result_save_dir, target_h=512, target_w=512, gen_num_per_case=2): # Set job dir. job_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") result_save_path = os.path.join(result_save_dir, job_name) if not os.path.exists(result_save_path): os.makedirs(result_save_path, exist_ok=True) # Load models. base_model_path = path_dict['base_model_path'] image_encoder_path = path_dict['image_encoder_path'] calligrapher_path = path_dict['calligrapher_path'] transformer = FluxTransformer2DModel.from_pretrained( base_model_path, subfolder="transformer", torch_dtype=torch.bfloat16 ) pipe = CalligrapherPipeline.from_pretrained(base_model_path, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda") model = Calligrapher(pipe, image_encoder_path, calligrapher_path, device="cuda", num_tokens=128) source_image_names = [i for i in os.listdir(test_image_dir) if 'source.png' in i] # Loading prompts from the bench txt and printing them. info_dict = {} with open(os.path.join(test_image_dir, 'cross_bench.txt'), 'r') as file: for line in file: line = line.strip() if line: key, value = line.split('-', 1) info_dict[int(key)] = value print(info_dict) i = 0 print('Printing given prompts...') for img_id in sorted(info_dict.keys()): i += 1 info = info_dict[img_id] print(f'Sample #{i}: {img_id} - {info}') count = 0 for source_image_name in sorted(source_image_names): count += 1 img_id = int(source_image_name.split("test")[1].split("_")[0]) if img_id not in info_dict.keys(): continue info = info_dict[img_id] ref_ids, text = info.split('-') ref_ids = ref_ids.split(',') prompt = f"The text is '{text}'." source_image_path = os.path.join(test_image_dir, source_image_name) mask_image_name = source_image_name.replace('source', 'mask') mask_image_path = os.path.join(test_image_dir, mask_image_name) for ref_id in ref_ids: reference_image_name = source_image_name.replace('source', 'ref').replace(f'{img_id}', f'{ref_id}') reference_image_path = os.path.join(test_image_dir, reference_image_name) print('source_image_path: ', source_image_path) print('mask_image_path: ', mask_image_path) print('reference_image_path: ', reference_image_path) print(f'prompt: {prompt}') source_image = load_image(source_image_path) mask_image = load_image(mask_image_path) # Resize source and mask. source_image = source_image.resize((target_w, target_h)) mask_image = mask_image.resize((target_w, target_h), Image.NEAREST) mask_np = np.array(mask_image) mask_np[mask_np > 0] = 255 mask_image = Image.fromarray(mask_np.astype(np.uint8)) source_img_w, source_img_h = source_image.size # resize reference to fit CLIP. reference_image = Image.open(reference_image_path).convert("RGB") reference_image_to_encoder = resize_img_and_pad(reference_image, target_size=[512, 512]) reference_context = generate_context_reference_image(reference_image, source_img_w) # Concat the context image on the top. source_with_context = Image.new(source_image.mode, (source_img_w, reference_context.size[1] + source_img_h)) source_with_context.paste(reference_context, (0, 0)) source_with_context.paste(source_image, (0, reference_context.size[1])) # Concat the 0 mask on the top of the mask image. mask_with_context = Image.new(mask_image.mode, (mask_image.size[0], reference_context.size[1] + mask_image.size[0]), color=0) mask_with_context.paste(mask_image, (0, reference_context.size[1])) # Identifiers in filename. ref_id = reference_image_name.split('_')[0] safe_prompt = prompt.replace(" ", "_").replace("'", "").replace(",", "").replace('"', '').replace('?', '')[:50] for i in range(gen_num_per_case): seed = random.randint(0, 2 ** 32 - 1) images = model.generate( image=source_with_context, mask_image=mask_with_context, ref_image=reference_image_to_encoder, prompt=prompt, scale=1.0, num_inference_steps=50, width=source_with_context.size[0], height=source_with_context.size[1], seed=seed, ) index = len(os.listdir(result_save_path)) output_filename = f"result_{index}_{ref_id}_{safe_prompt}_{seed}.png" result_img = images[0] result_img_vis = result_img.crop((0, reference_context.size[1], result_img.width, result_img.height)) result_img_vis.save(os.path.join(result_save_path, output_filename)) target_size = (source_image.size[0], source_image.size[1]) vis_img = Image.new('RGB', (source_image.size[0] * 3, source_image.size[1])) vis_img.paste(source_image.resize(target_size), (0, 0)) vis_img.paste(reference_context.resize(target_size), (source_image.size[0], 0)) vis_img.paste(result_img_vis.resize(target_size), (source_image.size[0] * 2, 0)) vis_img_save_path = os.path.join(result_save_path, f'vis_{output_filename}'.replace('.png', '.jpg')) vis_img.save(vis_img_save_path) print(f"Generated images saved to {vis_img_save_path}.") if __name__ == '__main__': with open(os.path.join(os.path.dirname(__file__), 'path_dict.json'), 'r') as f: path_dict = json.load(f) # Set directory paths. test_image_dir = path_dict['data_dir'] result_save_dir = path_dict['cli_save_dir'] infer_calligrapher(test_image_dir, result_save_dir, target_h=512, target_w=512, gen_num_per_case=2) print('Finished!')