Spaces:
Runtime error
Runtime error
# Necessary imports | |
import sys | |
from typing import Any | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
# Local imports | |
from src.logger import logging | |
from src.exception import CustomExceptionHandling | |
def load_model_and_tokenizer(model_name: str, device: str) -> Any: | |
""" | |
Load the model and tokenizer. | |
Args: | |
- model_name (str): The name of the model to load. | |
- device (str): The device to load the model onto. | |
Returns: | |
- model: The loaded model. | |
- tokenizer: The loaded tokenizer. | |
""" | |
try: | |
# Load the model and tokenizer | |
model = AutoModel.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
attn_implementation="sdpa", | |
torch_dtype=torch.bfloat16, | |
) | |
model = model.to(device=device) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model.eval() | |
# Log the successful loading of the model and tokenizer | |
logging.info("Model and tokenizer loaded successfully.") | |
# Return the model and tokenizer | |
return model, tokenizer | |
# Handle exceptions that may occur during model and tokenizer loading | |
except Exception as e: | |
# Custom exception handling | |
raise CustomExceptionHandling(e, sys) from e | |