File size: 1,857 Bytes
cb7d23f
bcdb617
cb7d23f
bcdb617
cb7d23f
 
dc5bfb6
cb7d23f
 
 
 
 
 
bcdb617
 
cb7d23f
bcdb617
 
cb7d23f
 
258b0b5
cb7d23f
d118df2
cb7d23f
 
 
 
 
 
 
 
dd98ef8
cb7d23f
 
dd98ef8
cb7d23f
 
 
 
 
bcdb617
cb7d23f
cdf47f6
bcdb617
 
 
e96ebfa
 
 
cb7d23f
258b0b5
 
cb7d23f
e96ebfa
 
cb7d23f
258b0b5
cb7d23f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Necessary imports
import os
import sys
from dotenv import load_dotenv
from typing import Any
import torch
from transformers import AutoModel, AutoTokenizer, AutoProcessor

# Local imports
from src.logger import logging
from src.exception import CustomExceptionHandling


# Load the Environment Variables from .env file
load_dotenv()

# Access token for using the model
access_token = os.environ.get("ACCESS_TOKEN")


def load_model_tokenizer_and_processor(model_name: str, device: str) -> Any:
    """
    Load the model, tokenizer and processor.

    Args:
        - model_name (str): The name of the model to load.
        - device (str): The device to load the model onto.

    Returns:
        - model: The loaded model.
        - tokenizer: The loaded tokenizer.
        - processor: The loaded processor.
    """
    try:
        # Load the model, tokenizer and processor
        model = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            attn_implementation="sdpa",
            torch_dtype=torch.bfloat16,
            token=access_token
        )
        model = model.eval().to(device=device)
        tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True, token=access_token
        )
        processor = AutoProcessor.from_pretrained(
            model_name, trust_remote_code=True, token=access_token
        )

        # Log the successful loading of the model, tokenizer and processor
        logging.info("Model, tokenizer and processor loaded successfully.")

        # Return the model, tokenizer and processor
        return model, tokenizer, processor

    # Handle exceptions that may occur during model, tokenizer and processor loading
    except Exception as e:
        # Custom exception handling
        raise CustomExceptionHandling(e, sys) from e