Sentence Similarity
Transformers
Safetensors
multilingual
nllb-llm2vec
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
Commit
·
c90eb91
1
Parent(s):
b0221f6
feat: support AutoModelForSequenceClassification
Browse files- modeling_nllbllm2vec.py +97 -2
modeling_nllbllm2vec.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
-
from typing import Any, Dict, List, Optional, Tuple, cast
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from transformers.models.auto import AutoModel
|
7 |
-
from transformers.modeling_outputs import
|
|
|
|
|
|
|
8 |
from transformers.modeling_utils import PreTrainedModel
|
9 |
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
|
|
|
10 |
|
11 |
from .configuration_nllbllm2vec import NLLBLLM2VecConfig
|
12 |
from .modeling_llama_encoder import LlamaEncoderModel
|
@@ -479,3 +483,94 @@ def repl():
|
|
479 |
|
480 |
with open("./model.safetensors.index.json", "r") as f:
|
481 |
print(json.load(f))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple, cast, Union
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from transformers.models.auto import AutoModel
|
7 |
+
from transformers.modeling_outputs import (
|
8 |
+
BaseModelOutputWithPooling,
|
9 |
+
SequenceClassifierOutputWithPast,
|
10 |
+
)
|
11 |
from transformers.modeling_utils import PreTrainedModel
|
12 |
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
|
13 |
+
from transformers.cache_utils import Cache
|
14 |
|
15 |
from .configuration_nllbllm2vec import NLLBLLM2VecConfig
|
16 |
from .modeling_llama_encoder import LlamaEncoderModel
|
|
|
483 |
|
484 |
with open("./model.safetensors.index.json", "r") as f:
|
485 |
print(json.load(f))
|
486 |
+
|
487 |
+
|
488 |
+
class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
|
489 |
+
def __init__(self, config):
|
490 |
+
super().__init__(config)
|
491 |
+
self.num_labels = config.num_labels
|
492 |
+
self.model = NLLBLLM2Vec(config)
|
493 |
+
self.score = nn.Linear(
|
494 |
+
config.llm2vec_config.hidden_size, self.num_labels, bias=False
|
495 |
+
)
|
496 |
+
|
497 |
+
# Initialize weights and apply final processing
|
498 |
+
self.post_init()
|
499 |
+
|
500 |
+
def get_input_embeddings(self):
|
501 |
+
return self.model.nllb.embed_tokens
|
502 |
+
|
503 |
+
def set_input_embeddings(self, value):
|
504 |
+
self.model.nllb.embed_tokens = value
|
505 |
+
|
506 |
+
def forward(
|
507 |
+
self,
|
508 |
+
input_ids: Optional[torch.LongTensor] = None,
|
509 |
+
attention_mask: Optional[torch.Tensor] = None,
|
510 |
+
position_ids: Optional[torch.LongTensor] = None,
|
511 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
512 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
513 |
+
labels: Optional[torch.LongTensor] = None,
|
514 |
+
use_cache: Optional[bool] = None,
|
515 |
+
output_attentions: Optional[bool] = None,
|
516 |
+
output_hidden_states: Optional[bool] = None,
|
517 |
+
return_dict: Optional[bool] = None,
|
518 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
519 |
+
r"""
|
520 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
521 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
522 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
523 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
524 |
+
"""
|
525 |
+
return_dict = (
|
526 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
527 |
+
)
|
528 |
+
|
529 |
+
transformer_outputs = self.model(
|
530 |
+
input_ids,
|
531 |
+
attention_mask=attention_mask,
|
532 |
+
position_ids=position_ids,
|
533 |
+
past_key_values=past_key_values,
|
534 |
+
inputs_embeds=inputs_embeds,
|
535 |
+
use_cache=use_cache,
|
536 |
+
output_attentions=output_attentions,
|
537 |
+
output_hidden_states=output_hidden_states,
|
538 |
+
return_dict=return_dict,
|
539 |
+
)
|
540 |
+
hidden_states = transformer_outputs.pooler_output
|
541 |
+
pooled_logits = self.score(hidden_states)
|
542 |
+
|
543 |
+
loss = None
|
544 |
+
if labels is not None:
|
545 |
+
if self.config.problem_type is None:
|
546 |
+
if self.num_labels == 1:
|
547 |
+
self.config.problem_type = "regression"
|
548 |
+
elif self.num_labels > 1 and (
|
549 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
550 |
+
):
|
551 |
+
self.config.problem_type = "single_label_classification"
|
552 |
+
else:
|
553 |
+
self.config.problem_type = "multi_label_classification"
|
554 |
+
|
555 |
+
if self.config.problem_type == "regression":
|
556 |
+
if self.num_labels == 1:
|
557 |
+
loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze())
|
558 |
+
else:
|
559 |
+
loss = F.mse_loss(pooled_logits, labels)
|
560 |
+
elif self.config.problem_type == "single_label_classification":
|
561 |
+
loss = F.cross_entropy(
|
562 |
+
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
563 |
+
)
|
564 |
+
elif self.config.problem_type == "multi_label_classification":
|
565 |
+
loss = F.binary_cross_entropy_with_logits(pooled_logits, labels)
|
566 |
+
if not return_dict:
|
567 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
568 |
+
return ((loss,) + output) if loss is not None else output
|
569 |
+
|
570 |
+
return SequenceClassifierOutputWithPast(
|
571 |
+
loss=loss,
|
572 |
+
logits=pooled_logits,
|
573 |
+
past_key_values=transformer_outputs.past_key_values,
|
574 |
+
hidden_states=transformer_outputs.hidden_states,
|
575 |
+
attentions=transformer_outputs.attentions,
|
576 |
+
)
|