# -*- coding: utf-8 -*- # file: omnigenome_wrapper.py # time: 14:30 27/10/2024 # author: YANG, HENG (杨恒) # github: https://github.com/yangheng95 # huggingface: https://huggingface.co/yangheng # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en # Copyright (C) 2019-2024. All Rights Reserved. import itertools import warnings import torch from ViennaRNA import ViennaRNA from transformers import AutoTokenizer from omnigenome import OmniSingleNucleotideTokenizer from omnigenome.src.misc.utils import RNA2StructureCache class Tokenizer(OmniSingleNucleotideTokenizer): def __init__( self, base_tokenizer=None, **kwargs): super(Tokenizer, self).__init__(base_tokenizer, **kwargs) self.metadata["tokenizer_name"] = self.__class__.__name__ self.rna2str = RNA2StructureCache() bases = [4, 5, 6, 7] triplet_combinations = list(itertools.product(bases, repeat=3)) kmer_to_index = {tuple(triplet): i for i, triplet in enumerate(triplet_combinations)} def process_input_ids(self, input_ids, k=3): kmer_input_ids = [64] for i in range(len(input_ids) - k + 1): kmer = tuple(input_ids[i:i + k].tolist()) kmer_input_ids.append(self.kmer_to_index.get(kmer, 64)) kmer_input_ids.append(64) return torch.tensor(kmer_input_ids) def __call__(self, sequence, **kwargs): sequence = sequence.replace("U", "T") structure, mfe = self.rna2str.fold(sequence, return_mfe=True) structure_inputs = self.base_tokenizer(structure, **kwargs) tokenized_inputs = self.base_tokenizer(sequence, **kwargs) kmer_ids = self.process_input_ids(tokenized_inputs['input_ids'][0], k=3) tokenized_inputs["kmer_ids"] = kmer_ids.unsqueeze(0) tokenized_inputs["str_ids"] = structure_inputs["input_ids"] return tokenized_inputs @staticmethod def from_pretrained(model_name_or_path, **kwargs): self = OmniSingleNucleotideTokenizer( AutoTokenizer.from_pretrained(model_name_or_path, **kwargs) ) return self def tokenize(self, sequence, **kwargs): if isinstance(sequence, str): sequences = [sequence] else: sequences = sequence sequence_tokens = [] for i in range(len(sequences)): tokens = [] for j in range(0, len(sequences[i]), self.k - self.overlap): tokens.append(sequences[i][j : j + self.k]) sequence_tokens.append(tokens) return sequence_tokens def encode(self, input_ids, **kwargs): return self.base_tokenizer.encode(input_ids, **kwargs) def decode(self, input_ids, **kwargs): return self.base_tokenizer.decode(input_ids, **kwargs) def encode_plus(self, sequence, **kwargs): raise NotImplementedError("The encode_plus() function is not implemented yet.")