| import copy | |
| from typing import Optional | |
| from PIL import Image | |
| from .single_image_convsation import SingleImageConvDatasetMixin | |
| class SingleImageInteractive(SingleImageConvDatasetMixin): | |
| _printed_sample = True | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.image: Optional[Image.Image] = None | |
| self.roles = ('human', 'gpt') | |
| self.boxes = [] | |
| self.points = [] | |
| self.raw_conv = [] | |
| self.conversations = [] | |
| def set_image(self, image: Image.Image): | |
| assert self.image is None, f"{image}" | |
| self.image = image | |
| def append_message(self, role: str, message: str, *, boxes=None, points=None, boxes_seq=None, points_seq=None): | |
| """Append a new message.""" | |
| assert role in self.roles | |
| def convert_idx(objs_seq, objs_value, get_obj_idx_func): | |
| if objs_seq is None: | |
| return None | |
| ret = [] | |
| for objs_idx in objs_seq: | |
| new_objs_idx = [] | |
| for idx in objs_idx: | |
| new_idx = get_obj_idx_func(objs_value[idx]) | |
| new_objs_idx.append(new_idx) | |
| ret.append(tuple(new_objs_idx)) | |
| return tuple(ret) | |
| boxes_seq = convert_idx(boxes_seq, boxes, self._get_box_idx) | |
| points_seq = convert_idx(points_seq, points, self._get_point_idx) | |
| if self.image is not None: | |
| previous_message_has_image_placeholder = any( | |
| '<image>' in item['value'] for item in self.conversations | |
| ) | |
| if not previous_message_has_image_placeholder and '<image>' not in message: | |
| message = '<image> ' + message | |
| if previous_message_has_image_placeholder and '<image>' in message: | |
| message = message.replace('<image>', '') | |
| self.conversations.append( | |
| { | |
| 'from': role, | |
| 'value': message, | |
| 'boxes_seq': copy.deepcopy(boxes_seq), | |
| 'points_seq': copy.deepcopy(points_seq), | |
| } | |
| ) | |
| def get_raw_item(self, index=None): | |
| ret = copy.deepcopy({ | |
| 'image': self.image, | |
| 'target': { | |
| 'boxes': self.boxes, | |
| 'points': self.points, | |
| }, | |
| 'conversations': self.conversations, | |
| }) | |
| assert ret['conversations'][0]['from'] == self.roles[0] | |
| if ret['conversations'][-1]['from'] == self.roles[0]: | |
| ret['conversations'].append( | |
| { | |
| 'from': self.roles[1], | |
| 'value': '', | |
| } | |
| ) | |
| return ret | |
| def to_model_input(self): | |
| item = self.__getitem__(0) | |
| ret = {'input_ids': item['input_ids'].unsqueeze(0).cuda()} | |
| if 'image' in item and item['image'] is not None: | |
| ret['images'] = item['image'].unsqueeze(0).cuda() | |
| else: | |
| ret['images'] = None | |
| return ret | |
| def to_gradio_chatbot_new_messages(self): | |
| conv = self.__getitem__(0, return_conv=True) | |
| new_messages = conv.messages[-2:] | |
| ret_messages = [] | |
| for r, m in new_messages: | |
| nm = m.replace('<im_patch>', '').replace('<im_end>', '').replace('<im_start>', '<image>') | |
| ret_messages.append((r, nm)) | |
| return ret_messages | |
| def _get_box_idx(self, box): | |
| assert isinstance(box, (tuple, list)), f"{type(box)}" | |
| assert isinstance(box[0], (int, float)), f"{type(box[0])}" | |
| assert len(box) == 4 | |
| box = tuple(box) | |
| if box not in self.boxes: | |
| self.boxes.append(box) | |
| return len(self.boxes) - 1 | |
| else: | |
| return self.boxes.index(box) | |
| def _get_point_idx(self, point): | |
| assert isinstance(point, (tuple, list)) | |
| assert isinstance(point[0], (int, float)) | |
| assert len(point) == 2 | |
| point = tuple(point) | |
| if point not in self.points: | |
| self.points.append(tuple(point)) | |
| return len(self.points) - 1 | |
| else: | |
| return self.points.index(point) | |
| def __len__(self): | |
| return 1 | |