Image Classification
unity-sentis
ONNX
File size: 3,409 Bytes
05abce4
c0392be
 
05abce4
 
a92657c
 
 
 
 
 
 
 
 
 
 
 
 
 
c0392be
 
05abce4
 
 
c0392be
 
 
 
 
 
 
 
05abce4
c0392be
 
 
 
05abce4
c0392be
 
 
05abce4
 
 
c0392be
 
 
 
 
 
 
05abce4
 
 
 
 
 
 
 
 
 
 
 
 
 
c0392be
 
05abce4
c0392be
 
 
 
 
 
 
 
05abce4
c0392be
 
 
 
 
05abce4
 
 
 
c0392be
 
05abce4
 
c0392be
 
 
05abce4
c0392be
 
 
 
 
 
05abce4
c0392be
05abce4
c0392be
 
 
 
05abce4
 
 
c0392be
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
using System.Collections.Generic;
using Unity.Sentis;
using UnityEngine;
using System.IO;
using FF = Unity.Sentis.Functional;
/*
 *  MovileNetV2 Inference Script
 *  ============================
 *  
 *  Place this script on the Main Camera
 *  
 *  Drag an image to the inputImage field
 *  
 *  When run the prediction of what the image is will output to the console window.
 *  You can modify the script to make it do something more interesting.
 * 
 */


public class RunMobileNet : MonoBehaviour
{
    //draw the sentis file here:
    public ModelAsset modelAsset;

    const string modelName = "mobilenet_v2.sentis";

    //The image to classify here:
    public Texture2D inputImage;

    //Link class_desc.txt here:
    public TextAsset labelsAsset;

    //All images are resized to these values to go into the model
    const int imageHeight = 224;
    const int imageWidth = 224;

    const BackendType backend = BackendType.GPUCompute;

    private IWorker engine;
    private string[] labels;

    //Used to normalise the input RGB values
    TensorFloat mulRGB = new TensorFloat(new TensorShape(1, 3, 1, 1), new float[] { 1 / 0.229f, 1 / 0.224f, 1 / 0.225f });
    TensorFloat shiftRGB = new TensorFloat(new TensorShape(1, 3, 1, 1), new float[] { 0.485f, 0.456f, 0.406f });

    void Start()
    {

        //Parse neural net labels
        labels = labelsAsset.text.Split('\n');

        //Load model from file or asset
        //var model = ModelLoader.Load(Path.Join(Application.streamingAssetsPath, modelName));
        var model = ModelLoader.Load(modelAsset);

        //We modify the model to normalise the input RGB values and select the highest prediction
        //probability and item number
        var model2 = FF.Compile(
            input =>
            {
                var probability = model.Forward(NormaliseRGB(input))[0];
                return (FF.ReduceMax(probability, 1), FF.ArgMax(probability, 1));
            },
            model.inputs[0]
        );

        //Setup the engine to run the model
        engine = WorkerFactory.CreateWorker(backend, model2);

        //Execute inference
        ExecuteML();
    }

    public void ExecuteML()
    {
        //Preprocess image for input
        using var input = TextureConverter.ToTensor(inputImage, imageWidth, imageHeight, 3);
        
        //Execute neural net
        engine.Execute(input);

        //Read output tensor
        var probability = engine.PeekOutput("output_0") as TensorFloat;
        var item = engine.PeekOutput("output_1") as TensorInt;
        item.CompleteOperationsAndDownload();
        probability.CompleteOperationsAndDownload();

        //Select the best output class and print the results
        var ID = item[0];
        var accuracy = probability[0];

        //The result is output to the console window
        int percent = Mathf.FloorToInt(accuracy * 100f + 0.5f);
        Debug.Log($"Prediction: {labels[ID]} {percent}﹪");

        //Clean memory
        Resources.UnloadUnusedAssets();
    }

    //This scales and shifts the RGB values for input into the model
    FunctionalTensor NormaliseRGB(FunctionalTensor image)
    {
        return (image - FunctionalTensor.FromTensor(shiftRGB)) * FunctionalTensor.FromTensor(mulRGB);
    }
    
    private void OnDestroy()
    {
        mulRGB?.Dispose();
        shiftRGB?.Dispose();
        engine?.Dispose();      
    }
}