fdschmidt93 commited on
Commit
c90eb91
·
1 Parent(s): b0221f6

feat: support AutoModelForSequenceClassification

Browse files
Files changed (1) hide show
  1. modeling_nllbllm2vec.py +97 -2
modeling_nllbllm2vec.py CHANGED
@@ -1,12 +1,16 @@
1
- from typing import Any, Dict, List, Optional, Tuple, cast
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from transformers.models.auto import AutoModel
7
- from transformers.modeling_outputs import BaseModelOutputWithPooling
 
 
 
8
  from transformers.modeling_utils import PreTrainedModel
9
  from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
 
10
 
11
  from .configuration_nllbllm2vec import NLLBLLM2VecConfig
12
  from .modeling_llama_encoder import LlamaEncoderModel
@@ -479,3 +483,94 @@ def repl():
479
 
480
  with open("./model.safetensors.index.json", "r") as f:
481
  print(json.load(f))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, cast, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from transformers.models.auto import AutoModel
7
+ from transformers.modeling_outputs import (
8
+ BaseModelOutputWithPooling,
9
+ SequenceClassifierOutputWithPast,
10
+ )
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
13
+ from transformers.cache_utils import Cache
14
 
15
  from .configuration_nllbllm2vec import NLLBLLM2VecConfig
16
  from .modeling_llama_encoder import LlamaEncoderModel
 
483
 
484
  with open("./model.safetensors.index.json", "r") as f:
485
  print(json.load(f))
486
+
487
+
488
+ class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
489
+ def __init__(self, config):
490
+ super().__init__(config)
491
+ self.num_labels = config.num_labels
492
+ self.model = NLLBLLM2Vec(config)
493
+ self.score = nn.Linear(
494
+ config.llm2vec_config.hidden_size, self.num_labels, bias=False
495
+ )
496
+
497
+ # Initialize weights and apply final processing
498
+ self.post_init()
499
+
500
+ def get_input_embeddings(self):
501
+ return self.model.nllb.embed_tokens
502
+
503
+ def set_input_embeddings(self, value):
504
+ self.model.nllb.embed_tokens = value
505
+
506
+ def forward(
507
+ self,
508
+ input_ids: Optional[torch.LongTensor] = None,
509
+ attention_mask: Optional[torch.Tensor] = None,
510
+ position_ids: Optional[torch.LongTensor] = None,
511
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
512
+ inputs_embeds: Optional[torch.FloatTensor] = None,
513
+ labels: Optional[torch.LongTensor] = None,
514
+ use_cache: Optional[bool] = None,
515
+ output_attentions: Optional[bool] = None,
516
+ output_hidden_states: Optional[bool] = None,
517
+ return_dict: Optional[bool] = None,
518
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
519
+ r"""
520
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
521
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
522
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
523
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
524
+ """
525
+ return_dict = (
526
+ return_dict if return_dict is not None else self.config.use_return_dict
527
+ )
528
+
529
+ transformer_outputs = self.model(
530
+ input_ids,
531
+ attention_mask=attention_mask,
532
+ position_ids=position_ids,
533
+ past_key_values=past_key_values,
534
+ inputs_embeds=inputs_embeds,
535
+ use_cache=use_cache,
536
+ output_attentions=output_attentions,
537
+ output_hidden_states=output_hidden_states,
538
+ return_dict=return_dict,
539
+ )
540
+ hidden_states = transformer_outputs.pooler_output
541
+ pooled_logits = self.score(hidden_states)
542
+
543
+ loss = None
544
+ if labels is not None:
545
+ if self.config.problem_type is None:
546
+ if self.num_labels == 1:
547
+ self.config.problem_type = "regression"
548
+ elif self.num_labels > 1 and (
549
+ labels.dtype == torch.long or labels.dtype == torch.int
550
+ ):
551
+ self.config.problem_type = "single_label_classification"
552
+ else:
553
+ self.config.problem_type = "multi_label_classification"
554
+
555
+ if self.config.problem_type == "regression":
556
+ if self.num_labels == 1:
557
+ loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze())
558
+ else:
559
+ loss = F.mse_loss(pooled_logits, labels)
560
+ elif self.config.problem_type == "single_label_classification":
561
+ loss = F.cross_entropy(
562
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
563
+ )
564
+ elif self.config.problem_type == "multi_label_classification":
565
+ loss = F.binary_cross_entropy_with_logits(pooled_logits, labels)
566
+ if not return_dict:
567
+ output = (pooled_logits,) + transformer_outputs[1:]
568
+ return ((loss,) + output) if loss is not None else output
569
+
570
+ return SequenceClassifierOutputWithPast(
571
+ loss=loss,
572
+ logits=pooled_logits,
573
+ past_key_values=transformer_outputs.past_key_values,
574
+ hidden_states=transformer_outputs.hidden_states,
575
+ attentions=transformer_outputs.attentions,
576
+ )