language-identifier / modeling_lang.py
Gleb Vinarskis
added emas pipeline
ae8276c
raw
history blame contribute delete
1.98 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel
import logging
import floret
import os
from huggingface_hub import hf_hub_download
from .configuration_lang import ImpressoConfig
logger = logging.getLogger(__name__)
class LangDetectorModel(PreTrainedModel):
config_class = ImpressoConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Dummy for device checking
self.dummy_param = nn.Parameter(torch.zeros(1))
bin_filename = self.config.config.filename
# Check if the file is already present locally, else download it
if not os.path.exists(bin_filename):
# print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
bin_filename = hf_hub_download(repo_id=self.config.config._name_or_path,
filename=bin_filename)
# Load floret model using the full path
self.model_floret = floret.load_model(bin_filename)
def forward(self, input_ids, **kwargs):
if isinstance(input_ids, str):
# If the input is a single string, make it a list for floret
texts = [input_ids]
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
texts = input_ids
else:
raise ValueError(f"Unexpected input type: {type(input_ids)}")
predictions, probabilities = self.model_floret.predict(texts, k=1)
return (
predictions,
probabilities,
)
@property
def device(self):
return next(self.parameters()).device
@classmethod
def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
# Manually create the config
config = ImpressoConfig(**kwargs)
# Pass the manually created config to the class
model = cls(config)
return model