File size: 4,679 Bytes
bf15361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import os
import torch
from data_loader.loader import ContentData
from models.unet import UNetModel
from diffusers import AutoencoderKL
from models.diffusion import Diffusion
import torchvision
from parse_config import cfg, cfg_from_file, assert_and_infer_cfg
from utils.util import fix_seed
from PIL import Image
import torchvision.transforms as transforms

class OneDMInference:
    def __init__(self, model_path, cfg_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        
        # Load config
        cfg_from_file(cfg_path)
        assert_and_infer_cfg()
        fix_seed(cfg.TRAIN.SEED)
        
        # Initialize models
        self.unet = self._initialize_unet(model_path)
        self.vae = self._initialize_vae()
        self.diffusion = Diffusion(device=self.device)
        self.content_loader = ContentData()
        
        # Define transform
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor()
        ])
    
    def _initialize_unet(self, model_path):
        unet = UNetModel(
            in_channels=cfg.MODEL.IN_CHANNELS,
            model_channels=cfg.MODEL.EMB_DIM,
            out_channels=cfg.MODEL.OUT_CHANNELS,
            num_res_blocks=cfg.MODEL.NUM_RES_BLOCKS,
            attention_resolutions=(1,1),
            channel_mult=(1, 1),
            num_heads=cfg.MODEL.NUM_HEADS,
            context_dim=cfg.MODEL.EMB_DIM
        ).to(self.device)
        
        # Load model with weights_only=True
        unet.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
        unet.eval()
        return unet
    
    def _initialize_vae(self):
        vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
        vae = vae.to(self.device)
        vae.requires_grad_(False)
        return vae

    def _load_image(self, image_path):
        image = Image.open(image_path)
        image_tensor = self.transform(image)
        return image_tensor
    
    def generate(self, text, style_path, laplace_path, output_dir, 
                sample_method='ddim', sampling_timesteps=50, eta=0.0):
        """
        Generate handwritten text with the specified style
        """
        # Load style and laplace images
        style_input = self._load_image(style_path).unsqueeze(0).to(self.device)
        laplace = self._load_image(laplace_path).unsqueeze(0).to(self.device)
        
        # Prepare text reference
        text_ref = self.content_loader.get_content(text)
        text_ref = text_ref.to(self.device).repeat(1, 1, 1, 1)
        
        # Initialize noise
        x = torch.randn((text_ref.shape[0], 4, style_input.shape[2]//8, 
                        (text_ref.shape[1]*32)//8)).to(self.device)
        
        # Generate image
        if sample_method == 'ddim':
            sampled_images = self.diffusion.ddim_sample(
                self.unet, self.vae, style_input.shape[0],
                x, style_input, laplace, text_ref,
                sampling_timesteps, eta
            )
        elif sample_method == 'ddpm':
            sampled_images = self.diffusion.ddpm_sample(
                self.unet, self.vae, style_input.shape[0],
                x, style_input, laplace, text_ref
            )
        
        # Save generated image
        os.makedirs(output_dir, exist_ok=True)
        output_paths = []
        for idx, image in enumerate(sampled_images):
            im = torchvision.transforms.ToPILImage()(image)
            image = im.convert("L")
            output_path = os.path.join(output_dir, f"{text}_{idx}.png")
            image.save(output_path)
            output_paths.append(output_path)
            
        return output_paths

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', required=True, help='Path to the One-DM model checkpoint')
    parser.add_argument('--cfg_path', required=True, help='Path to the config file')
    parser.add_argument('--text', required=True, help='Text to generate')
    parser.add_argument('--style_path', required=True, help='Path to style image')
    parser.add_argument('--laplace_path', required=True, help='Path to laplace image')
    parser.add_argument('--output_dir', required=True, help='Output directory')
    args = parser.parse_args()
    
    model = OneDMInference(args.model_path, args.cfg_path)
    output_paths = model.generate(
        args.text,
        args.style_path,
        args.laplace_path,
        args.output_dir
    )
    print(f"Generated images saved at: {output_paths}")

if __name__ == "__main__":
    main()