Jackmin108
commited on
Commit
•
7af97e7
1
Parent(s):
2646361
feat: add lora instructions for retrieval
Browse filesSigned-off-by: Meow <[email protected]>
- configuration_xlm_roberta.py +2 -2
- modeling_lora.py +8 -6
configuration_xlm_roberta.py
CHANGED
@@ -27,7 +27,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
27 |
use_cache: bool = True,
|
28 |
classifier_dropout: Optional[float] = None,
|
29 |
lora_adaptations: Optional[List[str]] = None,
|
30 |
-
|
31 |
lora_rank: int = 4,
|
32 |
lora_dropout_p: float = 0.0,
|
33 |
lora_alpha: int = 1,
|
@@ -103,7 +103,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
103 |
self.classifier_dropout = classifier_dropout
|
104 |
self.load_trained_adapters = load_trained_adapters
|
105 |
self.lora_adaptations = lora_adaptations
|
106 |
-
self.
|
107 |
self.lora_rank = lora_rank
|
108 |
self.lora_dropout_p = lora_dropout_p
|
109 |
self.lora_alpha = lora_alpha
|
|
|
27 |
use_cache: bool = True,
|
28 |
classifier_dropout: Optional[float] = None,
|
29 |
lora_adaptations: Optional[List[str]] = None,
|
30 |
+
task_instructions: Optional[Dict[str, str]] = None,
|
31 |
lora_rank: int = 4,
|
32 |
lora_dropout_p: float = 0.0,
|
33 |
lora_alpha: int = 1,
|
|
|
103 |
self.classifier_dropout = classifier_dropout
|
104 |
self.load_trained_adapters = load_trained_adapters
|
105 |
self.lora_adaptations = lora_adaptations
|
106 |
+
self.task_instructions = task_instructions
|
107 |
self.lora_rank = lora_rank
|
108 |
self.lora_dropout_p = lora_dropout_p
|
109 |
self.lora_alpha = lora_alpha
|
modeling_lora.py
CHANGED
@@ -258,15 +258,15 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
258 |
raise ValueError(
|
259 |
f"`lora_adaptations` must be a list and contain at least one element"
|
260 |
)
|
261 |
-
self.
|
262 |
if (
|
263 |
-
not isinstance(self.
|
264 |
-
or len(self.
|
265 |
-
or not all([v in self._lora_adaptations for v in self.
|
266 |
):
|
267 |
raise ValueError(
|
268 |
-
f"`
|
269 |
-
f"as `lora_adaptations` with all keys in `
|
270 |
)
|
271 |
self._adaptation_map = {
|
272 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
@@ -393,6 +393,8 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
393 |
adapter_mask = torch.full(
|
394 |
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
395 |
)
|
|
|
|
|
396 |
return self.roberta.encode(
|
397 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
398 |
)
|
|
|
258 |
raise ValueError(
|
259 |
f"`lora_adaptations` must be a list and contain at least one element"
|
260 |
)
|
261 |
+
self._task_instructions = config.task_instructions
|
262 |
if (
|
263 |
+
not isinstance(self._task_instructions, dict)
|
264 |
+
or len(self._task_instructions) != len(self._lora_adaptations)
|
265 |
+
or not all([v in self._lora_adaptations for v in self._task_instructions.keys()])
|
266 |
):
|
267 |
raise ValueError(
|
268 |
+
f"`task_instructions` must be a dict and contain the same number of elements "
|
269 |
+
f"as `lora_adaptations` with all keys in `task_instructions` present in `lora_adaptations`."
|
270 |
)
|
271 |
self._adaptation_map = {
|
272 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
|
|
393 |
adapter_mask = torch.full(
|
394 |
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
395 |
)
|
396 |
+
if task_type in ['query', 'passage']:
|
397 |
+
sentences = [self._task_instructions[task_type] + ' ' + sentence for sentence in sentences]
|
398 |
return self.roberta.encode(
|
399 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
400 |
)
|