|
from abc import ABC, abstractmethod |
|
|
|
import gradio as gr |
|
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp |
|
from mammal.model import Mammal |
|
|
|
|
|
class MammalObjectBroker: |
|
def __init__( |
|
self, |
|
model_path: str, |
|
name: str | None = None, |
|
task_list: list[str] | None = None, |
|
) -> None: |
|
self.model_path = model_path |
|
if name is None: |
|
name = model_path |
|
self.name = name |
|
|
|
self.tasks: list[str] = [] |
|
if task_list is not None: |
|
self.tasks = task_list |
|
self._model: Mammal | None = None |
|
self._tokenizer_op = None |
|
|
|
@property |
|
def model(self) -> Mammal: |
|
if self._model is None: |
|
self._model = Mammal.from_pretrained(self.model_path) |
|
self._model.eval() |
|
return self._model |
|
|
|
@property |
|
def tokenizer_op(self): |
|
if self._tokenizer_op is None: |
|
self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path) |
|
return self._tokenizer_op |
|
|
|
|
|
class MammalTask(ABC): |
|
def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None: |
|
self.name = name |
|
self.description = None |
|
self._demo = None |
|
self.model_dict = model_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
def crate_sample_dict( |
|
self, sample_inputs: dict, model_holder: MammalObjectBroker |
|
) -> dict: |
|
"""Formatting prompt to match pre-training syntax |
|
|
|
Args: |
|
prompt (str): _description_ |
|
|
|
Returns: |
|
dict: sample_dict for feeding into model |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
def run_model(self, sample_dict, model: Mammal): |
|
raise NotImplementedError() |
|
|
|
def create_demo(self, model_name_widget: gr.component) -> gr.Group: |
|
"""create an gradio demo group |
|
|
|
Args: |
|
model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create |
|
gradio actions with the current model name as an input |
|
|
|
|
|
Raises: |
|
NotImplementedError: _description_ |
|
""" |
|
raise NotImplementedError() |
|
|
|
def demo(self, model_name_widgit: gr.component = None): |
|
if self._demo is None: |
|
self._demo = self.create_demo(model_name_widget=model_name_widgit) |
|
return self._demo |
|
|
|
@abstractmethod |
|
def decode_output(self, batch_dict, model: Mammal) -> list: |
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TaskRegistry(dict[str, MammalTask]): |
|
"""just a dictionary with a register method""" |
|
|
|
def register_task(self, task: MammalTask): |
|
self[task.name] = task |
|
return task.name |
|
|
|
|
|
class ModelRegistry(dict[str, MammalObjectBroker]): |
|
"""just a dictionary with a register models""" |
|
|
|
def register_model(self, model_path, task_list=None, name=None): |
|
"""register a model and return the name of the model |
|
Args: |
|
model_path (_type_): _description_ |
|
name (optional str): explicit name for the model |
|
|
|
Returns: |
|
str: model name |
|
""" |
|
model_holder = MammalObjectBroker( |
|
model_path=model_path, task_list=task_list, name=name |
|
) |
|
self[model_holder.name] = model_holder |
|
return model_holder.name |
|
|