import torch from transformers import T5Tokenizer class T5TextConditionProcessor: def __init__(self, tokens_length, processor_path): self.tokens_length = tokens_length self.processor = T5Tokenizer.from_pretrained(processor_path) def encode(self, text=None, negative_text=None): encoded = self.processor(text, max_length=self.tokens_length, truncation=True) pad_length = self.tokens_length - len(encoded['input_ids']) input_ids = encoded['input_ids'] + [self.processor.pad_token_id] * pad_length attention_mask = encoded['attention_mask'] + [0] * pad_length condition_model_input = { 'input_ids': torch.tensor(input_ids, dtype=torch.long), 'attention_mask': torch.tensor(attention_mask, dtype=torch.long) } if negative_text is not None: negative_encoded = self.processor(negative_text, max_length=self.tokens_length, truncation=True) negative_input_ids = negative_encoded['input_ids'][:len(encoded['input_ids'])] negative_input_ids[-1] = self.processor.eos_token_id negative_pad_length = self.tokens_length - len(negative_input_ids) negative_input_ids = negative_input_ids + [self.processor.pad_token_id] * negative_pad_length negative_attention_mask = encoded['attention_mask'] + [0] * pad_length negative_condition_model_input = { 'input_ids': torch.tensor(negative_input_ids, dtype=torch.long), 'attention_mask': torch.tensor(negative_attention_mask, dtype=torch.long) } else: negative_condition_model_input = None return condition_model_input, negative_condition_model_input