BAAI
/

File size: 2,328 Bytes
41fdfca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dabb3aa
 
 
41fdfca
dabb3aa
 
 
 
41fdfca
dabb3aa
41fdfca
dabb3aa
41fdfca
dabb3aa
41fdfca
dabb3aa
41fdfca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# coding=utf-8
# Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logits Processor Helper class for Emu3. """

import torch

class Emu3PrefixConstrainedLogitsHelper:

    def __init__(
        self,
        height,
        width,
        img_token,
        eoi_token,
        eos_token,
        eol_token,
        eof_token,
        pad_token,
        visual_tokens,
    ):
        self.height = height
        self.width = width
        self.img_token = img_token
        self.eoi_token = eoi_token
        self.eos_token = eos_token
        self.eol_token = eol_token
        self.eof_token = eof_token
        self.pad_token = pad_token
        self.visual_tokens = visual_tokens

        self.offset_cache = {}

    def __call__(self, batch_id, input_ids):
        if batch_id not in self.offset_cache:
            position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
            self.offset_cache[batch_id] = position

        height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0]
        width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0]

        offset = input_ids.shape[0] - self.offset_cache[batch_id]
        height = height.to(offset.device)
        width = width.to(offset.device)

        if offset % (width + 1) == 0:
            return (self.eol_token, )
        elif offset == (width + 1) * height + 1:
            return (self.eof_token, )
        elif offset == (width + 1) * height + 2:
            return (self.eoi_token, )
        elif offset == (width + 1) * height + 3:
            return (self.eos_token, )
        elif offset > (width + 1) * height + 3:
            return (self.pad_token, )
        else:
            return self.visual_tokens