| from typing import Dict, Any, Tuple | |
| from torch import nn | |
| from .build_shikra import load_pretrained_shikra | |
| PREPROCESSOR = Dict[str, Any] | |
| # TODO: Registry | |
| def load_pretrained(model_args, training_args) -> Tuple[nn.Module, PREPROCESSOR]: | |
| type_ = model_args.type | |
| if type_ == 'shikra': | |
| return load_pretrained_shikra(model_args, training_args) | |
| else: | |
| assert False | |