Erfan11 commited on
Commit
b692cf4
1 Parent(s): e495d23

Create load_model.py

Browse files
Files changed (1) hide show
  1. load_model.py +39 -0
load_model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from transformers import TFBertForSequenceClassification, BertTokenizerFast
4
+
5
+ # Load environment variables from .env file
6
+ load_dotenv()
7
+
8
+ def load_model(model_name):
9
+ try:
10
+ # Load TensorFlow model from Hugging Face
11
+ model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('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'), from_tf=True)
12
+ except OSError:
13
+ raise ValueError("Model loading failed.")
14
+ return model
15
+
16
+ def load_tokenizer(model_name):
17
+ tokenizer = BertTokenizerFast.from_pretrained(model_name, use_auth_token=os.getenv('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'))
18
+ return tokenizer
19
+
20
+ def predict(text, model, tokenizer):
21
+ inputs = tokenizer(text, return_tensors="tf")
22
+ outputs = model(**inputs)
23
+ return outputs
24
+
25
+ def main():
26
+ model_name = os.getenv('Erfan11/Neuracraft')
27
+ if model_name is None:
28
+ raise ValueError("Erfan11/Neuracraft environment variable not set or is None")
29
+
30
+ model = load_model(model_name)
31
+ tokenizer = load_tokenizer(model_name)
32
+
33
+ # Example prediction
34
+ text = "Sample input text"
35
+ result = predict(text, model, tokenizer)
36
+ print(result)
37
+
38
+ if __name__ == "__main__":
39
+ main()