huz-relay commited on
Commit
f0a2a0a
·
1 Parent(s): 02806d2
Files changed (1) hide show
  1. handler.py +52 -5
handler.py CHANGED
@@ -1,7 +1,9 @@
1
  from typing import Any, Dict, List
2
- from transformers import Idefics2Processor, Idefics2Model
3
  import torch
4
  import logging
 
 
5
 
6
 
7
  class EndpointHandler:
@@ -11,7 +13,7 @@ class EndpointHandler:
11
  self.logger.addHandler(logging.StreamHandler())
12
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
  self.processor = Idefics2Processor.from_pretrained(path)
14
- self.model = Idefics2Model.from_pretrained(path)
15
  self.model.to(self.device)
16
  self.logger.info("Initialisation finished!")
17
 
@@ -23,20 +25,65 @@ class EndpointHandler:
23
  Return:
24
  A :obj:`list` | `dict`: will be serialized and returned
25
  """
26
- image = data.pop("inputs", data)
27
  self.logger.info("image")
28
 
29
  # process image
30
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
31
  self.logger.info("inputs")
32
- generated_ids = self.model.forward(input_ids=inputs)
 
33
  self.logger.info("generated")
34
 
35
  # run prediction
36
  generated_text = self.processor.batch_decode(
37
  generated_ids, skip_special_tokens=True
38
  )
39
- self.logger.info("decoded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # decode output
42
  return generated_text
 
1
  from typing import Any, Dict, List
2
+ from transformers import Idefics2Processor, Idefics2ForConditionalGeneration
3
  import torch
4
  import logging
5
+ from PIL import Image
6
+ import requests
7
 
8
 
9
  class EndpointHandler:
 
13
  self.logger.addHandler(logging.StreamHandler())
14
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
  self.processor = Idefics2Processor.from_pretrained(path)
16
+ self.model = Idefics2ForConditionalGeneration.from_pretrained(path)
17
  self.model.to(self.device)
18
  self.logger.info("Initialisation finished!")
19
 
 
25
  Return:
26
  A :obj:`list` | `dict`: will be serialized and returned
27
  """
28
+ """image = data.pop("inputs", data)
29
  self.logger.info("image")
30
 
31
  # process image
32
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
33
  self.logger.info("inputs")
34
+ self.logger.info(f"{inputs.input_ids}")
35
+ generated_ids = self.model.generate(**inputs)
36
  self.logger.info("generated")
37
 
38
  # run prediction
39
  generated_text = self.processor.batch_decode(
40
  generated_ids, skip_special_tokens=True
41
  )
42
+ self.logger.info("decoded")"""
43
+
44
+ url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
45
+ url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg"
46
+
47
+ image_1 = Image.open(requests.get(url_1, stream=True).raw)
48
+ image_2 = Image.open(requests.get(url_2, stream=True).raw)
49
+ images = [image_1, image_2]
50
+
51
+ messages = [
52
+ {
53
+ "role": "user",
54
+ "content": [
55
+ {
56
+ "type": "text",
57
+ "text": "What’s the difference between these two images?",
58
+ },
59
+ {"type": "image"},
60
+ {"type": "image"},
61
+ ],
62
+ }
63
+ ]
64
+
65
+ processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b")
66
+ model = Idefics2ForConditionalGeneration.from_pretrained(
67
+ "HuggingFaceM4/idefics2-8b"
68
+ )
69
+ model.to(self.device)
70
+
71
+ # at inference time, one needs to pass `add_generation_prompt=True` in order to make sure the model completes the prompt
72
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
73
+ self.logger.info(text)
74
+ # 'User: What’s the difference between these two images?<image><image><end_of_utterance>\nAssistant:'
75
+
76
+ inputs = processor(images=images, text=text, return_tensors="pt").to(
77
+ self.device
78
+ )
79
+ self.logger.info("inputs")
80
+
81
+ generated_text = model.generate(**inputs, max_new_tokens=500)
82
+ self.logger.info("generated")
83
+ generated_text = processor.batch_decode(
84
+ generated_text, skip_special_tokens=True
85
+ )[0]
86
+ self.logger.info(f"Generated text: {generated_text}")
87
 
88
  # decode output
89
  return generated_text