Commit
·
c35a42b
1
Parent(s):
7af97e7
fix: when sentences is one
Browse filesSigned-off-by: Meow <[email protected]>
- modeling_lora.py +18 -9
modeling_lora.py
CHANGED
|
@@ -11,8 +11,11 @@ from torch.nn import Parameter
|
|
| 11 |
from torch.nn import functional as F
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
| 14 |
-
from .modeling_xlm_roberta import (
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def initialized_weights(
|
|
@@ -241,6 +244,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 241 |
"""
|
| 242 |
A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
|
| 243 |
"""
|
|
|
|
| 244 |
def __init__(
|
| 245 |
self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
|
| 246 |
):
|
|
@@ -262,7 +266,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 262 |
if (
|
| 263 |
not isinstance(self._task_instructions, dict)
|
| 264 |
or len(self._task_instructions) != len(self._lora_adaptations)
|
| 265 |
-
or not all(
|
|
|
|
|
|
|
| 266 |
):
|
| 267 |
raise ValueError(
|
| 268 |
f"`task_instructions` must be a dict and contain the same number of elements "
|
|
@@ -325,11 +331,11 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 325 |
config = XLMRobertaFlashConfig.from_pretrained(
|
| 326 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 327 |
)
|
| 328 |
-
if config.load_trained_adapters:
|
| 329 |
return super().from_pretrained(
|
| 330 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 331 |
)
|
| 332 |
-
else:
|
| 333 |
roberta = XLMRobertaModel.from_pretrained(
|
| 334 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 335 |
)
|
|
@@ -387,14 +393,17 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 387 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
| 388 |
)
|
| 389 |
adapter_mask = None
|
|
|
|
| 390 |
if task_type:
|
| 391 |
task_id = self._adaptation_map[task_type]
|
| 392 |
-
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
| 393 |
adapter_mask = torch.full(
|
| 394 |
-
(
|
| 395 |
)
|
| 396 |
-
if task_type in [
|
| 397 |
-
sentences = [
|
|
|
|
|
|
|
|
|
|
| 398 |
return self.roberta.encode(
|
| 399 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 400 |
)
|
|
|
|
| 11 |
from torch.nn import functional as F
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
| 14 |
+
from .modeling_xlm_roberta import (
|
| 15 |
+
XLMRobertaFlashConfig,
|
| 16 |
+
XLMRobertaModel,
|
| 17 |
+
XLMRobertaPreTrainedModel,
|
| 18 |
+
)
|
| 19 |
|
| 20 |
|
| 21 |
def initialized_weights(
|
|
|
|
| 244 |
"""
|
| 245 |
A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
|
| 246 |
"""
|
| 247 |
+
|
| 248 |
def __init__(
|
| 249 |
self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
|
| 250 |
):
|
|
|
|
| 266 |
if (
|
| 267 |
not isinstance(self._task_instructions, dict)
|
| 268 |
or len(self._task_instructions) != len(self._lora_adaptations)
|
| 269 |
+
or not all(
|
| 270 |
+
[v in self._lora_adaptations for v in self._task_instructions.keys()]
|
| 271 |
+
)
|
| 272 |
):
|
| 273 |
raise ValueError(
|
| 274 |
f"`task_instructions` must be a dict and contain the same number of elements "
|
|
|
|
| 331 |
config = XLMRobertaFlashConfig.from_pretrained(
|
| 332 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 333 |
)
|
| 334 |
+
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
| 335 |
return super().from_pretrained(
|
| 336 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 337 |
)
|
| 338 |
+
else: # initializing new adapters
|
| 339 |
roberta = XLMRobertaModel.from_pretrained(
|
| 340 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 341 |
)
|
|
|
|
| 393 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
| 394 |
)
|
| 395 |
adapter_mask = None
|
| 396 |
+
sentences = list(sentences) if isinstance(sentences, str) else sentences
|
| 397 |
if task_type:
|
| 398 |
task_id = self._adaptation_map[task_type]
|
|
|
|
| 399 |
adapter_mask = torch.full(
|
| 400 |
+
(len(sentences),), task_id, dtype=torch.int32, device=self.device
|
| 401 |
)
|
| 402 |
+
if task_type in ["query", "passage"]:
|
| 403 |
+
sentences = [
|
| 404 |
+
self._task_instructions[task_type] + " " + sentence
|
| 405 |
+
for sentence in sentences
|
| 406 |
+
]
|
| 407 |
return self.roberta.encode(
|
| 408 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 409 |
)
|