Kishan11 commited on
Commit
bf15361
·
verified ·
1 Parent(s): 2f22cf7

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +125 -0
inference.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from data_loader.loader import ContentData
5
+ from models.unet import UNetModel
6
+ from diffusers import AutoencoderKL
7
+ from models.diffusion import Diffusion
8
+ import torchvision
9
+ from parse_config import cfg, cfg_from_file, assert_and_infer_cfg
10
+ from utils.util import fix_seed
11
+ from PIL import Image
12
+ import torchvision.transforms as transforms
13
+
14
+ class OneDMInference:
15
+ def __init__(self, model_path, cfg_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
16
+ self.device = device
17
+
18
+ # Load config
19
+ cfg_from_file(cfg_path)
20
+ assert_and_infer_cfg()
21
+ fix_seed(cfg.TRAIN.SEED)
22
+
23
+ # Initialize models
24
+ self.unet = self._initialize_unet(model_path)
25
+ self.vae = self._initialize_vae()
26
+ self.diffusion = Diffusion(device=self.device)
27
+ self.content_loader = ContentData()
28
+
29
+ # Define transform
30
+ self.transform = transforms.Compose([
31
+ transforms.Grayscale(),
32
+ transforms.ToTensor()
33
+ ])
34
+
35
+ def _initialize_unet(self, model_path):
36
+ unet = UNetModel(
37
+ in_channels=cfg.MODEL.IN_CHANNELS,
38
+ model_channels=cfg.MODEL.EMB_DIM,
39
+ out_channels=cfg.MODEL.OUT_CHANNELS,
40
+ num_res_blocks=cfg.MODEL.NUM_RES_BLOCKS,
41
+ attention_resolutions=(1,1),
42
+ channel_mult=(1, 1),
43
+ num_heads=cfg.MODEL.NUM_HEADS,
44
+ context_dim=cfg.MODEL.EMB_DIM
45
+ ).to(self.device)
46
+
47
+ # Load model with weights_only=True
48
+ unet.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
49
+ unet.eval()
50
+ return unet
51
+
52
+ def _initialize_vae(self):
53
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
54
+ vae = vae.to(self.device)
55
+ vae.requires_grad_(False)
56
+ return vae
57
+
58
+ def _load_image(self, image_path):
59
+ image = Image.open(image_path)
60
+ image_tensor = self.transform(image)
61
+ return image_tensor
62
+
63
+ def generate(self, text, style_path, laplace_path, output_dir,
64
+ sample_method='ddim', sampling_timesteps=50, eta=0.0):
65
+ """
66
+ Generate handwritten text with the specified style
67
+ """
68
+ # Load style and laplace images
69
+ style_input = self._load_image(style_path).unsqueeze(0).to(self.device)
70
+ laplace = self._load_image(laplace_path).unsqueeze(0).to(self.device)
71
+
72
+ # Prepare text reference
73
+ text_ref = self.content_loader.get_content(text)
74
+ text_ref = text_ref.to(self.device).repeat(1, 1, 1, 1)
75
+
76
+ # Initialize noise
77
+ x = torch.randn((text_ref.shape[0], 4, style_input.shape[2]//8,
78
+ (text_ref.shape[1]*32)//8)).to(self.device)
79
+
80
+ # Generate image
81
+ if sample_method == 'ddim':
82
+ sampled_images = self.diffusion.ddim_sample(
83
+ self.unet, self.vae, style_input.shape[0],
84
+ x, style_input, laplace, text_ref,
85
+ sampling_timesteps, eta
86
+ )
87
+ elif sample_method == 'ddpm':
88
+ sampled_images = self.diffusion.ddpm_sample(
89
+ self.unet, self.vae, style_input.shape[0],
90
+ x, style_input, laplace, text_ref
91
+ )
92
+
93
+ # Save generated image
94
+ os.makedirs(output_dir, exist_ok=True)
95
+ output_paths = []
96
+ for idx, image in enumerate(sampled_images):
97
+ im = torchvision.transforms.ToPILImage()(image)
98
+ image = im.convert("L")
99
+ output_path = os.path.join(output_dir, f"{text}_{idx}.png")
100
+ image.save(output_path)
101
+ output_paths.append(output_path)
102
+
103
+ return output_paths
104
+
105
+ def main():
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument('--model_path', required=True, help='Path to the One-DM model checkpoint')
108
+ parser.add_argument('--cfg_path', required=True, help='Path to the config file')
109
+ parser.add_argument('--text', required=True, help='Text to generate')
110
+ parser.add_argument('--style_path', required=True, help='Path to style image')
111
+ parser.add_argument('--laplace_path', required=True, help='Path to laplace image')
112
+ parser.add_argument('--output_dir', required=True, help='Output directory')
113
+ args = parser.parse_args()
114
+
115
+ model = OneDMInference(args.model_path, args.cfg_path)
116
+ output_paths = model.generate(
117
+ args.text,
118
+ args.style_path,
119
+ args.laplace_path,
120
+ args.output_dir
121
+ )
122
+ print(f"Generated images saved at: {output_paths}")
123
+
124
+ if __name__ == "__main__":
125
+ main()