File size: 6,686 Bytes
aa4fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import datetime
import os
import os.path as osp
import random
import subprocess
from functools import partial
from typing import Optional
import time

import pytz
from infinity.dataset.webdataset import WDSEditDataset

try:
    from grp import getgrgid
    from pwd import getpwuid
except:
    pass
import PIL.Image as PImage
from PIL import ImageFile
import numpy as np
from torchvision.transforms import transforms
from torchvision.transforms.functional import resize, to_tensor
import torch.distributed as tdist

from torchvision.transforms import InterpolationMode
bicubic = InterpolationMode.BICUBIC
lanczos = InterpolationMode.LANCZOS
PImage.MAX_IMAGE_PIXELS = (1024 * 1024 * 1024 // 4 // 3) * 5
ImageFile.LOAD_TRUNCATED_IMAGES = False


def time_str(fmt='[%m-%d %H:%M:%S]'):
    return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)


def normalize_01_into_pm1(x):  # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
    return x.add(x).add_(-1)


def denormalize_pm1_into_01(x):  # denormalize x from [-1, 1] to [0, 1]
    return x.add(1).mul_(0.5)


def center_crop_arr(pil_image, image_size):
    """

    Center cropping implementation from ADM.

    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126

    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=PImage.BOX
        )
    
    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=PImage.LANCZOS
    )
    
    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return PImage.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


class RandomResize:
    def __init__(self, mid_reso, final_reso, interpolation):
        ub = max(round((mid_reso + (mid_reso-final_reso) / 8) / 4) * 4, mid_reso)
        self.reso_lb, self.reso_ub = final_reso, ub
        self.interpolation = interpolation
    
    def __call__(self, img):
        return resize(img, size=random.randint(self.reso_lb, self.reso_ub), interpolation=self.interpolation)
    
    def __repr__(self):
        return f'RandomResize(reso=({self.reso_lb}, {self.reso_ub}), interpolation={self.interpolation})'


def load_save(reso=512):
    import os
    from PIL import Image as PImage
    from torchvision.transforms import transforms, InterpolationMode
    aug = transforms.Compose([
        transforms.Resize(512, interpolation=InterpolationMode.LANCZOS),
        transforms.CenterCrop((512, 512))
    ])
    src_folder = r'C:\Users\16333\Pictures\imgs_to_visual_v2'
    ls = [os.path.join(src_folder, x) for x in ('1.jpg', '2.jpg', '3.png', '4.png', '5.png')]
    print(ls)
    imgs = []
    for i, fname in enumerate(ls):
        assert os.path.exists(fname)
        with PImage.open(fname) as img:
            img = img.convert('RGB')
            img = aug(img)
            imgs.append(img)
        dst_d, dst_f = os.path.split(fname)
        dst = os.path.join(dst_d, f'crop{dst_f.replace(".jpg", ".png")}')
        img.save(dst)
    
    W, H = imgs[0].size
    WW = W * len(imgs)
    new_im = PImage.new('RGB', (WW, H))
    x_offset = 0
    for img in imgs:
        new_im.paste(img, (x_offset, 0))
        x_offset += W
    dst = os.path.join(src_folder, f'junfeng.png')
    new_im.save(dst)


def print_aug(transform, label):
    print(f'Transform {label} = ')
    if hasattr(transform, 'transforms'):
        for t in transform.transforms:
            print(t)
    else:
        print(transform)
    print('---------------------------\n')


def build_t2i_dataset(

    args,

    data_path: str,

    data_load_reso: int,

    max_caption_len: int,

    short_prob=0.2,

    load_vae_instead_of_image=False

):
    if args.use_streaming_dataset:
        # return T2IIterableDataset(
        #     data_path, 
        #     max_caption_len=max_caption_len, 
        #     short_prob=short_prob, 
        #     load_vae_instead_of_image=load_vae_instead_of_image, 
        #     buffersize=args.iterable_data_buffersize,
        #     pn=args.pn,
        #     online_t5=args.online_t5,
        #     batch_size=args.batch_size,
        #     num_replicas=tdist.get_world_size(), # 1,
        #     rank=tdist.get_rank(), # 0
        #     dataloader_workers=args.workers,
        #     dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus,
        #     enable_dynamic_length_prompt=args.enable_dynamic_length_prompt,
        #     seed=args.seed if args.seed is not None else int(time.time()),
        # )        
        return WDSEditDataset(
            data_path, 
            buffersize=args.iterable_data_buffersize,
            pn=args.pn,
            batch_size=args.batch_size,
            num_replicas=tdist.get_world_size(), # 1,
            rank=tdist.get_rank(), # 0
            # dataloader_workers=args.workers,
            # dynamic_resolution_across_gpus=args.dynamic_resolution_across_gpus,
            # enable_dynamic_length_prompt=args.enable_dynamic_length_prompt,
            # seed=args.seed if args.seed is not None else int(time.time()),
        )
    else:
        raise ValueError(f'args.use_streaming_dataset={args.use_streaming_dataset} unsupported')


def pil_load(path: str, proposal_size):
    with open(path, 'rb') as f:
        img: PImage.Image = PImage.open(f)
        w: int = img.width
        h: int = img.height
        sh: int = min(h, w)
        if sh > proposal_size:
            ratio: float = proposal_size / sh
            w = round(ratio * w)
            h = round(ratio * h)
        img.draft('RGB', (w, h))
        img = img.convert('RGB')
    return img


def rewrite(im: PImage, file: str, info: str):
    kw = dict(quality=100)
    if file.lower().endswith('.tif') or file.lower().endswith('.tiff'):
        kw['compression'] = 'none'
    elif file.lower().endswith('.webp'):
        kw['lossless'] = True
    
    st = os.stat(file)
    uname = getpwuid(st.st_uid).pw_name
    gname = getgrgid(st.st_gid).gr_name
    mode = oct(st.st_mode)[-3:]
    
    local_file = osp.basename(file)
    im.save(local_file, **kw)
    print(f'************* <REWRITE: {info}> *************  @  {file}')
    subprocess.call(f'sudo mv {local_file} {file}; sudo chown {uname}:{gname} {file}; sudo chmod {mode} {file}', shell=True)