language-detection-pipeline / modeling_lang.py
Gleb Vinarskis
initial commit adding files
8473922
import torch
import torch.nn as nn
from transformers import PreTrainedModel
import logging
import floret
from .configuration_lang import ImpressoConfig
logger = logging.getLogger(__name__)
class LangDetectorModel(PreTrainedModel):
config_class = ImpressoConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Dummy for device checking
self.dummy_param = nn.Parameter(torch.zeros(1))
# Load floret model
self.model_floret = floret.load_model(self.config.config.filename)
#
def forward(self, input_ids, **kwargs):
if isinstance(input_ids, str):
# If the input is a single string, make it a list for floret
texts = [input_ids]
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
texts = input_ids
else:
raise ValueError(f"Unexpected input type: {type(input_ids)}")
predictions, probabilities = self.model_floret.predict(texts, k=1)
return (
predictions,
probabilities,
)
@property
def device(self):
return next(self.parameters()).device
@classmethod
def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
# Manually create the config
config = ImpressoConfig(**kwargs)
# Pass the manually created config to the class
model = cls(config)
return model