Spaces:
Build error
Build error
| import re | |
| from m4.models.vbloom.configuration_vbloom import VBloomConfig | |
| from m4.models.vbloom.modeling_vbloom import VBloomForCausalLM | |
| from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config | |
| from m4.models.vgpt2.modeling_vgpt2 import VGPT2LMHeadModel | |
| from m4.models.vgpt_neo.configuration_vgpt_neo import VGPTNeoConfig | |
| from m4.models.vgpt_neo.modeling_vgpt_neo import VGPTNeoForCausalLM | |
| from m4.models.vllama.configuration_vllama import VLlamaConfig | |
| from m4.models.vllama.modeling_vllama import VLlamaForCausalLM | |
| from m4.models.vopt.configuration_vopt import VOPTConfig | |
| from m4.models.vopt.modeling_vopt import VOPTForCausalLM | |
| from m4.models.vt5.configuration_vt5 import VT5Config | |
| from m4.models.vt5.modeling_vt5 import VT5ForConditionalGeneration | |
| model_name2classes = { | |
| r"bloom|bigscience-small-testing": [VBloomConfig, VBloomForCausalLM], | |
| r"gpt-neo|gptneo": [VGPTNeoConfig, VGPTNeoForCausalLM], | |
| r"gpt2": [VGPT2Config, VGPT2LMHeadModel], | |
| r"opt": [VOPTConfig, VOPTForCausalLM], | |
| r"t5": [VT5Config, VT5ForConditionalGeneration], | |
| r"llama": [VLlamaConfig, VLlamaForCausalLM], | |
| } | |
| def model_name_to_classes(model_name_or_path): | |
| """returns config_class, model_class for a given model name or path""" | |
| model_name_lowcase = model_name_or_path.lower() | |
| for rx, classes in model_name2classes.items(): | |
| if re.search(rx, model_name_lowcase): | |
| return classes | |
| else: | |
| raise ValueError( | |
| f"Unknown type of backbone LM. Got {model_name_or_path}, supported regexes:" | |
| f" {list(model_name2classes.keys())}." | |
| ) | |