|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
from io import BytesIO |
|
import json |
|
import logging |
|
import base64 |
|
import random |
|
from typing import Callable, List, Tuple, Union, NamedTuple |
|
from PIL import Image |
|
from PIL import ImageFile |
|
import torch.utils.data as data |
|
from .languages.prompt_engineering import prompt_engineering |
|
from .tsv_file import TSVFile, CompositeTSVFile |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TSVDataset(data.Dataset): |
|
|
|
def __init__(self, |
|
tsv_file: Union[str, List[str]], |
|
transform: Callable = None, |
|
map_file: str = None, |
|
token_file: str = None, |
|
is_train: bool = True, |
|
azcopy_path: str = None): |
|
self.transform = transform |
|
self._chunk_sizes = None |
|
self.label2idx = self._load_map(map_file) |
|
self.class_selector = list(self.label2idx.keys()) if self.label2idx else None |
|
|
|
if isinstance(tsv_file, str): |
|
if os.path.splitext(tsv_file)[1] == '.tsv': |
|
self.tsv_file = TSVFile( |
|
tsv_file, class_selector=self.class_selector |
|
) |
|
else: |
|
self.tsv_file = CompositeTSVFile( |
|
tsv_file, |
|
class_selector=self.class_selector, |
|
is_train=is_train, |
|
sas_token_path=token_file, |
|
azcopy_path=azcopy_path |
|
) |
|
self._chunk_sizes = self.tsv_file.get_chunk_size() |
|
elif isinstance(tsv_file, list): |
|
self.tsv_file = CompositeTSVFile( |
|
tsv_file, |
|
class_selector=self.class_selector, |
|
is_train=is_train, |
|
sas_token_path=token_file, |
|
azcopy_path=azcopy_path |
|
) |
|
self._chunk_sizes = self.tsv_file.get_chunk_size() |
|
else: |
|
raise ValueError("Invalid input! Please check the tsv filenames") |
|
|
|
logger.debug('=> {}\titems: {}'.format(tsv_file, len(self.tsv_file))) |
|
|
|
def fetch_blob(self, idx): |
|
image_tsv = self.tsv_file.file_list[idx] |
|
self.tsv_file.blob_storage.fetch_blob(image_tsv) |
|
|
|
def num_classes(self): |
|
return len(self.class_selector) |
|
|
|
def get_chunk_sizes(self): |
|
return self._chunk_sizes |
|
|
|
def get_class_boundaries(self): |
|
|
|
|
|
return self.tsv_file.get_class_boundaries() |
|
|
|
def get_filenames(self): |
|
filenames = [ |
|
self.tsv_file.get_key(i) |
|
for i in range(self.tsv_file.num_rows()) |
|
] |
|
|
|
return filenames |
|
|
|
def _load_map(self, map_file: str): |
|
if not map_file: |
|
return None |
|
|
|
label2idx = {} |
|
with open(map_file) as f: |
|
for line in f: |
|
items = line.strip().split('\t') |
|
label2idx[items[0]] = int(items[1]) |
|
|
|
return label2idx |
|
|
|
def __getitem__(self, index: Union[int, Tuple[int, int]]): |
|
items = self.tsv_file[index] |
|
_, target, img = self._decode_data(items) |
|
|
|
if self.transform: |
|
img = self.transform(img) |
|
|
|
return img, target |
|
|
|
def _decode_data(self, items: Tuple[str, str, str]): |
|
key = items[0] |
|
label = self._get_label(items[1]) |
|
image = Image.open(BytesIO(base64.b64decode(items[2]))).convert('RGB') |
|
|
|
return key, label, image |
|
|
|
def _get_label(self, item: str): |
|
if not self.label2idx: |
|
return int(item) |
|
|
|
js = json.loads(item) |
|
return self.label2idx[js[0]['class']] |
|
|
|
def __len__(self): |
|
return len(self.tsv_file) |
|
|
|
|
|
class TSVMeta(NamedTuple): |
|
source: str |
|
num_classes: int |
|
task: str |
|
|
|
|
|
class TSVImageTextDatasetV2(data.Dataset): |
|
""" |
|
This class is intended for encapsulating Image/Text pair data for contrastive learning described in |
|
the following paper, |
|
"Learning Transferable Visual Models From Natural Language Supervision" (a.k.a CLIP) |
|
V2: support image text pairs and supervised classification data |
|
""" |
|
|
|
def __init__(self, |
|
image_tsv_file: Union[str, List[str]], |
|
text_tsv_file: Union[str, List[str]], |
|
transform: Callable = None, |
|
tokenize: Callable = None, |
|
context_length: int = 77, |
|
num_captions: int = 1, |
|
text_format: str = 'txt', |
|
is_train: bool = True, |
|
sas_token_path: str = None, |
|
azcopy_path: str = None, |
|
metas: List[NamedTuple] = None, |
|
prompt_engineering=True, |
|
concat_queries=False): |
|
self.transform = transform |
|
self.tokenize = tokenize |
|
self._chunk_sizes = None |
|
self.context_length = context_length |
|
self.num_captions = num_captions |
|
self.text_format = text_format |
|
self.tsv_file_list = [] |
|
self.metas = metas |
|
self.label_offsets = self.build_label_offsets() |
|
self.prompt_engineering = prompt_engineering |
|
self.concat_queries = concat_queries |
|
|
|
if isinstance(image_tsv_file, str) and isinstance(text_tsv_file, str): |
|
|
|
if ( |
|
os.path.splitext(image_tsv_file)[1].lower() == '.tsv' |
|
and os.path.splitext(text_tsv_file)[1].lower() == '.tsv' |
|
): |
|
self.tsv_file_list.append((image_tsv_file, text_tsv_file)) |
|
self.image_tsv_file = TSVFile( |
|
image_tsv_file, if_generate_lineidx=True |
|
) |
|
self.text_tsv_file = TSVFile( |
|
text_tsv_file, if_generate_lineidx=True |
|
) |
|
else: |
|
raise ValueError("Invalid input! Please check the tsv filenames.") |
|
|
|
elif ( |
|
isinstance(image_tsv_file, list) |
|
and isinstance(text_tsv_file, list) |
|
): |
|
assert len(image_tsv_file) == len(text_tsv_file), \ |
|
"Inconsistent number of Image/Text tsv files!" |
|
self.tsv_file_list = [ |
|
(txt, img) |
|
for img, txt in zip(image_tsv_file, text_tsv_file) |
|
] |
|
self.image_tsv_file = CompositeTSVFile( |
|
image_tsv_file, |
|
is_train=is_train, |
|
sas_token_path=sas_token_path, |
|
azcopy_path=azcopy_path |
|
) |
|
self.text_tsv_file = CompositeTSVFile( |
|
text_tsv_file, |
|
is_train=is_train, |
|
sas_token_path=sas_token_path, |
|
azcopy_path=azcopy_path |
|
) |
|
self._chunk_sizes = self.image_tsv_file.get_chunk_size() |
|
else: |
|
raise ValueError("Invalid input! Please check the tsv filenames.") |
|
|
|
assert len(self.image_tsv_file) == len(self.text_tsv_file), \ |
|
"Inconsistent size of Image/Text ({}/{}) data!".format( |
|
len(self.image_tsv_file), len(self.text_tsv_file) |
|
) |
|
|
|
def build_label_offsets(self): |
|
if self.metas is None: |
|
return None |
|
|
|
label_offsets = {} |
|
offset = 1 |
|
for meta in self.metas: |
|
print(meta) |
|
print(label_offsets) |
|
label_offsets[meta.source] = offset |
|
offset += meta.num_classes |
|
|
|
return label_offsets |
|
|
|
def fetch_blob(self, idx): |
|
|
|
image_tsv = self.image_tsv_file.file_list[idx] |
|
text_tsv = self.text_tsv_file.file_list[idx] |
|
self.image_tsv_file.blob_storage.fetch_blob(image_tsv) |
|
self.text_tsv_file.blob_storage.fetch_blob(text_tsv) |
|
|
|
def get_chunk_sizes(self): |
|
return self._chunk_sizes |
|
|
|
def __getitem__(self, index: Union[int, Tuple[int, int]]): |
|
if index is None: |
|
import torch |
|
return torch.tensor([], dtype=torch.float32), \ |
|
torch.tensor([], dtype=torch.int64), \ |
|
torch.tensor([], dtype=torch.int64) |
|
|
|
items_image = self.image_tsv_file[index] |
|
items_text = self.text_tsv_file[index] |
|
|
|
assert items_text[0] == items_image[0], \ |
|
'keys do not match for image and text {} vs {}'.format( |
|
items_text[0], items_image[0] |
|
) |
|
|
|
_, img = self._decode_image(items_image) |
|
_, txt, label = self._decode_text(items_text) |
|
|
|
if self.transform: |
|
img = self.transform(img) |
|
|
|
tokens = self.tokenize( |
|
txt, padding='max_length', truncation=True, max_length=self.context_length, |
|
return_tensors='pt' |
|
) if self.tokenize else txt |
|
|
|
tokens['input_ids'].squeeze_() |
|
tokens['attention_mask'].squeeze_() |
|
|
|
return img, tokens, label |
|
|
|
def _decode_image(self, items: Tuple[str, str]): |
|
key = items[0] |
|
image = Image.open(BytesIO(base64.b64decode(items[1]))).convert('RGB') |
|
|
|
return key, image |
|
|
|
def _decode_text(self, items: Tuple[str, Union[str, dict]]): |
|
key = items[0] |
|
text = '' |
|
|
|
if self.text_format != 'json': |
|
raise ValueError('Only support json format') |
|
|
|
|
|
try: |
|
js = json.loads(items[1]) |
|
except Exception as e: |
|
|
|
|
|
js = {} |
|
|
|
|
|
logger.info("JSON parsing error on: " + items[1]) |
|
logger.info(str(e)) |
|
|
|
|
|
|
|
|
|
|
|
sstr = items[1].find("\"") |
|
if (sstr < 0): |
|
sstr = 0 |
|
|
|
estr = items[1][sstr:].find("\"") |
|
if (estr < 0): |
|
estr = len(items[1]) |
|
|
|
text = items[1][sstr:estr] |
|
if (len(text) < 2): |
|
text = "A picture showing some content." |
|
|
|
label = 0 |
|
|
|
if 'captions' in js: |
|
captions = js['captions'] |
|
if isinstance(captions, list): |
|
if self.num_captions == 1: |
|
text = random.choice(captions) |
|
else: |
|
text = captions |
|
if len(captions) > self.num_captions: |
|
text = captions[:self.num_captions] |
|
elif isinstance(captions, str): |
|
text = captions |
|
else: |
|
raise ValueError('captions should be str or list') |
|
label = 0 |
|
elif 'tags' in js: |
|
text = prompt_engineering(js['tags']) |
|
label = 0 |
|
elif 'task' in js and js['task'] == 'classification': |
|
if (self.prompt_engineering): |
|
text = prompt_engineering(js['class_name']) |
|
else: |
|
text = js['class_name'] |
|
label = js['class_id'] |
|
|
|
if (self.label_offsets is not None): |
|
if (js['source'] in self.label_offsets): |
|
label += self.label_offsets[js['source']] |
|
|
|
if (self.concat_queries): |
|
if ('queries' in js) and (len(js['queries']) > 0): |
|
q = '' |
|
for item in js['queries']: |
|
q = q + item + ' ' |
|
|
|
text = q + ', ' + text |
|
|
|
return key, text, label |
|
|
|
def __len__(self): |
|
return len(self.image_tsv_file) |
|
|