import argparse import os import torch from data_loader.loader import ContentData from models.unet import UNetModel from diffusers import AutoencoderKL from models.diffusion import Diffusion import torchvision from parse_config import cfg, cfg_from_file, assert_and_infer_cfg from utils.util import fix_seed from PIL import Image import torchvision.transforms as transforms class OneDMInference: def __init__(self, model_path, cfg_path, device='cuda' if torch.cuda.is_available() else 'cpu'): self.device = device # Load config cfg_from_file(cfg_path) assert_and_infer_cfg() fix_seed(cfg.TRAIN.SEED) # Initialize models self.unet = self._initialize_unet(model_path) self.vae = self._initialize_vae() self.diffusion = Diffusion(device=self.device) self.content_loader = ContentData() # Define transform self.transform = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor() ]) def _initialize_unet(self, model_path): unet = UNetModel( in_channels=cfg.MODEL.IN_CHANNELS, model_channels=cfg.MODEL.EMB_DIM, out_channels=cfg.MODEL.OUT_CHANNELS, num_res_blocks=cfg.MODEL.NUM_RES_BLOCKS, attention_resolutions=(1,1), channel_mult=(1, 1), num_heads=cfg.MODEL.NUM_HEADS, context_dim=cfg.MODEL.EMB_DIM ).to(self.device) # Load model with weights_only=True unet.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True)) unet.eval() return unet def _initialize_vae(self): vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") vae = vae.to(self.device) vae.requires_grad_(False) return vae def _load_image(self, image_path): image = Image.open(image_path) image_tensor = self.transform(image) return image_tensor def generate(self, text, style_path, laplace_path, output_dir, sample_method='ddim', sampling_timesteps=50, eta=0.0): """ Generate handwritten text with the specified style """ # Load style and laplace images style_input = self._load_image(style_path).unsqueeze(0).to(self.device) laplace = self._load_image(laplace_path).unsqueeze(0).to(self.device) # Prepare text reference text_ref = self.content_loader.get_content(text) text_ref = text_ref.to(self.device).repeat(1, 1, 1, 1) # Initialize noise x = torch.randn((text_ref.shape[0], 4, style_input.shape[2]//8, (text_ref.shape[1]*32)//8)).to(self.device) # Generate image if sample_method == 'ddim': sampled_images = self.diffusion.ddim_sample( self.unet, self.vae, style_input.shape[0], x, style_input, laplace, text_ref, sampling_timesteps, eta ) elif sample_method == 'ddpm': sampled_images = self.diffusion.ddpm_sample( self.unet, self.vae, style_input.shape[0], x, style_input, laplace, text_ref ) # Save generated image os.makedirs(output_dir, exist_ok=True) output_paths = [] for idx, image in enumerate(sampled_images): im = torchvision.transforms.ToPILImage()(image) image = im.convert("L") output_path = os.path.join(output_dir, f"{text}_{idx}.png") image.save(output_path) output_paths.append(output_path) return output_paths def main(): parser = argparse.ArgumentParser() parser.add_argument('--model_path', required=True, help='Path to the One-DM model checkpoint') parser.add_argument('--cfg_path', required=True, help='Path to the config file') parser.add_argument('--text', required=True, help='Text to generate') parser.add_argument('--style_path', required=True, help='Path to style image') parser.add_argument('--laplace_path', required=True, help='Path to laplace image') parser.add_argument('--output_dir', required=True, help='Output directory') args = parser.parse_args() model = OneDMInference(args.model_path, args.cfg_path) output_paths = model.generate( args.text, args.style_path, args.laplace_path, args.output_dir ) print(f"Generated images saved at: {output_paths}") if __name__ == "__main__": main()