Upload tokenizer
Browse files- tokenization_kpr.py +49 -32
tokenization_kpr.py
CHANGED
|
@@ -255,8 +255,7 @@ def preprocess_text(
|
|
| 255 |
) -> dict[str, list[int]]:
|
| 256 |
tokens = []
|
| 257 |
entity_ids = []
|
| 258 |
-
|
| 259 |
-
entity_lengths = []
|
| 260 |
if title is not None:
|
| 261 |
if title_mentions is None:
|
| 262 |
title_mentions = []
|
|
@@ -265,8 +264,7 @@ def preprocess_text(
|
|
| 265 |
tokens += title_tokens + [tokenizer.sep_token]
|
| 266 |
for entity in title_entities:
|
| 267 |
entity_ids.append(entity.entity_id)
|
| 268 |
-
|
| 269 |
-
entity_lengths.append(entity.end - entity.start)
|
| 270 |
|
| 271 |
if mentions is None:
|
| 272 |
mentions = []
|
|
@@ -276,16 +274,14 @@ def preprocess_text(
|
|
| 276 |
tokens += text_tokens
|
| 277 |
for entity in text_entities:
|
| 278 |
entity_ids.append(entity.entity_id)
|
| 279 |
-
|
| 280 |
-
entity_lengths.append(entity.end - entity.start)
|
| 281 |
|
| 282 |
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 283 |
|
| 284 |
return {
|
| 285 |
"input_ids": input_ids,
|
| 286 |
"entity_ids": entity_ids,
|
| 287 |
-
"
|
| 288 |
-
"entity_lengths": entity_lengths,
|
| 289 |
}
|
| 290 |
|
| 291 |
|
|
@@ -349,8 +345,7 @@ class KPRBertTokenizer(BertTokenizer):
|
|
| 349 |
"token_type_ids",
|
| 350 |
"attention_mask",
|
| 351 |
"entity_ids",
|
| 352 |
-
"
|
| 353 |
-
"entity_lengths",
|
| 354 |
]
|
| 355 |
|
| 356 |
def __init__(
|
|
@@ -379,7 +374,7 @@ class KPRBertTokenizer(BertTokenizer):
|
|
| 379 |
"Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
|
| 380 |
)
|
| 381 |
|
| 382 |
-
def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int]]:
|
| 383 |
mentions = self.entity_linker.detect_mentions(text)
|
| 384 |
model_inputs = preprocess_text(
|
| 385 |
text=text,
|
|
@@ -395,18 +390,26 @@ class KPRBertTokenizer(BertTokenizer):
|
|
| 395 |
# We exclude "return_tensors" from kwargs
|
| 396 |
# to avoid issues in passing the data to BatchEncoding outside this method
|
| 397 |
prepared_inputs = self.prepare_for_model(
|
| 398 |
-
model_inputs["input_ids"],
|
|
|
|
| 399 |
)
|
| 400 |
model_inputs.update(prepared_inputs)
|
| 401 |
|
| 402 |
# Account for special tokens
|
| 403 |
-
if kwargs.get("add_special_tokens"):
|
| 404 |
if prepared_inputs["input_ids"][0] != self.cls_token_id:
|
| 405 |
raise ValueError(
|
| 406 |
"We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
|
| 407 |
)
|
| 408 |
-
# Shift the entity
|
| 409 |
-
model_inputs["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
# Count the number of special tokens at the end of the input
|
| 412 |
num_special_tokens_at_end = 0
|
|
@@ -414,26 +417,25 @@ class KPRBertTokenizer(BertTokenizer):
|
|
| 414 |
if isinstance(input_ids, torch.Tensor):
|
| 415 |
input_ids = input_ids.tolist()
|
| 416 |
for input_id in input_ids[::-1]:
|
| 417 |
-
if int(input_id) not in {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
break
|
| 419 |
num_special_tokens_at_end += 1
|
| 420 |
|
| 421 |
# Remove entities that are not in truncated input
|
| 422 |
max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
|
| 423 |
entity_indices_to_keep = list()
|
| 424 |
-
for i,
|
| 425 |
-
|
| 426 |
-
):
|
| 427 |
-
if (start_pos + length) <= max_effective_pos:
|
| 428 |
entity_indices_to_keep.append(i)
|
| 429 |
model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
|
| 430 |
-
model_inputs["
|
| 431 |
-
model_inputs["entity_start_positions"][i] for i in entity_indices_to_keep
|
| 432 |
-
]
|
| 433 |
-
model_inputs["entity_lengths"] = [model_inputs["entity_lengths"][i] for i in entity_indices_to_keep]
|
| 434 |
|
| 435 |
if self.entity_embeddings is not None:
|
| 436 |
-
model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]]
|
| 437 |
return model_inputs
|
| 438 |
|
| 439 |
def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
|
|
@@ -447,7 +449,9 @@ class KPRBertTokenizer(BertTokenizer):
|
|
| 447 |
if isinstance(text, str):
|
| 448 |
processed_inputs = self._preprocess_text(text, **kwargs)
|
| 449 |
return BatchEncoding(
|
| 450 |
-
processed_inputs,
|
|
|
|
|
|
|
| 451 |
)
|
| 452 |
|
| 453 |
processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
|
|
@@ -463,20 +467,33 @@ class KPRBertTokenizer(BertTokenizer):
|
|
| 463 |
return_attention_mask=kwargs.get("return_attention_mask"),
|
| 464 |
verbose=kwargs.get("verbose", True),
|
| 465 |
)
|
| 466 |
-
#
|
| 467 |
max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
|
| 468 |
for entity_ids in collated_inputs["entity_ids"]:
|
| 469 |
entity_ids += [0] * (max_num_entities - len(entity_ids))
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
if "entity_embeds" in collated_inputs:
|
| 475 |
for i in range(len(collated_inputs["entity_embeds"])):
|
| 476 |
collated_inputs["entity_embeds"][i] = np.pad(
|
| 477 |
collated_inputs["entity_embeds"][i],
|
| 478 |
pad_width=(
|
| 479 |
-
(
|
|
|
|
|
|
|
|
|
|
| 480 |
(0, 0),
|
| 481 |
),
|
| 482 |
mode="constant",
|
|
|
|
| 255 |
) -> dict[str, list[int]]:
|
| 256 |
tokens = []
|
| 257 |
entity_ids = []
|
| 258 |
+
entity_position_ids = []
|
|
|
|
| 259 |
if title is not None:
|
| 260 |
if title_mentions is None:
|
| 261 |
title_mentions = []
|
|
|
|
| 264 |
tokens += title_tokens + [tokenizer.sep_token]
|
| 265 |
for entity in title_entities:
|
| 266 |
entity_ids.append(entity.entity_id)
|
| 267 |
+
entity_position_ids.append(list(range(entity.start, entity.end)))
|
|
|
|
| 268 |
|
| 269 |
if mentions is None:
|
| 270 |
mentions = []
|
|
|
|
| 274 |
tokens += text_tokens
|
| 275 |
for entity in text_entities:
|
| 276 |
entity_ids.append(entity.entity_id)
|
| 277 |
+
entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset)))
|
|
|
|
| 278 |
|
| 279 |
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 280 |
|
| 281 |
return {
|
| 282 |
"input_ids": input_ids,
|
| 283 |
"entity_ids": entity_ids,
|
| 284 |
+
"entity_position_ids": entity_position_ids,
|
|
|
|
| 285 |
}
|
| 286 |
|
| 287 |
|
|
|
|
| 345 |
"token_type_ids",
|
| 346 |
"attention_mask",
|
| 347 |
"entity_ids",
|
| 348 |
+
"entity_position_ids",
|
|
|
|
| 349 |
]
|
| 350 |
|
| 351 |
def __init__(
|
|
|
|
| 374 |
"Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
|
| 375 |
)
|
| 376 |
|
| 377 |
+
def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]:
|
| 378 |
mentions = self.entity_linker.detect_mentions(text)
|
| 379 |
model_inputs = preprocess_text(
|
| 380 |
text=text,
|
|
|
|
| 390 |
# We exclude "return_tensors" from kwargs
|
| 391 |
# to avoid issues in passing the data to BatchEncoding outside this method
|
| 392 |
prepared_inputs = self.prepare_for_model(
|
| 393 |
+
model_inputs["input_ids"],
|
| 394 |
+
**{k: v for k, v in kwargs.items() if k != "return_tensors"},
|
| 395 |
)
|
| 396 |
model_inputs.update(prepared_inputs)
|
| 397 |
|
| 398 |
# Account for special tokens
|
| 399 |
+
if kwargs.get("add_special_tokens", True):
|
| 400 |
if prepared_inputs["input_ids"][0] != self.cls_token_id:
|
| 401 |
raise ValueError(
|
| 402 |
"We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
|
| 403 |
)
|
| 404 |
+
# Shift the entity position IDs by 1 to account for the [CLS] token
|
| 405 |
+
model_inputs["entity_position_ids"] = [
|
| 406 |
+
[pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"]
|
| 407 |
+
]
|
| 408 |
+
|
| 409 |
+
# If there is no entities in the text, we output padding entity for the model
|
| 410 |
+
if not model_inputs["entity_ids"]:
|
| 411 |
+
model_inputs["entity_ids"] = [0] # The padding entity id is 0
|
| 412 |
+
model_inputs["entity_position_ids"] = [[0]]
|
| 413 |
|
| 414 |
# Count the number of special tokens at the end of the input
|
| 415 |
num_special_tokens_at_end = 0
|
|
|
|
| 417 |
if isinstance(input_ids, torch.Tensor):
|
| 418 |
input_ids = input_ids.tolist()
|
| 419 |
for input_id in input_ids[::-1]:
|
| 420 |
+
if int(input_id) not in {
|
| 421 |
+
self.sep_token_id,
|
| 422 |
+
self.pad_token_id,
|
| 423 |
+
self.cls_token_id,
|
| 424 |
+
}:
|
| 425 |
break
|
| 426 |
num_special_tokens_at_end += 1
|
| 427 |
|
| 428 |
# Remove entities that are not in truncated input
|
| 429 |
max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
|
| 430 |
entity_indices_to_keep = list()
|
| 431 |
+
for i, position_ids in enumerate(model_inputs["entity_position_ids"]):
|
| 432 |
+
if len(position_ids) > 0 and max(position_ids) < max_effective_pos:
|
|
|
|
|
|
|
| 433 |
entity_indices_to_keep.append(i)
|
| 434 |
model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
|
| 435 |
+
model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep]
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
if self.entity_embeddings is not None:
|
| 438 |
+
model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32)
|
| 439 |
return model_inputs
|
| 440 |
|
| 441 |
def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
|
|
|
|
| 449 |
if isinstance(text, str):
|
| 450 |
processed_inputs = self._preprocess_text(text, **kwargs)
|
| 451 |
return BatchEncoding(
|
| 452 |
+
processed_inputs,
|
| 453 |
+
tensor_type=kwargs.get("return_tensors", None),
|
| 454 |
+
prepend_batch_axis=True,
|
| 455 |
)
|
| 456 |
|
| 457 |
processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
|
|
|
|
| 467 |
return_attention_mask=kwargs.get("return_attention_mask"),
|
| 468 |
verbose=kwargs.get("verbose", True),
|
| 469 |
)
|
| 470 |
+
# Pad entity ids
|
| 471 |
max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
|
| 472 |
for entity_ids in collated_inputs["entity_ids"]:
|
| 473 |
entity_ids += [0] * (max_num_entities - len(entity_ids))
|
| 474 |
+
# Pad entity position ids
|
| 475 |
+
flattened_entity_length = [
|
| 476 |
+
len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list
|
| 477 |
+
]
|
| 478 |
+
max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0
|
| 479 |
+
for entity_position_ids_list in collated_inputs["entity_position_ids"]:
|
| 480 |
+
# pad entity_position_ids to max_entity_token_length
|
| 481 |
+
for entity_position_ids in entity_position_ids_list:
|
| 482 |
+
entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids))
|
| 483 |
+
# pad to max_num_entities
|
| 484 |
+
entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * (
|
| 485 |
+
max_num_entities - len(entity_position_ids_list)
|
| 486 |
+
)
|
| 487 |
+
# Pad entity embeddings
|
| 488 |
if "entity_embeds" in collated_inputs:
|
| 489 |
for i in range(len(collated_inputs["entity_embeds"])):
|
| 490 |
collated_inputs["entity_embeds"][i] = np.pad(
|
| 491 |
collated_inputs["entity_embeds"][i],
|
| 492 |
pad_width=(
|
| 493 |
+
(
|
| 494 |
+
0,
|
| 495 |
+
max_num_entities - len(collated_inputs["entity_embeds"][i]),
|
| 496 |
+
),
|
| 497 |
(0, 0),
|
| 498 |
),
|
| 499 |
mode="constant",
|