Spaces:
Paused
Paused
| from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig | |
| from starvector.data.base import ImageTrainProcessor | |
| from starvector.util import dtype_mapping | |
| from transformers import AutoConfig | |
| def load_pretrained_model(model_path, device="cuda", **kwargs): | |
| model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device) | |
| tokenizer = model.model.svg_transformer.tokenizer | |
| image_processor = ImageTrainProcessor() | |
| context_len = model.model.query_length + model.model.max_length | |
| return tokenizer, model, image_processor, context_len | |
| def model_builder(config): | |
| model_name = config.model.get("model_name", False) | |
| args = { | |
| "task": config.model.task, | |
| "train_image_encoder": config.training.train_image_encoder, | |
| "ignore_mismatched_sizes": True, | |
| "starcoder_model_name": config.model.starcoder_model_name, | |
| "train_LLM": config.training.train_LLM, | |
| "torch_dtype": dtype_mapping[config.training.model_precision], | |
| "transformer_layer_cls": config.model.get("transformer_layer_cls", False), | |
| "use_cache": config.model.use_cache, | |
| } | |
| if model_name: | |
| model = StarVectorForCausalLM.from_pretrained(model_name, **args) | |
| else: | |
| starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name) | |
| starvector_config = StarVectorConfig( | |
| max_length_train=config.model.max_length, | |
| image_encoder_type=config.model.image_encoder_type, | |
| use_flash_attn=config.model.use_flash_attn, | |
| adapter_norm=config.model.adapter_norm, | |
| starcoder_model_name=config.model.starcoder_model_name, | |
| torch_dtype=dtype_mapping[config.training.model_precision], | |
| num_attention_heads=starcoder_model_config.num_attention_heads, | |
| num_hidden_layers=starcoder_model_config.num_hidden_layers, | |
| vocab_size=starcoder_model_config.vocab_size, | |
| hidden_size=starcoder_model_config.hidden_size, | |
| num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None), | |
| ) | |
| model = StarVectorForCausalLM(starvector_config, **args) | |
| return model | |