Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import torch | |
from data_loader.loader import ContentData | |
from models.unet import UNetModel | |
from diffusers import AutoencoderKL | |
from models.diffusion import Diffusion | |
import torchvision | |
from parse_config import cfg, cfg_from_file, assert_and_infer_cfg | |
from utils.util import fix_seed | |
from PIL import Image | |
import torchvision.transforms as transforms | |
class OneDMInference: | |
def __init__(self, model_path, cfg_path, device='cuda' if torch.cuda.is_available() else 'cpu'): | |
self.device = device | |
# Load config | |
cfg_from_file(cfg_path) | |
assert_and_infer_cfg() | |
fix_seed(cfg.TRAIN.SEED) | |
# Initialize models | |
self.unet = self._initialize_unet(model_path) | |
self.vae = self._initialize_vae() | |
self.diffusion = Diffusion(device=self.device) | |
self.content_loader = ContentData() | |
# Define transform | |
self.transform = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.ToTensor() | |
]) | |
def _initialize_unet(self, model_path): | |
unet = UNetModel( | |
in_channels=cfg.MODEL.IN_CHANNELS, | |
model_channels=cfg.MODEL.EMB_DIM, | |
out_channels=cfg.MODEL.OUT_CHANNELS, | |
num_res_blocks=cfg.MODEL.NUM_RES_BLOCKS, | |
attention_resolutions=(1,1), | |
channel_mult=(1, 1), | |
num_heads=cfg.MODEL.NUM_HEADS, | |
context_dim=cfg.MODEL.EMB_DIM | |
).to(self.device) | |
# Load model with weights_only=True | |
unet.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True)) | |
unet.eval() | |
return unet | |
def _initialize_vae(self): | |
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") | |
vae = vae.to(self.device) | |
vae.requires_grad_(False) | |
return vae | |
def _load_image(self, image_path): | |
image = Image.open(image_path) | |
image_tensor = self.transform(image) | |
return image_tensor | |
def generate(self, text, style_path, laplace_path, output_dir, | |
sample_method='ddim', sampling_timesteps=50, eta=0.0): | |
""" | |
Generate handwritten text with the specified style | |
""" | |
# Load style and laplace images | |
style_input = self._load_image(style_path).unsqueeze(0).to(self.device) | |
laplace = self._load_image(laplace_path).unsqueeze(0).to(self.device) | |
# Prepare text reference | |
text_ref = self.content_loader.get_content(text) | |
text_ref = text_ref.to(self.device).repeat(1, 1, 1, 1) | |
# Initialize noise | |
x = torch.randn((text_ref.shape[0], 4, style_input.shape[2]//8, | |
(text_ref.shape[1]*32)//8)).to(self.device) | |
# Generate image | |
if sample_method == 'ddim': | |
sampled_images = self.diffusion.ddim_sample( | |
self.unet, self.vae, style_input.shape[0], | |
x, style_input, laplace, text_ref, | |
sampling_timesteps, eta | |
) | |
elif sample_method == 'ddpm': | |
sampled_images = self.diffusion.ddpm_sample( | |
self.unet, self.vae, style_input.shape[0], | |
x, style_input, laplace, text_ref | |
) | |
# Save generated image | |
os.makedirs(output_dir, exist_ok=True) | |
output_paths = [] | |
for idx, image in enumerate(sampled_images): | |
im = torchvision.transforms.ToPILImage()(image) | |
image = im.convert("L") | |
output_path = os.path.join(output_dir, f"{text}_{idx}.png") | |
image.save(output_path) | |
output_paths.append(output_path) | |
return output_paths | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model_path', required=True, help='Path to the One-DM model checkpoint') | |
parser.add_argument('--cfg_path', required=True, help='Path to the config file') | |
parser.add_argument('--text', required=True, help='Text to generate') | |
parser.add_argument('--style_path', required=True, help='Path to style image') | |
parser.add_argument('--laplace_path', required=True, help='Path to laplace image') | |
parser.add_argument('--output_dir', required=True, help='Output directory') | |
args = parser.parse_args() | |
model = OneDMInference(args.model_path, args.cfg_path) | |
output_paths = model.generate( | |
args.text, | |
args.style_path, | |
args.laplace_path, | |
args.output_dir | |
) | |
print(f"Generated images saved at: {output_paths}") | |
if __name__ == "__main__": | |
main() |