|
import json |
|
import os |
|
import random |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from accelerate import Accelerator |
|
from omegaconf import OmegaConf |
|
from torch.nn.utils.rnn import pad_sequence |
|
from torchmetrics.image.fid import FrechetInceptionDistance |
|
from torchvision.transforms import functional as F |
|
from tqdm import tqdm |
|
|
|
from app_model import AppModel |
|
from app_utils import randomize_seed_fn |
|
from fairseq import options |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
|
|
|
|
class COCO_Dataset_Image(torch.utils.data.Dataset): |
|
def __init__(self, files): |
|
self.files = files |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def __getitem__(self, index): |
|
filename = self.files[index] |
|
real_image = np.array(Image.open(filename).convert('RGB')) |
|
real_image = torch.tensor(real_image) |
|
real_image = real_image.permute(2, 0, 1) / 255.0 |
|
real_image = F.resize(real_image, 256) |
|
real_image = F.center_crop(real_image, (256, 256)) |
|
return real_image |
|
|
|
|
|
class COCO_Dataset_Caption(torch.utils.data.Dataset): |
|
def __init__(self, args, preprocess_fn): |
|
self.args = args |
|
self.preprocess_fn = preprocess_fn |
|
|
|
with open(os.path.join(args.data_dir, 'annotations', 'captions_val2014.json'), 'r') as f: |
|
self.coco = json.load(f) |
|
self.files = self.coco['annotations'] |
|
|
|
random.seed(args.seed) |
|
self.files = random.sample(self.files, 30000) |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def __getitem__(self, index): |
|
prompt = self.files[index]['caption'] |
|
|
|
src_tokens, _, img_gpt_input_mask, negative_tokens = \ |
|
self.preprocess_fn(prompt, |
|
"" if self.args.negative_prompt else "", |
|
None, single_batch=False) |
|
|
|
return src_tokens, img_gpt_input_mask, negative_tokens |
|
|
|
|
|
def collate_fn(batch): |
|
src_tokens = [x[0] for x in batch] |
|
img_gpt_input_mask = [x[1] for x in batch] |
|
negative_tokens = batch[0][2].unsqueeze(0) |
|
src_tokens = pad_sequence(src_tokens, batch_first=True, padding_value=1) |
|
img_gpt_input_mask = pad_sequence(img_gpt_input_mask, batch_first=True, padding_value=0) |
|
|
|
return src_tokens, img_gpt_input_mask, negative_tokens |
|
|
|
|
|
def main(cfg): |
|
cfg.model.pretrained_ckpt_path = "/path/to/checkpoint_final.pt" |
|
args = OmegaConf.create() |
|
args.data_dir = "/path/to/coco" |
|
args.batch_size = 16 |
|
args.num_workers = 4 |
|
args.scheduler = "ddim" |
|
args.num_inference_steps = 250 |
|
args.guidance_scale = 3.0 |
|
args.num_images_per_prompt = 1 |
|
args.seed = 0 |
|
args.negative_prompt = False |
|
args.override = False |
|
args.output_dir = "/path/to/output-dir/" + cfg.model.pretrained_ckpt_path.split('/')[-2] + '_' + \ |
|
cfg.model.pretrained_ckpt_path.split('/')[-1].split('.')[0].split('_')[-1] + '_' + args.scheduler \ |
|
+ '_' + str(args.num_inference_steps) + '_' + str(args.negative_prompt) |
|
|
|
accelerator = Accelerator() |
|
if accelerator.is_main_process and not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir) |
|
|
|
fid = FrechetInceptionDistance(normalize=True) |
|
fid = accelerator.prepare_model(fid, evaluation_mode=True) |
|
with open(os.path.join(args.data_dir, 'annotations', 'captions_val2014.json'), 'r') as f: |
|
files = json.load(f)['images'] |
|
files = [os.path.join(args.data_dir, 'val2014', file['file_name']) for file in files] |
|
image_dataset = COCO_Dataset_Image(files) |
|
image_dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=16, num_workers=args.num_workers, |
|
shuffle=False, pin_memory=True, drop_last=False, |
|
persistent_workers=True) |
|
image_dataloader = accelerator.prepare(image_dataloader) |
|
accelerator.print("Number of real images: ", len(image_dataset)) |
|
|
|
for batch in tqdm(image_dataloader): |
|
fid.update(batch, real=True) |
|
|
|
|
|
image_paths = list() |
|
for root, dirs, files in os.walk(args.output_dir): |
|
for file in files: |
|
if file.endswith(".png"): |
|
image_paths.append(os.path.join(root, file)) |
|
if len(image_paths) >= 30000 and not args.override: |
|
accelerator.print("Already generated enough images") |
|
image_dataset = COCO_Dataset_Image(image_paths) |
|
image_dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=128, num_workers=args.num_workers, |
|
shuffle=False, pin_memory=True, drop_last=False, |
|
persistent_workers=True) |
|
image_dataloader = accelerator.prepare(image_dataloader) |
|
accelerator.print("Number of fake images: ", len(image_dataset)) |
|
|
|
for batch in tqdm(image_dataloader): |
|
fid.update(batch, real=False) |
|
accelerator.print("FID: ", fid.compute()) |
|
return |
|
else: |
|
|
|
if accelerator.is_main_process: |
|
for root, dirs, files in os.walk(args.output_dir): |
|
for file in files: |
|
if file.endswith(".png"): |
|
os.remove(os.path.join(root, file)) |
|
|
|
model = AppModel(cfg) |
|
model.set_ckpt_scheduler_fn(cfg.model.pretrained_ckpt_path, args.scheduler) |
|
|
|
caption_dataset = COCO_Dataset_Caption(args, model.kosmosg_preprocess) |
|
caption_dataloader = torch.utils.data.DataLoader(caption_dataset, batch_size=args.batch_size, |
|
num_workers=args.num_workers, shuffle=False, pin_memory=True, |
|
drop_last=False, persistent_workers=True, collate_fn=collate_fn) |
|
accelerator.print("Number of prompts: ", len(caption_dataset)) |
|
|
|
model, caption_dataloader = accelerator.prepare(model, caption_dataloader) |
|
|
|
kwargs = { |
|
'num_inference_steps': args.num_inference_steps, |
|
'text_guidance_scale': args.guidance_scale, |
|
'num_images_per_prompt': args.num_images_per_prompt, |
|
'lora_scale': 0.0, |
|
'output_type': 'numpy' |
|
} |
|
|
|
for batch_id, batch in tqdm(enumerate(caption_dataloader), total=len(caption_dataloader)): |
|
src_tokens, img_gpt_input_mask, negative_tokens = batch |
|
|
|
randomize_seed_fn(args.seed, False) |
|
images = model.model.sample(src_tokens, None, img_gpt_input_mask, negative_tokens, **kwargs) |
|
|
|
|
|
for image_id, image in enumerate(images): |
|
pos = batch_id * accelerator.num_processes * args.batch_size * args.num_images_per_prompt + \ |
|
image_id * accelerator.num_processes + accelerator.process_index |
|
model.model.vae.numpy_to_pil(image)[0].save(os.path.join(args.output_dir, "{:05d}.png".format(pos))) |
|
|
|
images = np.stack(images, axis=0) |
|
images = torch.tensor(images).to(accelerator.device) |
|
images = images.permute(0, 3, 1, 2) |
|
fid.update(images, real=False) |
|
|
|
accelerator.print("Number of Real Images: ", (fid.real_features_num_samples * accelerator.num_processes).item()) |
|
accelerator.print("Number of Fake Images: ", (fid.real_features_num_samples * accelerator.num_processes).item()) |
|
accelerator.print("FID: ", fid.compute()) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = options.get_training_parser() |
|
cfg = options.parse_args_and_arch(parser, modify_parser=None) |
|
cfg = convert_namespace_to_omegaconf(cfg) |
|
main(cfg) |
|
|