赛萌 commited on
Commit
ba4096b
·
1 Parent(s): 28f35d9
Files changed (4) hide show
  1. modeling_qwen.py +10 -153
  2. qwen.tiktoken +0 -0
  3. tokenization_qwen.py +432 -0
  4. visual.py +70 -19
modeling_qwen.py CHANGED
@@ -69,44 +69,7 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
69
 
70
  apply_rotary_emb_func = None
71
  rms_norm = None
72
- flash_attn_unpadded_func = None
73
-
74
-
75
- def _import_flash_attn():
76
- global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
77
- try:
78
- from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
79
- apply_rotary_emb_func = __apply_rotary_emb_func
80
- except ImportError:
81
- logger.warn(
82
- "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
83
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
84
- )
85
 
86
- try:
87
- from flash_attn.ops.rms_norm import rms_norm as __rms_norm
88
- rms_norm = __rms_norm
89
- except ImportError:
90
- logger.warn(
91
- "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
92
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
93
- )
94
-
95
- try:
96
- import flash_attn
97
- if not hasattr(flash_attn, '__version__'):
98
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
99
- else:
100
- if int(flash_attn.__version__.split(".")[0]) >= 2:
101
- from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
102
- else:
103
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
104
- flash_attn_unpadded_func = __flash_attn_unpadded_func
105
- except ImportError:
106
- logger.warn(
107
- "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
108
- "https://github.com/Dao-AILab/flash-attention"
109
- )
110
 
111
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
112
  def _make_causal_mask(
@@ -141,70 +104,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
141
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
142
 
143
 
144
- class FlashSelfAttention(torch.nn.Module):
145
- def __init__(
146
- self,
147
- causal=False,
148
- softmax_scale=None,
149
- attention_dropout=0.0,
150
- ):
151
- super().__init__()
152
- assert flash_attn_unpadded_func is not None, (
153
- "Please install FlashAttention first, " "e.g., with pip install flash-attn"
154
- )
155
- assert (
156
- rearrange is not None
157
- ), "Please install einops first, e.g., with pip install einops"
158
- self.causal = causal
159
- self.softmax_scale = softmax_scale
160
- self.dropout_p = attention_dropout
161
-
162
- def forward(self, q, k, v):
163
- assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
164
- assert all((i.is_cuda for i in (q, k, v)))
165
- batch_size, seqlen_q = q.shape[0], q.shape[1]
166
- seqlen_k = k.shape[1]
167
- q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
168
- cu_seqlens_q = torch.arange(
169
- 0,
170
- (batch_size + 1) * seqlen_q,
171
- step=seqlen_q,
172
- dtype=torch.int32,
173
- device=q.device,
174
- )
175
-
176
- if self.training:
177
- assert seqlen_k == seqlen_q
178
-
179
- is_causal = self.causal
180
- cu_seqlens_k = cu_seqlens_q
181
- else:
182
- is_causal = seqlen_q == seqlen_k
183
- cu_seqlens_k = torch.arange(
184
- 0,
185
- (batch_size + 1) * seqlen_k,
186
- step=seqlen_k,
187
- dtype=torch.int32,
188
- device=q.device,
189
- )
190
- self.dropout_p = 0
191
- output = flash_attn_unpadded_func(
192
- q,
193
- k,
194
- v,
195
- cu_seqlens_q,
196
- cu_seqlens_k,
197
- seqlen_q,
198
- seqlen_k,
199
- self.dropout_p,
200
- softmax_scale=self.softmax_scale,
201
- causal=is_causal,
202
- )
203
-
204
- output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
205
- return output
206
-
207
-
208
  class QWenAttention(nn.Module):
209
  def __init__(self, config):
210
  super().__init__()
@@ -225,7 +124,6 @@ class QWenAttention(nn.Module):
225
  self.num_heads = config.num_attention_heads
226
  self.head_dim = self.hidden_size // self.num_heads
227
 
228
- self.use_flash_attn = config.use_flash_attn
229
  self.scale_attn_weights = True
230
 
231
  self.projection_size = config.kv_channels * config.num_attention_heads
@@ -242,15 +140,6 @@ class QWenAttention(nn.Module):
242
  )
243
 
244
  self.is_fp32 = not (config.bf16 or config.fp16)
245
- if (
246
- self.use_flash_attn
247
- and flash_attn_unpadded_func is not None
248
- and not self.is_fp32
249
- ):
250
- self.core_attention_flash = FlashSelfAttention(
251
- causal=True, attention_dropout=config.attn_dropout_prob
252
- )
253
-
254
  self.bf16 = config.bf16
255
 
256
  if config.rotary_pct == 1.0:
@@ -453,40 +342,20 @@ class QWenAttention(nn.Module):
453
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
454
  query = query * logn_tensor.expand_as(query)
455
 
456
- if (
457
- self.use_flash_attn
458
- and flash_attn_unpadded_func is not None
459
- and not self.is_fp32
460
- and query.is_cuda
461
- ):
462
- q, k, v = query, key, value
463
- context_layer = self.core_attention_flash(q, k, v)
464
-
465
- context_layer = rearrange(
466
- context_layer, "b s h d -> b s (h d)"
467
- ).contiguous()
468
- else:
469
- query = query.permute(0, 2, 1, 3)
470
- key = key.permute(0, 2, 1, 3)
471
- value = value.permute(0, 2, 1, 3)
472
- attn_output, attn_weight = self._attn(
473
- query, key, value, attention_mask, head_mask
474
- )
475
- context_layer = self._merge_heads(
476
- attn_output, self.num_heads, self.head_dim
477
- )
478
 
479
  attn_output = self.c_proj(context_layer)
480
  outputs = (attn_output, present)
481
  if output_attentions:
482
- if (
483
- self.use_flash_attn
484
- and flash_attn_unpadded_func is not None
485
- and not self.is_fp32
486
- ):
487
- raise ValueError("Cannot output attentions while using flash-attn")
488
- else:
489
- outputs += (attn_weight,)
490
 
491
  return outputs
492
 
@@ -882,18 +751,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
882
  logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
883
  elif SUPPORT_FP16:
884
  logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
885
-
886
- if config.use_flash_attn == "auto":
887
- if config.bf16 or config.fp16:
888
- logger.warn("Try importing flash-attention for faster inference...")
889
- config.use_flash_attn = True
890
- else:
891
- config.use_flash_attn = False
892
- if config.use_flash_attn and config.fp32:
893
- logger.warn("Flash attention will be disabled because it does NOT support fp32.")
894
-
895
- if config.use_flash_attn:
896
- _import_flash_attn()
897
 
898
  self.transformer = QWenModel(config)
899
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
69
 
70
  apply_rotary_emb_func = None
71
  rms_norm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
75
  def _make_causal_mask(
 
104
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  class QWenAttention(nn.Module):
108
  def __init__(self, config):
109
  super().__init__()
 
124
  self.num_heads = config.num_attention_heads
125
  self.head_dim = self.hidden_size // self.num_heads
126
 
 
127
  self.scale_attn_weights = True
128
 
129
  self.projection_size = config.kv_channels * config.num_attention_heads
 
140
  )
141
 
142
  self.is_fp32 = not (config.bf16 or config.fp16)
 
 
 
 
 
 
 
 
 
143
  self.bf16 = config.bf16
144
 
145
  if config.rotary_pct == 1.0:
 
342
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
343
  query = query * logn_tensor.expand_as(query)
344
 
345
+ query = query.permute(0, 2, 1, 3)
346
+ key = key.permute(0, 2, 1, 3)
347
+ value = value.permute(0, 2, 1, 3)
348
+ attn_output, attn_weight = self._attn(
349
+ query, key, value, attention_mask, head_mask
350
+ )
351
+ context_layer = self._merge_heads(
352
+ attn_output, self.num_heads, self.head_dim
353
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  attn_output = self.c_proj(context_layer)
356
  outputs = (attn_output, present)
357
  if output_attentions:
358
+ outputs += (attn_weight,)
 
 
 
 
 
 
 
359
 
360
  return outputs
361
 
 
751
  logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
752
  elif SUPPORT_FP16:
753
  logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
  self.transformer = QWenModel(config)
756
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
qwen.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
tokenization_qwen.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import requests
12
+ import unicodedata
13
+ from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
14
+
15
+ import tiktoken
16
+ import numpy as np
17
+ from PIL import Image
18
+ from PIL import ImageFont
19
+ from PIL import ImageDraw
20
+ from transformers import PreTrainedTokenizer, AddedToken
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
26
+
27
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
28
+ ENDOFTEXT = "<|endoftext|>"
29
+ IMSTART = "<|im_start|>"
30
+ IMEND = "<|im_end|>"
31
+ # as the default behavior is changed to allow special tokens in
32
+ # regular texts, the surface forms of special tokens need to be
33
+ # as different as possible to minimize the impact
34
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
35
+ SPECIAL_TOKENS = (
36
+ ENDOFTEXT,
37
+ IMSTART,
38
+ IMEND,
39
+ ) + EXTRAS
40
+ IMG_TOKEN_SPAN = 256
41
+
42
+
43
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
44
+ with open(tiktoken_bpe_file, "rb") as f:
45
+ contents = f.read()
46
+ return {
47
+ base64.b64decode(token): int(rank)
48
+ for token, rank in (line.split() for line in contents.splitlines() if line)
49
+ }
50
+
51
+ def _list_find(
52
+ input_list: List[Any],
53
+ candidates: Tuple[Any],
54
+ start: int = 0,
55
+ ):
56
+ for i in range(start, len(input_list)):
57
+ if input_list[i] in candidates:
58
+ return i
59
+ return -1
60
+
61
+ def _replace_closed_tag(
62
+ input_tokens: List[Any],
63
+ start_tags: Union[Any, Tuple[Any]],
64
+ end_tags: Union[Any, Tuple[Any]],
65
+ inclusive_replace_func: Callable,
66
+ exclusive_replace_func: Callable = lambda x: x,
67
+ ):
68
+ if isinstance(start_tags, (str, int)):
69
+ start_tags = (start_tags,)
70
+ if isinstance(end_tags, (str, int)):
71
+ end_tags = (end_tags,)
72
+ assert len(start_tags) == len(end_tags)
73
+
74
+ output_tokens = []
75
+ end = 0
76
+ while True:
77
+ start = _list_find(input_tokens, start_tags, end)
78
+ if start == -1:
79
+ break
80
+ output_tokens.extend(exclusive_replace_func(input_tokens[end : start]))
81
+ tag_idx = start_tags.index(input_tokens[start])
82
+ end = _list_find(input_tokens, (end_tags[tag_idx],), start)
83
+ if end == -1:
84
+ raise ValueError("Unclosed image token")
85
+ output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1]))
86
+ end += 1
87
+ output_tokens.extend(exclusive_replace_func(input_tokens[end : ]))
88
+ return output_tokens
89
+
90
+ class QWenTokenizer(PreTrainedTokenizer):
91
+ """QWen tokenizer."""
92
+
93
+ vocab_files_names = VOCAB_FILES_NAMES
94
+
95
+ def __init__(
96
+ self,
97
+ vocab_file,
98
+ errors="replace",
99
+ image_start_tag='<img>',
100
+ image_end_tag='</img>',
101
+ image_pad_tag='<imgpad>',
102
+ ref_start_tag='<ref>',
103
+ ref_end_tag='</ref>',
104
+ box_start_tag='<box>',
105
+ box_end_tag='</box>',
106
+ quad_start_tag='<quad>',
107
+ quad_end_tag='</quad>',
108
+ **kwargs,
109
+ ):
110
+ super().__init__(**kwargs)
111
+ self.image_start_tag = image_start_tag
112
+ self.image_end_tag = image_end_tag
113
+ self.image_pad_tag = image_pad_tag
114
+ self.ref_start_tag = ref_start_tag
115
+ self.ref_end_tag = ref_end_tag
116
+ self.box_start_tag = box_start_tag
117
+ self.box_end_tag = box_end_tag
118
+ self.quad_start_tag = quad_start_tag
119
+ self.quad_end_tag = quad_end_tag
120
+ self.IMAGE_ST = (
121
+ ref_start_tag, ref_end_tag,
122
+ box_start_tag, box_end_tag,
123
+ quad_start_tag, quad_end_tag,
124
+ image_start_tag, image_end_tag,
125
+ image_pad_tag
126
+ )
127
+
128
+ self.errors = errors # how to handle errors in decoding
129
+
130
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
131
+ self.special_tokens = {
132
+ token: index
133
+ for index, token in enumerate(
134
+ SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
135
+ )
136
+ }
137
+ self.img_start_id = self.special_tokens[self.image_start_tag]
138
+ self.img_end_id = self.special_tokens[self.image_end_tag]
139
+ self.img_pad_id = self.special_tokens[self.image_pad_tag]
140
+ self.ref_start_id = self.special_tokens[self.ref_start_tag]
141
+ self.ref_end_id = self.special_tokens[self.ref_end_tag]
142
+ self.box_start_id = self.special_tokens[self.box_start_tag]
143
+ self.box_end_id = self.special_tokens[self.box_end_tag]
144
+ self.quad_start_id = self.special_tokens[self.quad_start_tag]
145
+ self.quad_end_id = self.special_tokens[self.quad_end_tag]
146
+
147
+ enc = tiktoken.Encoding(
148
+ "Qwen",
149
+ pat_str=PAT_STR,
150
+ mergeable_ranks=self.mergeable_ranks,
151
+ special_tokens=self.special_tokens,
152
+ )
153
+ assert (
154
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
155
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
156
+
157
+ self.decoder = {
158
+ v: k for k, v in self.mergeable_ranks.items()
159
+ } # type: dict[int, bytes|str]
160
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
161
+
162
+ self.tokenizer = enc # type: tiktoken.Encoding
163
+
164
+ self.eod_id = self.tokenizer.eot_token
165
+ self.im_start_id = self.special_tokens[IMSTART]
166
+ self.im_end_id = self.special_tokens[IMEND]
167
+
168
+ def __len__(self) -> int:
169
+ return self.tokenizer.n_vocab
170
+
171
+ def get_vocab(self) -> Dict[bytes, int]:
172
+ return self.mergeable_ranks
173
+
174
+ def convert_tokens_to_ids(
175
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
176
+ ) -> List[int]:
177
+ ids = []
178
+ if isinstance(tokens, (str, bytes)):
179
+ if tokens in self.special_tokens:
180
+ return self.special_tokens[tokens]
181
+ else:
182
+ return self.mergeable_ranks.get(tokens)
183
+ for token in tokens:
184
+ if token in self.special_tokens:
185
+ ids.append(self.special_tokens[token])
186
+ else:
187
+ ids.append(self.mergeable_ranks.get(token))
188
+ return ids
189
+
190
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
191
+ if not special_tokens and new_tokens:
192
+ raise ValueError('Adding regular tokens is not supported')
193
+ for token in new_tokens:
194
+ surface_form = token.content if isinstance(token, AddedToken) else token
195
+ if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST:
196
+ raise ValueError('Adding unknown special tokens is not supported')
197
+ return 0
198
+
199
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
200
+ """
201
+ Save only the vocabulary of the tokenizer (vocabulary).
202
+
203
+ Returns:
204
+ `Tuple(str)`: Paths to the files saved.
205
+ """
206
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
207
+ with open(file_path, "w", encoding="utf8") as w:
208
+ for k, v in self.mergeable_ranks.items():
209
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
210
+ w.write(line)
211
+ return (file_path,)
212
+
213
+ def tokenize(
214
+ self,
215
+ text: str,
216
+ allowed_special: Union[Set, str] = "all",
217
+ disallowed_special: Union[Collection, str] = (),
218
+ **kwargs,
219
+ ) -> List[Union[bytes, str]]:
220
+ """
221
+ Converts a string in a sequence of tokens.
222
+
223
+ Args:
224
+ text (`str`):
225
+ The sequence to be encoded.
226
+ allowed_special (`Literal["all"]` or `set`):
227
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
228
+ Default to "all".
229
+ disallowed_special (`Literal["all"]` or `Collection`):
230
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
231
+ Default to an empty tuple.
232
+
233
+ kwargs (additional keyword arguments, *optional*):
234
+ Will be passed to the underlying model specific encode method.
235
+
236
+ Returns:
237
+ `List[bytes|str]`: The list of tokens.
238
+ """
239
+ tokens = []
240
+ text = unicodedata.normalize("NFC", text)
241
+
242
+ # this implementation takes a detour: text -> token id -> token surface forms
243
+ for t in self.tokenizer.encode(
244
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
245
+ ):
246
+ tokens.append(self.decoder[t])
247
+
248
+ def _encode_imgurl(img_tokens):
249
+ assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
250
+ img_tokens = img_tokens[1:-1]
251
+ img_url = b''.join(img_tokens)
252
+ out_img_tokens = list(map(self.decoder.get, img_url))
253
+ if len(out_img_tokens) > IMG_TOKEN_SPAN:
254
+ raise ValueError("The content in {}..{} is too long".format(
255
+ self.image_start_tag, self.image_end_tag))
256
+ out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens)))
257
+ out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
258
+ return out_img_tokens
259
+
260
+ return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
261
+
262
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
263
+ """
264
+ Converts a sequence of tokens in a single string.
265
+ """
266
+ text = ""
267
+ temp = b""
268
+ for t in tokens:
269
+ if isinstance(t, str):
270
+ if temp:
271
+ text += temp.decode("utf-8", errors=self.errors)
272
+ temp = b""
273
+ text += t
274
+ elif isinstance(t, bytes):
275
+ temp += t
276
+ else:
277
+ raise TypeError("token should only be of type types or str")
278
+ if temp:
279
+ text += temp.decode("utf-8", errors=self.errors)
280
+ return text
281
+
282
+ @property
283
+ def vocab_size(self):
284
+ return self.tokenizer.n_vocab
285
+
286
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
287
+ """Converts an id to a token, special tokens included"""
288
+ if index in self.decoder:
289
+ return self.decoder[index]
290
+ raise ValueError("unknown ids")
291
+
292
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
293
+ """Converts a token to an id using the vocab, special tokens included"""
294
+ if token in self.special_tokens:
295
+ return self.special_tokens[token]
296
+ if token in self.mergeable_ranks:
297
+ return self.mergeable_ranks[token]
298
+ raise ValueError("unknown token")
299
+
300
+ def _tokenize(self, text: str, **kwargs):
301
+ """
302
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
303
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
304
+
305
+ Do NOT take care of added tokens.
306
+ """
307
+ raise NotImplementedError
308
+
309
+ def _decode(
310
+ self,
311
+ token_ids: Union[int, List[int]],
312
+ skip_special_tokens: bool = False,
313
+ errors: str = None,
314
+ **kwargs,
315
+ ) -> str:
316
+ if isinstance(token_ids, int):
317
+ token_ids = [token_ids]
318
+
319
+ def _decode_imgurl(img_token_ids):
320
+ assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
321
+ img_token_ids = img_token_ids[1:-1]
322
+ img_token_ids = img_token_ids[ : img_token_ids.index(self.img_pad_id)]
323
+ img_url = bytes(img_token_ids).decode('utf-8')
324
+ return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id]
325
+
326
+ token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
327
+
328
+ if skip_special_tokens:
329
+ token_ids = [i for i in token_ids if i < self.eod_id]
330
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
331
+
332
+ def to_list_format(self, text: str):
333
+ text = unicodedata.normalize("NFC", text)
334
+ token_ids = self.tokenizer.encode(
335
+ text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,)))
336
+
337
+ def _encode_vl_info(tokens):
338
+ if len(tokens) == 0:
339
+ return []
340
+ if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
341
+ key = 'image'
342
+ elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id:
343
+ key = 'ref'
344
+ elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id:
345
+ key = 'box'
346
+ elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id:
347
+ key = 'quad'
348
+ else:
349
+ _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
350
+ return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}]
351
+ val = b''.join(map(self.decoder.get, tokens[1:-1])).decode('utf-8')
352
+ return [{key: val}]
353
+
354
+ return _replace_closed_tag(
355
+ token_ids,
356
+ (self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id),
357
+ (self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id),
358
+ _encode_vl_info,
359
+ _encode_vl_info,
360
+ )
361
+
362
+ def from_list_format(self, list_format: List[Dict]):
363
+ text = ''
364
+ for ele in list_format:
365
+ if 'image' in ele:
366
+ text += self.image_start_tag + ele['image'] + self.image_end_tag
367
+ elif 'text' in ele:
368
+ text += ele['text']
369
+ elif 'box' in ele:
370
+ if 'ref' in ele:
371
+ text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
372
+ for box in ele['box']:
373
+ text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
374
+ else:
375
+ raise ValueError("Unsupport element: " + str(ele))
376
+ return text
377
+
378
+ def _fetch_latest_picture(self, response, history):
379
+ if history is None:
380
+ history = []
381
+ _history = history + [(response, None)]
382
+ for q, r in _history[::-1]:
383
+ for ele in self.to_list_format(q)[::-1]:
384
+ if 'image' in ele:
385
+ return ele['image']
386
+ return None
387
+
388
+ def _fetch_all_box_with_ref(self, text):
389
+ list_format = self.to_list_format(text)
390
+ output = []
391
+ for i, ele in enumerate(list_format):
392
+ if 'box' in ele:
393
+ bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
394
+ assert len(bbox) == 4
395
+ output.append({'box': bbox})
396
+
397
+ ref_idx = i - 1
398
+ while ref_idx >= 0 and 'box' in list_format[ref_idx]:
399
+ ref_idx -= 1
400
+ if ref_idx >= 0 and 'ref' in list_format[ref_idx]:
401
+ output[-1]['ref'] = list_format[ref_idx]['ref'].strip()
402
+ return output
403
+
404
+ def draw_bbox_on_latest_picture(
405
+ self,
406
+ response,
407
+ history=None,
408
+ ) -> Optional[Image.Image]:
409
+ image = self._fetch_latest_picture(response, history)
410
+ if image is None:
411
+ return None
412
+ if image.startswith("http://") or image.startswith("https://"):
413
+ image = Image.open(requests.get(image, stream=True).raw)
414
+ else:
415
+ image = Image.open(image)
416
+ h, w = image.height, image.width
417
+ image = image.convert("RGB")
418
+
419
+ boxes = self._fetch_all_box_with_ref(response)
420
+ if not boxes:
421
+ return None
422
+ fnt = ImageFont.truetype("SimSun.ttf", 50)
423
+ draw = ImageDraw.Draw(image)
424
+ for box in boxes:
425
+ x1, y1, x2, y2 = box['box']
426
+ x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
427
+ draw.rectangle((x1, y1, x2, y2), outline='red', width=4)
428
+ if 'ref' in box:
429
+ draw.text((x1, y1), box['ref'], fill='yellow', font=fnt)
430
+ return image
431
+
432
+
visual.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from collections import OrderedDict
2
  import math
3
  import requests
@@ -5,11 +10,11 @@ from io import BytesIO
5
  from functools import partial
6
  from PIL import Image
7
  from typing import Callable, Optional, Sequence, Tuple, List
 
8
 
9
  import torch
10
  from torch import nn
11
  from torch.nn import functional as F
12
- from torch.utils.checkpoint import checkpoint
13
  from torch.nn.init import trunc_normal_
14
  from torchvision import transforms
15
  from torchvision.transforms import InterpolationMode
@@ -33,8 +38,64 @@ def get_abs_pos(abs_pos, tgt_size):
33
  else:
34
  return abs_pos
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class Resampler(nn.Module):
 
 
 
 
 
 
38
  def __init__(
39
  self,
40
  grid_size,
@@ -48,7 +109,9 @@ class Resampler(nn.Module):
48
  self.embed_dim = embed_dim
49
  self.num_heads = num_heads
50
 
51
- self.pos_embed = nn.Parameter(torch.randn(embed_dim, grid_size)).requires_grad_(False)
 
 
52
 
53
  self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
54
  trunc_normal_(self.query, std=.02)
@@ -234,7 +297,7 @@ class VisualAttentionBlock(nn.Module):
234
  return x
235
 
236
 
237
- class Transformer(nn.Module):
238
  def __init__(
239
  self,
240
  width: int,
@@ -247,7 +310,6 @@ class Transformer(nn.Module):
247
  super().__init__()
248
  self.width = width
249
  self.layers = layers
250
- self.grad_checkpointing = False
251
 
252
  self.resblocks = nn.ModuleList([
253
  VisualAttentionBlock(
@@ -263,11 +325,7 @@ class Transformer(nn.Module):
263
 
264
  def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
265
  for r in self.resblocks:
266
- if self.grad_checkpointing and not torch.jit.is_scripting():
267
- # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
268
- x = checkpoint(r, x, None, None, attn_mask)
269
- else:
270
- x = r(x, attn_mask=attn_mask)
271
  return x
272
 
273
 
@@ -306,13 +364,13 @@ class VisionTransformer(nn.Module):
306
 
307
  # class embeddings and positional embeddings
308
  scale = width ** -0.5
309
- self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width))
310
 
311
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
312
  act_layer = nn.GELU
313
 
314
  self.ln_pre = norm_layer(width)
315
- self.transformer = Transformer(
316
  width,
317
  layers,
318
  heads,
@@ -331,10 +389,6 @@ class VisionTransformer(nn.Module):
331
  self.ln_post = norm_layer(output_dim)
332
  self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
333
 
334
- @torch.jit.ignore
335
- def set_grad_checkpointing(self, enable=True):
336
- self.transformer.grad_checkpointing = enable
337
-
338
  def forward(self, x: torch.Tensor):
339
  x = x.to(
340
  dtype=self.transformer.get_cast_dtype(),
@@ -353,8 +407,7 @@ class VisionTransformer(nn.Module):
353
  x = self.transformer(x)
354
  x = x.permute(1, 0, 2) # LND -> NLD
355
 
356
- if self.attn_pool:
357
- x = self.attn_pool(x)
358
  x = self.ln_post(x)
359
  x = x @ self.proj
360
 
@@ -365,8 +418,6 @@ class VisionTransformer(nn.Module):
365
  for image_path in image_paths:
366
  if image_path.startswith("http://") or image_path.startswith("https://"):
367
  image = Image.open(requests.get(image_path, stream=True).raw)
368
- elif image_path.startswith("oss://"):
369
- raise NotImplementedError
370
  else:
371
  image = Image.open(image_path)
372
  image = image.convert("RGB")
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
  from collections import OrderedDict
7
  import math
8
  import requests
 
10
  from functools import partial
11
  from PIL import Image
12
  from typing import Callable, Optional, Sequence, Tuple, List
13
+ import numpy as np
14
 
15
  import torch
16
  from torch import nn
17
  from torch.nn import functional as F
 
18
  from torch.nn.init import trunc_normal_
19
  from torchvision import transforms
20
  from torchvision.transforms import InterpolationMode
 
38
  else:
39
  return abs_pos
40
 
41
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
42
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
43
+ """
44
+ grid_size: int of the grid height and width
45
+ return:
46
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
47
+ """
48
+ grid_h = np.arange(grid_size, dtype=np.float32)
49
+ grid_w = np.arange(grid_size, dtype=np.float32)
50
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
51
+ grid = np.stack(grid, axis=0)
52
+
53
+ grid = grid.reshape([2, 1, grid_size, grid_size])
54
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
55
+ if cls_token:
56
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
57
+ return pos_embed
58
+
59
+
60
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
61
+ assert embed_dim % 2 == 0
62
+
63
+ # use half of dimensions to encode grid_h
64
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
65
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
66
+
67
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
68
+ return emb
69
+
70
+
71
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
72
+ """
73
+ embed_dim: output dimension for each position
74
+ pos: a list of positions to be encoded: size (M,)
75
+ out: (M, D)
76
+ """
77
+ assert embed_dim % 2 == 0
78
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
79
+ omega /= embed_dim / 2.
80
+ omega = 1. / 10000**omega # (D/2,)
81
+
82
+ pos = pos.reshape(-1) # (M,)
83
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
84
+
85
+ emb_sin = np.sin(out) # (M, D/2)
86
+ emb_cos = np.cos(out) # (M, D/2)
87
+
88
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
89
+ return emb
90
+
91
 
92
  class Resampler(nn.Module):
93
+ """
94
+ A 2D perceiver-resampler network with one cross attention layers by
95
+ (grid_size**2) learnable queries and 2d sincos pos_emb
96
+ Outputs:
97
+ A tensor with the shape of (grid_size**2, embed_dim)
98
+ """
99
  def __init__(
100
  self,
101
  grid_size,
 
109
  self.embed_dim = embed_dim
110
  self.num_heads = num_heads
111
 
112
+ self.pos_embed = nn.Parameter(
113
+ torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
114
+ ).requires_grad_(False)
115
 
116
  self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
117
  trunc_normal_(self.query, std=.02)
 
297
  return x
298
 
299
 
300
+ class TransformerBlock(nn.Module):
301
  def __init__(
302
  self,
303
  width: int,
 
310
  super().__init__()
311
  self.width = width
312
  self.layers = layers
 
313
 
314
  self.resblocks = nn.ModuleList([
315
  VisualAttentionBlock(
 
325
 
326
  def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
327
  for r in self.resblocks:
328
+ x = r(x, attn_mask=attn_mask)
 
 
 
 
329
  return x
330
 
331
 
 
364
 
365
  # class embeddings and positional embeddings
366
  scale = width ** -0.5
367
+ self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
368
 
369
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
370
  act_layer = nn.GELU
371
 
372
  self.ln_pre = norm_layer(width)
373
+ self.transformer = TransformerBlock(
374
  width,
375
  layers,
376
  heads,
 
389
  self.ln_post = norm_layer(output_dim)
390
  self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
391
 
 
 
 
 
392
  def forward(self, x: torch.Tensor):
393
  x = x.to(
394
  dtype=self.transformer.get_cast_dtype(),
 
407
  x = self.transformer(x)
408
  x = x.permute(1, 0, 2) # LND -> NLD
409
 
410
+ x = self.attn_pool(x)
 
411
  x = self.ln_post(x)
412
  x = x @ self.proj
413
 
 
418
  for image_path in image_paths:
419
  if image_path.startswith("http://") or image_path.startswith("https://"):
420
  image = Image.open(requests.get(image_path, stream=True).raw)
 
 
421
  else:
422
  image = Image.open(image_path)
423
  image = image.convert("RGB")