import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel | |
import logging | |
import floret | |
from .configuration_lang import ImpressoConfig | |
logger = logging.getLogger(__name__) | |
class LangDetectorModel(PreTrainedModel): | |
config_class = ImpressoConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
# Dummy for device checking | |
self.dummy_param = nn.Parameter(torch.zeros(1)) | |
# Load floret model | |
self.model_floret = floret.load_model(self.config.config.filename) | |
# | |
def forward(self, input_ids, **kwargs): | |
if isinstance(input_ids, str): | |
# If the input is a single string, make it a list for floret | |
texts = [input_ids] | |
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): | |
texts = input_ids | |
else: | |
raise ValueError(f"Unexpected input type: {type(input_ids)}") | |
predictions, probabilities = self.model_floret.predict(texts, k=1) | |
return ( | |
predictions, | |
probabilities, | |
) | |
def device(self): | |
return next(self.parameters()).device | |
def from_pretrained(cls, *args, **kwargs): | |
# print("Ignoring weights and using custom initialization.") | |
# Manually create the config | |
config = ImpressoConfig(**kwargs) | |
# Pass the manually created config to the class | |
model = cls(config) | |
return model | |