megalaa's picture
Upload 11 files
d33d554
raw
history blame
6.09 kB
from dataclasses import dataclass
import logging
import os
from abc import ABC
from typing import Optional
import torch
import json
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
MAX_TOKEN_LENGTH_ERR = {
"code": 422,
"type" : "MaxTokenLengthError",
"message": "Max token length exceeded",
}
class CopEngHandler(BaseHandler, ABC):
@dataclass
class GenerationConfig:
max_length: int = 20
max_new_tokens: Optional[int] = None
min_length: int = 0
min_new_tokens: Optional[int] = None
early_stopping: bool = True
do_sample: bool = False
num_beams: int = 1
num_beam_groups: int = 1
top_k: int = 50
top_p: float = 0.95
temperature: float = 1.0
diversity_penalty: float = 0.0
def __init__(self):
super(CopEngHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
"""In this initialize function, the HF large model is loaded and
partitioned using DeepSpeed.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
logger.info("Start initialize")
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
setup_config_path = os.path.join(model_dir, "setup_self.config.json")
if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_path:
self.setup_config = json.load(setup_config_path)
seed = int(42)
torch.manual_seed(seed)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Device: %s", self.device)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.config = CopEngHandler.GenerationConfig(
max_new_tokens=128,
min_new_tokens=1,
num_beams=5,
)
self.initialized = True
logger.info("Init done")
def preprocess(self, requests):
preprocessed_data = []
for data in requests:
data_item = data.get("data")
if data_item is None:
data_item = data.get("body")
if isinstance(data_item, (bytes, bytearray)):
data_item = data_item.decode("utf-8")
preprocessed_data.append(greekify(data_item))
logger.info("preprocessed_data %s: ", preprocessed_data)
return preprocessed_data
def inference(self, data):
indices = {}
batch = []
for i, item in enumerate(data):
tokens = self.tokenizer(item, return_tensors="pt", padding=True)
if len(tokens.input_ids.squeeze()) > self.tokenizer.model_max_length:
logger.info("Skipping token %s for index %s", tokens, i)
continue
indices[i] = len(batch)
batch.append(data[i])
logger.info("inference batch: %s", batch)
result = self.batch_translate(batch)
return [result[indices[i]] if i in indices else None for i in range(len(data))]
def postprocess(self, output):
return output
def handle(self, requests, context):
preprocessed = self.preprocess(requests)
inference_data = self.inference(preprocessed)
postprocessed = self.postprocess(inference_data)
logger.info("inference result: %s", postprocessed)
responses = [
{"code": 200, "translation": translation}
if translation
else MAX_TOKEN_LENGTH_ERR
for translation in postprocessed
]
return responses
def batch_translate(self, input_sentences, output_confidence=False):
if len(input_sentences) == 0:
return []
inputs = self.tokenizer(input_sentences, return_tensors="pt", padding=True).to(
self.device
)
output_scores, return_dict_in_generate = output_confidence, output_confidence
outputs = self.model.generate(
**inputs,
max_length=self.config.max_length,
max_new_tokens=self.config.max_new_tokens,
min_length=self.config.min_length,
min_new_tokens=self.config.min_new_tokens,
early_stopping=self.config.early_stopping,
do_sample=self.config.do_sample,
num_beams=self.config.num_beams,
num_beam_groups=self.config.num_beam_groups,
top_k=self.config.top_k,
top_p=self.config.top_p,
temperature=self.config.temperature,
diversity_penalty=self.config.diversity_penalty,
output_scores=output_scores,
return_dict_in_generate=True,
)
translated_text = self.tokenizer.batch_decode(
outputs.sequences, skip_special_tokens=True
)
return translated_text
COPTIC_TO_GREEK = {
"ⲁ": "α",
"ⲃ": "β",
"ⲅ": "γ",
"ⲇ": "δ",
"ⲉ": "ε",
"ⲋ": "ϛ",
"ⲍ": "ζ",
"ⲏ": "η",
"ⲑ": "θ",
"ⲓ": "ι",
"ⲕ": "κ",
"ⲗ": "λ",
"ⲙ": "μ",
"ⲛ": "ν",
"ⲝ": "ξ",
"ⲟ": "ο",
"ⲡ": "π",
"ⲣ": "ρ",
"ⲥ": "σ",
"ⲧ": "τ",
"ⲩ": "υ",
"ⲫ": "φ",
"ⲭ": "χ",
"ⲯ": "ψ",
"ⲱ": "ω",
"ϣ": "s",
"ϥ": "f",
"ϧ": "k",
"ϩ": "h",
"ϫ": "j",
"ϭ": "c",
"ϯ": "t",
}
def greekify(coptic_text):
chars = []
for c in coptic_text:
l_c = c.lower()
chars.append(COPTIC_TO_GREEK.get(l_c, l_c))
return "".join(chars)