hf-issue-36653 / modeling.py
yairschiff's picture
Initial commit
e07ddf4 verified
import transformers
from .backbone import MyBackbone
class MyModelConfig(transformers.PretrainedConfig):
model_type = "my_model"
auto_map = {
"AutoConfig": "modeling.MyModelConfig",
"AutoModel": "modeling.MyModel",
}
def __init__(
self,
num_layers: int = 2,
input_dim: int = 2,
hidden_dim: int = 128,
output_dim: int = 2,
**kwargs
):
super().__init__(**kwargs)
self.num_layers = num_layers
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
class MyModel(transformers.PreTrainedModel):
config_class = MyModelConfig
def __init__(self, config: MyModelConfig):
super().__init__(config)
self.config = config
self.backbone = MyBackbone(
num_layers=config.num_layers,
input_dim=config.input_dim,
hidden_dim=config.hidden_dim,
output_dim=config.output_dim,
)
def forward(self, inputs):
# Forward pass through the backbone
outputs = self.backbone(inputs)
return outputs