sushmanth commited on
Commit
2b89217
·
1 Parent(s): 814b60e

Upload functions.py

Browse files
Files changed (1) hide show
  1. functions.py +47 -0
functions.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import onnxruntime as rt
4
+
5
+ model_path = 'model/model.onnx'
6
+ class_path = 'model/birds_name_mapping.json'
7
+
8
+ normalise_means = [0.4914, 0.4822, 0.4465]
9
+ normalise_stds = [0.2023, 0.1994, 0.2010]
10
+
11
+ def normalise_image(image):
12
+ image = image.copy()
13
+ for i in range(3):
14
+ image[:, i, :, :] = (image[:, i, :, :] - normalise_means[i]) / normalise_stds[i]
15
+ return image
16
+
17
+ def load_class_names():
18
+ with open(class_path, 'r') as f:
19
+ class_names = json.load(f)
20
+ return class_names
21
+
22
+ def predict(inp_image):
23
+
24
+ class_names = load_class_names()
25
+
26
+ image = inp_image
27
+ image = image.transpose((2, 0, 1))
28
+
29
+ image = image / 255.0
30
+ image = np.expand_dims(image, axis=0)
31
+ image = normalise_image(image)
32
+ image = image.astype(np.float32)
33
+
34
+ sess = rt.InferenceSession(model_path)
35
+
36
+ input_name = sess.get_inputs()[0].name
37
+ output_name = sess.get_outputs()[0].name
38
+
39
+ output = sess.run([output_name], {input_name: image})[0]
40
+ prob = np.exp(output) / np.sum(np.exp(output), axis=1, keepdims=True)
41
+
42
+ top5 = np.argsort(prob[0])[-5:][::-1]
43
+
44
+ class_probs = {class_names[str(i)]: float(prob[0][i]) for i in top5}
45
+ print(class_probs)
46
+
47
+ return class_probs