Commit
·
3703946
1
Parent(s):
851aaca
refactor: stuff
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- modeling_lora.py +37 -26
modeling_lora.py
CHANGED
|
@@ -14,6 +14,9 @@ from transformers import PretrainedConfig
|
|
| 14 |
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
def initialized_weights(
|
| 18 |
shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
|
| 19 |
) -> torch.Tensor:
|
|
@@ -214,7 +217,17 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 214 |
):
|
| 215 |
super().__init__(config)
|
| 216 |
|
| 217 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
self._rank = config.lora_rank
|
| 219 |
self._dropout_p = config.lora_dropout_p
|
| 220 |
self._alpha = config.lora_alpha
|
|
@@ -294,14 +307,20 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 294 |
return self._task_idx
|
| 295 |
|
| 296 |
@current_task.setter
|
| 297 |
-
def current_task(self,
|
| 298 |
"""Set the LoRA that is to be used.
|
| 299 |
The LoRA is specified by `task_idx`, which may be an integer >= 0,
|
| 300 |
indexing the available LoRAs. If it is None, no LoRA is used.
|
| 301 |
-
:param
|
| 302 |
:return:
|
| 303 |
"""
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
if self._task_idx != task_idx:
|
| 306 |
# In this case, we need to update the LoRAs everywhere
|
| 307 |
self._task_idx = task_idx
|
|
@@ -309,9 +328,9 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 309 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
| 310 |
)
|
| 311 |
|
| 312 |
-
def forward(self, *args,
|
| 313 |
-
if
|
| 314 |
-
self.current_task =
|
| 315 |
return super().forward(*args, **kwargs)
|
| 316 |
|
| 317 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
@@ -331,35 +350,27 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
| 331 |
def encode(
|
| 332 |
self,
|
| 333 |
*args,
|
| 334 |
-
task:
|
| 335 |
**kwargs,
|
| 336 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 337 |
"""
|
| 338 |
Computes sentence embeddings
|
| 339 |
|
| 340 |
-
task(`str`, *optional*, defaults to
|
| 341 |
-
Specifies the task for which the encoding is intended. This
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
model
|
|
|
|
| 346 |
"""
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
if task:
|
| 350 |
-
if task in self.config.lora_adaptations:
|
| 351 |
-
lora_adapter_num = self.config.lora_adaptations.index(task)
|
| 352 |
-
else:
|
| 353 |
-
raise ValueError(
|
| 354 |
-
f"Unsupported task '{task}'. "
|
| 355 |
-
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 356 |
-
)
|
| 357 |
-
else:
|
| 358 |
warnings.warn(
|
| 359 |
f"Task-specific embeddings are disabled. To enable, specify the `task` "
|
| 360 |
f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
|
| 361 |
category=UserWarning,
|
| 362 |
)
|
| 363 |
-
|
| 364 |
|
| 365 |
return super().encode(*args, **kwargs)
|
|
|
|
| 14 |
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
|
| 15 |
|
| 16 |
|
| 17 |
+
LORA_NO_UPDATE = '__lora_no_update__'
|
| 18 |
+
|
| 19 |
+
|
| 20 |
def initialized_weights(
|
| 21 |
shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
|
| 22 |
) -> torch.Tensor:
|
|
|
|
| 217 |
):
|
| 218 |
super().__init__(config)
|
| 219 |
|
| 220 |
+
self._lora_adaptations = config.lora_adaptations
|
| 221 |
+
if (
|
| 222 |
+
not isinstance(self._lora_adaptations, list)
|
| 223 |
+
or len(self._lora_adaptations) < 1
|
| 224 |
+
):
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f'`lora_adaptations` must be a list and contain at least one element'
|
| 227 |
+
)
|
| 228 |
+
self._adaptation_map = {
|
| 229 |
+
name: idx for idx, name in enumerate(self._lora_adaptations)
|
| 230 |
+
}
|
| 231 |
self._rank = config.lora_rank
|
| 232 |
self._dropout_p = config.lora_dropout_p
|
| 233 |
self._alpha = config.lora_alpha
|
|
|
|
| 307 |
return self._task_idx
|
| 308 |
|
| 309 |
@current_task.setter
|
| 310 |
+
def current_task(self, task_name: Union[None, str]):
|
| 311 |
"""Set the LoRA that is to be used.
|
| 312 |
The LoRA is specified by `task_idx`, which may be an integer >= 0,
|
| 313 |
indexing the available LoRAs. If it is None, no LoRA is used.
|
| 314 |
+
:param task_name: Which LoRA to use
|
| 315 |
:return:
|
| 316 |
"""
|
| 317 |
+
if task_name and task_name not in self._lora_adaptations:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"Unsupported task '{task_name}'. "
|
| 320 |
+
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 321 |
+
f"Alternatively, set `task` to `None` if you want to disable LoRA."
|
| 322 |
+
)
|
| 323 |
+
task_idx = self._adaptation_map[task_name] if task_name else None
|
| 324 |
if self._task_idx != task_idx:
|
| 325 |
# In this case, we need to update the LoRAs everywhere
|
| 326 |
self._task_idx = task_idx
|
|
|
|
| 328 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
| 329 |
)
|
| 330 |
|
| 331 |
+
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
| 332 |
+
if task != LORA_NO_UPDATE:
|
| 333 |
+
self.current_task = task
|
| 334 |
return super().forward(*args, **kwargs)
|
| 335 |
|
| 336 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
|
|
| 350 |
def encode(
|
| 351 |
self,
|
| 352 |
*args,
|
| 353 |
+
task: Union[str, None] = LORA_NO_UPDATE,
|
| 354 |
**kwargs,
|
| 355 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 356 |
"""
|
| 357 |
Computes sentence embeddings
|
| 358 |
|
| 359 |
+
task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
|
| 360 |
+
Specifies the task for which the encoding is intended. This parameter controls the
|
| 361 |
+
use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
|
| 362 |
+
to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
|
| 363 |
+
existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
|
| 364 |
+
adapters are disabled, and the model reverts to its original, general-purpose weights.
|
| 365 |
+
If `task` is set to a specific LoRA adaptation, that adaptation is activated.
|
| 366 |
"""
|
| 367 |
+
if task != LORA_NO_UPDATE:
|
| 368 |
+
if not task:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
warnings.warn(
|
| 370 |
f"Task-specific embeddings are disabled. To enable, specify the `task` "
|
| 371 |
f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
|
| 372 |
category=UserWarning,
|
| 373 |
)
|
| 374 |
+
self.current_task = task
|
| 375 |
|
| 376 |
return super().encode(*args, **kwargs)
|