Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 3,699 Bytes
			
			| 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a 19dfa7a 93d0d1a fda141d 93d0d1a 19dfa7a 93d0d1a fda141d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | from abc import ABC, abstractmethod
import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.model import Mammal
class MammalObjectBroker:
    def __init__(
        self,
        model_path: str,
        name: str | None = None,
        task_list: list[str] | None = None,
    ) -> None:
        self.model_path = model_path
        if name is None:
            name = model_path
        self.name = name
        self.tasks: list[str] = []
        if task_list is not None:
            self.tasks = task_list
        self._model: Mammal | None = None
        self._tokenizer_op = None
    @property
    def model(self) -> Mammal:
        if self._model is None:
            self._model = Mammal.from_pretrained(self.model_path)
        self._model.eval()
        return self._model
    @property
    def tokenizer_op(self):
        if self._tokenizer_op is None:
            self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
        return self._tokenizer_op
class MammalTask(ABC):
    def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
        self.name = name
        self.description = None
        self._demo = None
        self.model_dict = model_dict
    # @abstractmethod
    # def _generate_prompt(self, **kwargs) -> str:
    #     """Formatting prompt to match pre-training syntax
    #     Args:
    #         prot1 (_type_): _description_
    #         prot2 (_type_): _description_
    #     Raises:
    #         No: _description_
    #     """
    #     raise NotImplementedError()
    @abstractmethod
    def crate_sample_dict(
        self, sample_inputs: dict, model_holder: MammalObjectBroker
    ) -> dict:
        """Formatting prompt to match pre-training syntax
        Args:
            prompt (str): _description_
        Returns:
            dict: sample_dict for feeding into model
        """
        raise NotImplementedError()
    # @abstractmethod
    def run_model(self, sample_dict, model: Mammal):
        raise NotImplementedError()
    def create_demo(self, model_name_widget: gr.component) -> gr.Group:
        """create an gradio demo group
        Args:
            model_name_widgit (gr.Component): widget holding the model name to use.  This is needed to create
                gradio actions with the current model name as an input
        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError()
    def demo(self, model_name_widgit: gr.component = None):
        if self._demo is None:
            self._demo = self.create_demo(model_name_widget=model_name_widgit)
        return self._demo
    @abstractmethod
    def decode_output(self, batch_dict, model: Mammal) -> list:
        raise NotImplementedError()
    # self._setup()
    # def _setup(self):
    #     pass
class TaskRegistry(dict[str, MammalTask]):
    """just a dictionary with a register method"""
    def register_task(self, task: MammalTask):
        self[task.name] = task
        return task.name
class ModelRegistry(dict[str, MammalObjectBroker]):
    """just a dictionary with a register models"""
    def register_model(self, model_path, task_list=None, name=None):
        """register a model and return the name of the model
        Args:
            model_path (_type_): _description_
            name (optional str): explicit name for the model
        Returns:
            str: model name
        """
        model_holder = MammalObjectBroker(
            model_path=model_path, task_list=task_list, name=name
        )
        self[model_holder.name] = model_holder
        return model_holder.name
 | 
