import torch import torch.nn as nn from transformers import PreTrainedModel, AutoConfig, AutoModel class CustomModel(PreTrainedModel): config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class def __init__(self, config): super().__init__(config) # Implement your model architecture here self.encoder = AutoModel.from_config(config) # Load the base model self.classifier = nn.Linear(config.hidden_size, config.num_labels) def forward(self, input_ids, attention_mask=None): # Pass inputs through the encoder outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) # Get the pooled output (e.g., CLS token for classification tasks) pooled_output = outputs.last_hidden_state[:, 0, :] # Pass through the classifier logits = self.classifier(pooled_output) return logits @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): try: # Load the configuration config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) # Initialize the model with the configuration model = cls(config) # Optionally, you can load the state_dict here if needed # model.load_state_dict(torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin"))) return model except Exception as e: print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}") return None