import datetime import os import os.path as osp import random import subprocess from functools import partial from typing import Optional import time import pytz from infinity.dataset.webdataset import WDSEditDataset try: from grp import getgrgid from pwd import getpwuid except: pass import PIL.Image as PImage from PIL import ImageFile import numpy as np from torchvision.transforms import transforms from torchvision.transforms.functional import resize, to_tensor import torch.distributed as tdist from torchvision.transforms import InterpolationMode bicubic = InterpolationMode.BICUBIC lanczos = InterpolationMode.LANCZOS PImage.MAX_IMAGE_PIXELS = (1024 * 1024 * 1024 // 4 // 3) * 5 ImageFile.LOAD_TRUNCATED_IMAGES = False def time_str(fmt='[%m-%d %H:%M:%S]'): return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt) def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1 return x.add(x).add_(-1) def denormalize_pm1_into_01(x): # denormalize x from [-1, 1] to [0, 1] return x.add(1).mul_(0.5) def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=PImage.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=PImage.LANCZOS ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return PImage.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) class RandomResize: def __init__(self, mid_reso, final_reso, interpolation): ub = max(round((mid_reso + (mid_reso-final_reso) / 8) / 4) * 4, mid_reso) self.reso_lb, self.reso_ub = final_reso, ub self.interpolation = interpolation def __call__(self, img): return resize(img, size=random.randint(self.reso_lb, self.reso_ub), interpolation=self.interpolation) def __repr__(self): return f'RandomResize(reso=({self.reso_lb}, {self.reso_ub}), interpolation={self.interpolation})' def load_save(reso=512): import os from PIL import Image as PImage from torchvision.transforms import transforms, InterpolationMode aug = transforms.Compose([ transforms.Resize(512, interpolation=InterpolationMode.LANCZOS), transforms.CenterCrop((512, 512)) ]) src_folder = r'C:\Users\16333\Pictures\imgs_to_visual_v2' ls = [os.path.join(src_folder, x) for x in ('1.jpg', '2.jpg', '3.png', '4.png', '5.png')] print(ls) imgs = [] for i, fname in enumerate(ls): assert os.path.exists(fname) with PImage.open(fname) as img: img = img.convert('RGB') img = aug(img) imgs.append(img) dst_d, dst_f = os.path.split(fname) dst = os.path.join(dst_d, f'crop{dst_f.replace(".jpg", ".png")}') img.save(dst) W, H = imgs[0].size WW = W * len(imgs) new_im = PImage.new('RGB', (WW, H)) x_offset = 0 for img in imgs: new_im.paste(img, (x_offset, 0)) x_offset += W dst = os.path.join(src_folder, f'junfeng.png') new_im.save(dst) def print_aug(transform, label): print(f'Transform {label} = ') if hasattr(transform, 'transforms'): for t in transform.transforms: print(t) else: print(transform) print('---------------------------\n') def build_t2i_dataset( args, data_path: str, data_load_reso: int, max_caption_len: int, short_prob=0.2, load_vae_instead_of_image=False ): if args.use_streaming_dataset: # return T2IIterableDataset( # data_path, # max_caption_len=max_caption_len, # short_prob=short_prob, # load_vae_instead_of_image=load_vae_instead_of_image, # buffersize=args.iterable_data_buffersize, # pn=args.pn, # online_t5=args.online_t5, # batch_size=args.batch_size, # num_replicas=tdist.get_world_size(), # 1, # rank=tdist.get_rank(), # 0 # dataloader_workers=args.workers, # dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus, # enable_dynamic_length_prompt=args.enable_dynamic_length_prompt, # seed=args.seed if args.seed is not None else int(time.time()), # ) return WDSEditDataset( data_path, buffersize=args.iterable_data_buffersize, pn=args.pn, batch_size=args.batch_size, num_replicas=tdist.get_world_size(), # 1, rank=tdist.get_rank(), # 0 # dataloader_workers=args.workers, # dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus, # enable_dynamic_length_prompt=args.enable_dynamic_length_prompt, # seed=args.seed if args.seed is not None else int(time.time()), ) else: raise ValueError(f'args.use_streaming_dataset={args.use_streaming_dataset} unsupported') def pil_load(path: str, proposal_size): with open(path, 'rb') as f: img: PImage.Image = PImage.open(f) w: int = img.width h: int = img.height sh: int = min(h, w) if sh > proposal_size: ratio: float = proposal_size / sh w = round(ratio * w) h = round(ratio * h) img.draft('RGB', (w, h)) img = img.convert('RGB') return img def rewrite(im: PImage, file: str, info: str): kw = dict(quality=100) if file.lower().endswith('.tif') or file.lower().endswith('.tiff'): kw['compression'] = 'none' elif file.lower().endswith('.webp'): kw['lossless'] = True st = os.stat(file) uname = getpwuid(st.st_uid).pw_name gname = getgrgid(st.st_gid).gr_name mode = oct(st.st_mode)[-3:] local_file = osp.basename(file) im.save(local_file, **kw) print(f'************* ************* @ {file}') subprocess.call(f'sudo mv {local_file} {file}; sudo chown {uname}:{gname} {file}; sudo chmod {mode} {file}', shell=True)