from dataclasses import dataclass import logging import os from abc import ABC from typing import Optional import torch import json from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, ) from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) MAX_TOKEN_LENGTH_ERR = { "code": 422, "type" : "MaxTokenLengthError", "message": "Max token length exceeded", } class CopEngHandler(BaseHandler, ABC): @dataclass class GenerationConfig: max_length: int = 20 max_new_tokens: Optional[int] = None min_length: int = 0 min_new_tokens: Optional[int] = None early_stopping: bool = True do_sample: bool = False num_beams: int = 1 num_beam_groups: int = 1 top_k: int = 50 top_p: float = 0.95 temperature: float = 1.0 diversity_penalty: float = 0.0 def __init__(self): super(CopEngHandler, self).__init__() self.initialized = False def initialize(self, ctx): """In this initialize function, the HF large model is loaded and partitioned using DeepSpeed. Args: ctx (context): It is a JSON Object containing information pertaining to the model artifacts parameters. """ logger.info("Start initialize") self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") serialized_file = self.manifest["model"]["serializedFile"] model_pt_path = os.path.join(model_dir, serialized_file) setup_config_path = os.path.join(model_dir, "setup_self.config.json") if os.path.isfile(setup_config_path): with open(setup_config_path) as setup_config_path: self.setup_config = json.load(setup_config_path) seed = int(42) torch.manual_seed(seed) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info("Device: %s", self.device) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) self.model.to(self.device) self.model.eval() self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.config = CopEngHandler.GenerationConfig( max_new_tokens=128, min_new_tokens=1, num_beams=5, ) self.initialized = True logger.info("Init done") def preprocess(self, requests): preprocessed_data = [] for data in requests: data_item = data.get("data") if data_item is None: data_item = data.get("body") if isinstance(data_item, (bytes, bytearray)): data_item = data_item.decode("utf-8") preprocessed_data.append(greekify(data_item)) logger.info("preprocessed_data %s: ", preprocessed_data) return preprocessed_data def inference(self, data): indices = {} batch = [] for i, item in enumerate(data): tokens = self.tokenizer(item, return_tensors="pt", padding=True) if len(tokens.input_ids.squeeze()) > self.tokenizer.model_max_length: logger.info("Skipping token %s for index %s", tokens, i) continue indices[i] = len(batch) batch.append(data[i]) logger.info("inference batch: %s", batch) result = self.batch_translate(batch) return [result[indices[i]] if i in indices else None for i in range(len(data))] def postprocess(self, output): return output def handle(self, requests, context): preprocessed = self.preprocess(requests) inference_data = self.inference(preprocessed) postprocessed = self.postprocess(inference_data) logger.info("inference result: %s", postprocessed) responses = [ {"code": 200, "translation": translation} if translation else MAX_TOKEN_LENGTH_ERR for translation in postprocessed ] return responses def batch_translate(self, input_sentences, output_confidence=False): if len(input_sentences) == 0: return [] inputs = self.tokenizer(input_sentences, return_tensors="pt", padding=True).to( self.device ) output_scores, return_dict_in_generate = output_confidence, output_confidence outputs = self.model.generate( **inputs, max_length=self.config.max_length, max_new_tokens=self.config.max_new_tokens, min_length=self.config.min_length, min_new_tokens=self.config.min_new_tokens, early_stopping=self.config.early_stopping, do_sample=self.config.do_sample, num_beams=self.config.num_beams, num_beam_groups=self.config.num_beam_groups, top_k=self.config.top_k, top_p=self.config.top_p, temperature=self.config.temperature, diversity_penalty=self.config.diversity_penalty, output_scores=output_scores, return_dict_in_generate=True, ) translated_text = self.tokenizer.batch_decode( outputs.sequences, skip_special_tokens=True ) return translated_text COPTIC_TO_GREEK = { "ⲁ": "α", "ⲃ": "β", "ⲅ": "γ", "ⲇ": "δ", "ⲉ": "ε", "ⲋ": "ϛ", "ⲍ": "ζ", "ⲏ": "η", "ⲑ": "θ", "ⲓ": "ι", "ⲕ": "κ", "ⲗ": "λ", "ⲙ": "μ", "ⲛ": "ν", "ⲝ": "ξ", "ⲟ": "ο", "ⲡ": "π", "ⲣ": "ρ", "ⲥ": "σ", "ⲧ": "τ", "ⲩ": "υ", "ⲫ": "φ", "ⲭ": "χ", "ⲯ": "ψ", "ⲱ": "ω", "ϣ": "s", "ϥ": "f", "ϧ": "k", "ϩ": "h", "ϫ": "j", "ϭ": "c", "ϯ": "t", } def greekify(coptic_text): chars = [] for c in coptic_text: l_c = c.lower() chars.append(COPTIC_TO_GREEK.get(l_c, l_c)) return "".join(chars)