File size: 1,965 Bytes
fd4b932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import time

os.sys.path += ['expman']
import expman

import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import trange


def main(args):
    is_run_dir = expman.is_exp_dir(args.model)
    if is_run_dir:
        exp = expman.from_dir(args.model)
        for model_file in ('best_savedmodel', 'best_model.h5', 'last_model.h5'):
            model_path = exp.path_to(model_file)
            if os.path.exists(model_path):
                break
    elif tf.saved_model.contains_saved_model(args.model):
        model_path = args.model
    else:
        print('Cannot find suitable model snapshot in {}'.format(args.model))
        exit(1)

    model = tf.keras.models.load_model(model_path, compile=False, custom_objects={'tf': tf})
    data = np.empty((1, args.rh, args.rw, 1), dtype=np.float32)

    # warm-up
    model.predict(data)

    start = time.time()
    for _ in trange(args.n):
        model.predict(data)
    end = time.time()
    elapsed = end - start

    throughput = elapsed / args.n
    fps = args.n / elapsed
    print(f'Total: {elapsed:g}s ({throughput * 1000} ms/img, {fps} fps)')

    timings = pd.Series({'elapsed': elapsed, 'throughput': throughput, 'fps': fps})

    if is_run_dir and not args.output:
        timings_path = exp.path_to('timings.csv')
        timings.to_csv(timings_path)

    if args.output:
        timings.to_csv(args.output)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Predict on test video')
    parser.add_argument('model', help='path to model or run dir')
    parser.add_argument('-n', type=int, default=100, help='number of predictions')
    parser.add_argument('-rh', type=int, default=128, help='RoI height (-1 for full height)')
    parser.add_argument('-rw', type=int, default=128, help='RoI width (-1 for full width)')
    parser.add_argument('-o', '--output', help='CSV output file')

    args = parser.parse_args()
    main(args)