Erfan11 commited on
Commit
dc480fb
1 Parent(s): 4eaf7ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -15
app.py CHANGED
@@ -1,20 +1,37 @@
1
- from transformers import TFBertForSequenceClassification, BertTokenizer
2
- from flask import Flask, request, jsonify
 
 
3
 
4
- app = Flask(__name__)
 
5
 
6
- # Load model and tokenizer from Hugging Face Hub
7
- model_name = "Erfan11/Neuracraft"
8
- model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token="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")
9
- tokenizer = BertTokenizer.from_pretrained(model_name, use_auth_token="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")
 
 
 
 
10
 
11
- @app.route('/predict', methods=['POST'])
12
- def predict():
13
- data = request.get_json()
14
- inputs = tokenizer(data["text"], return_tensors="tf")
 
 
15
  outputs = model(**inputs)
16
- # Process your model's output as needed
17
- return jsonify(outputs)
 
 
 
 
 
 
 
 
18
 
19
- if __name__ == '__main__':
20
- app.run(debug=True)
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from transformers import TFBertForSequenceClassification, BertTokenizerFast
4
+ import tensorflow as tf
5
 
6
+ # Load environment variables
7
+ load_dotenv()
8
 
9
+ def load_model(model_name):
10
+ try:
11
+ # Try loading the model as a TensorFlow model
12
+ model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('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'))
13
+ except OSError:
14
+ # If loading fails, assume it's a PyTorch model and use from_pt=True
15
+ model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('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'), from_pt=True)
16
+ return model
17
 
18
+ def load_tokenizer(model_name):
19
+ tokenizer = BertTokenizerFast.from_pretrained(model_name, use_auth_token=os.getenv('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'))
20
+ return tokenizer
21
+
22
+ def predict(text, model, tokenizer):
23
+ inputs = tokenizer(text, return_tensors="tf")
24
  outputs = model(**inputs)
25
+ return outputs
26
+
27
+ def main():
28
+ model_name = os.getenv('Erfan11/Neuracraft')
29
+ model = load_model(model_name)
30
+ tokenizer = load_tokenizer(model_name)
31
+ # Example usage
32
+ text = "Sample input text"
33
+ result = predict(text, model, tokenizer)
34
+ print(result)
35
 
36
+ if __name__ == "__main__":
37
+ main()