Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| from pie_modules.models import SimpleSequenceClassificationModel | |
| from pie_modules.models.simple_sequence_classification import InputType, OutputType, TargetType | |
| from pytorch_ie import PyTorchIEModel | |
| from torch import nn | |
| from transformers import BertModel | |
| from transformers.utils import is_accelerate_available | |
| if is_accelerate_available(): | |
| from accelerate.hooks import add_hook_to_module | |
| class SimpleSequenceClassificationModelWithInputTypeIds(SimpleSequenceClassificationModel): | |
| def __init__( | |
| self, num_token_type_ids: int, use_as_token_type_ids: str = "token_type_ids", **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.num_token_type_ids = num_token_type_ids | |
| self.token_type_ids_key = use_as_token_type_ids | |
| self.resize_type_embeddings(num_token_type_ids) | |
| def get_input_type_embeddings(self) -> nn.Module: | |
| base_model: BertModel = getattr(self.model, self.model.base_model_prefix) | |
| if base_model is None: | |
| raise ValueError("Model has no base model.") | |
| return base_model.embeddings.token_type_embeddings | |
| def set_input_type_embeddings(self, value): | |
| base_model: BertModel = getattr(self.model, self.model.base_model_prefix) | |
| if base_model is None: | |
| raise ValueError("Model has no base model.") | |
| base_model.embeddings.token_type_embeddings = value | |
| def _resize_type_embeddings(self, new_num_tokens, pad_to_multiple_of=None): | |
| old_embeddings = self.get_input_type_embeddings() | |
| new_embeddings = self.model._get_resized_embeddings( | |
| old_embeddings, new_num_tokens, pad_to_multiple_of | |
| ) | |
| if hasattr(old_embeddings, "_hf_hook"): | |
| hook = old_embeddings._hf_hook | |
| add_hook_to_module(new_embeddings, hook) | |
| old_embeddings_requires_grad = old_embeddings.weight.requires_grad | |
| new_embeddings.requires_grad_(old_embeddings_requires_grad) | |
| self.set_input_type_embeddings(new_embeddings) | |
| return self.get_input_type_embeddings() | |
| def resize_type_embeddings( | |
| self, new_num_types: Optional[int] = None, pad_to_multiple_of: Optional[int] = None | |
| ) -> nn.Embedding: | |
| """ | |
| Same as resize_token_embeddings but for the token type embeddings. | |
| Resizes input token type embeddings matrix of the model if `new_num_types != config.type_vocab_size`. | |
| Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. | |
| Arguments: | |
| new_num_types (`int`, *optional*): | |
| The number of new token types in the embedding matrix. Increasing the size will add newly initialized | |
| vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just | |
| returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. | |
| pad_to_multiple_of (`int`, *optional*): | |
| If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to | |
| `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. | |
| This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability | |
| `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more | |
| details about this, or help on choosing the correct value for resizing, refer to this guide: | |
| https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc | |
| Return: | |
| `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. | |
| """ | |
| model_embeds = self._resize_type_embeddings(new_num_types, pad_to_multiple_of) | |
| if new_num_types is None and pad_to_multiple_of is None: | |
| return model_embeds | |
| # Update base model and current model config | |
| self.model.config.type_vocab_size = model_embeds.weight.shape[0] | |
| # Tie weights again if needed | |
| self.model.tie_weights() | |
| return model_embeds | |
| def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType: | |
| kwargs = {**inputs, **(targets or {})} | |
| # rename key to input_type_ids | |
| kwargs["token_type_ids"] = kwargs.pop(self.token_type_ids_key) | |
| return self.model(**kwargs) | |