Spaces:
Paused
Paused
from __future__ import annotations | |
import json | |
import math | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
import torch | |
import torchvision | |
from einops import rearrange | |
from PIL import Image | |
from torch.utils.data import Dataset | |
class EditDataset(Dataset): | |
def __init__( | |
self, | |
path: str, | |
split: str = "train", | |
splits: tuple[float, float, float] = (0.9, 0.05, 0.05), | |
min_resize_res: int = 256, | |
max_resize_res: int = 256, | |
crop_res: int = 256, | |
flip_prob: float = 0.0, | |
): | |
assert split in ("train", "val", "test") | |
assert sum(splits) == 1 | |
self.path = path | |
self.min_resize_res = min_resize_res | |
self.max_resize_res = max_resize_res | |
self.crop_res = crop_res | |
self.flip_prob = flip_prob | |
with open(Path(self.path, "seeds.json")) as f: | |
self.seeds = json.load(f) | |
split_0, split_1 = { | |
"train": (0.0, splits[0]), | |
"val": (splits[0], splits[0] + splits[1]), | |
"test": (splits[0] + splits[1], 1.0), | |
}[split] | |
idx_0 = math.floor(split_0 * len(self.seeds)) | |
idx_1 = math.floor(split_1 * len(self.seeds)) | |
self.seeds = self.seeds[idx_0:idx_1] | |
def __len__(self) -> int: | |
return len(self.seeds) | |
def __getitem__(self, i: int) -> dict[str, Any]: | |
name, seeds = self.seeds[i] | |
propt_dir = Path(self.path, name) | |
seed = seeds[torch.randint(0, len(seeds), ()).item()] | |
with open(propt_dir.joinpath("prompt.json")) as fp: | |
prompt = json.load(fp)["edit"] | |
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) | |
image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg")) | |
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() | |
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) | |
image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS) | |
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) | |