peaceful-pirate commited on
Commit
59042ec
·
verified ·
1 Parent(s): 3e04ac1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import spu.utils.distributed as ppd
3
+
4
+ from time import time
5
+ from datasets import load_dataset
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoImageProcessor,
9
+ FlaxResNetForImageClassification,
10
+ )
11
+
12
+ parser = argparse.ArgumentParser(description='distributed driver.')
13
+ parser.add_argument("-c", "--config", default="3pc.json")
14
+ args = parser.parse_args()
15
+
16
+ with open(args.config, 'r') as file:
17
+ conf = json.load(file)
18
+
19
+ ppd.init(conf["nodes"], conf["devices"])
20
+
21
+ dataset = load_dataset("huggingface/cats-image")
22
+ image = dataset["test"]["image"][0]
23
+
24
+ processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
25
+ model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50")
26
+
27
+ inputs = processor(image, return_tensors="jax")["pixel_values"]
28
+
29
+
30
+ def run_on_spu(inputs, model):
31
+ start = time()
32
+ inputs = ppd.device("P1")(lambda x: x)(inputs)
33
+ params = ppd.device("P2")(lambda x: x)(model.params)
34
+ outputs = ppd.device("SPU")(inference)(inputs, params)
35
+ outputs = ppd.get(outputs)
36
+ outputs = outputs['logits']
37
+ predicted_class_idx = jax.numpy.argmax(outputs, axis=-1)
38
+ print(f"Elapsed time:{time() - start}")
39
+ print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
40
+
41
+
42
+ def run_on_cpu(inputs, model):
43
+ start = time()
44
+ outputs = inference(inputs, model.params)
45
+ outputs = outputs['logits']
46
+ predicted_class_idx = jax.numpy.argmax(outputs, axis=-1)
47
+ print(f"Elapsed time:{time() - start}")
48
+ print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
49
+
50
+
51
+ if __name__ == "__main__":
52
+ print("Run on CPU\n------\n")
53
+ run_on_cpu(inputs, model)
54
+ print("Run on SPU\n------\n")
55
+ run_on_spu(inputs, model)