Spaces:
Runtime error
Runtime error
| # Copyright 2023-present the HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from peft.tuners.prompt_tuning import PromptEmbedding | |
| from peft.utils import TaskType | |
| from .config import MultitaskPromptTuningConfig, MultitaskPromptTuningInit | |
| # This code is adapted for the paper: https://arxiv.org/abs/2303.02861 and | |
| # constitutes the work done at MIT-IBM Watson Research Lab. | |
| class MultitaskPromptEmbedding(PromptEmbedding): | |
| def __init__(self, config: MultitaskPromptTuningConfig, word_embeddings): | |
| super().__init__(config, word_embeddings) | |
| self.num_tasks = config.num_tasks | |
| self.num_ranks = config.num_ranks | |
| self.num_virtual_tokens = config.num_virtual_tokens | |
| self.num_transformer_submodules = config.num_transformer_submodules | |
| if self.num_transformer_submodules is None: | |
| self.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1 | |
| self.token_dim = config.token_dim | |
| total_virtual_tokens = self.num_virtual_tokens * self.num_transformer_submodules | |
| self.prefix_task_cols = torch.nn.Parameter( | |
| torch.normal( | |
| mean=0, | |
| std=0.02, | |
| size=(self.num_tasks, total_virtual_tokens, self.num_ranks), | |
| ) | |
| ) | |
| self.prefix_task_rows = torch.nn.Parameter( | |
| torch.normal( | |
| mean=0, | |
| std=0.02, | |
| size=(self.num_tasks, self.num_ranks, self.token_dim), | |
| ) | |
| ) | |
| if config.prompt_tuning_init in [ | |
| MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, | |
| MultitaskPromptTuningInit.EXACT_SOURCE_TASK, | |
| MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, | |
| ]: | |
| if config.prompt_tuning_init_state_dict_path is None: | |
| raise ValueError( | |
| f"prompt_tuning_init_state_dict_path needs to be specified with {config.prompt_tuning_init} " | |
| "init method" | |
| ) | |
| if config.prompt_tuning_init_state_dict_path.endswith(".safetensors"): | |
| from safetensors.torch import load_file | |
| state_dict: dict = load_file(config.prompt_tuning_init_state_dict_path) | |
| else: | |
| state_dict: dict = torch.load( | |
| config.prompt_tuning_init_state_dict_path, | |
| map_location=word_embeddings.weight.device, | |
| ) | |
| if config.prompt_tuning_init in [ | |
| MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, | |
| MultitaskPromptTuningInit.EXACT_SOURCE_TASK, | |
| ]: | |
| prefix_task_cols_: torch.Tensor = state_dict["prefix_task_cols"] | |
| prefix_task_rows_: torch.Tensor = state_dict["prefix_task_rows"] | |
| if config.prompt_tuning_init == MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS: | |
| prefix_task_cols_ = prefix_task_cols_.mean(0, keepdim=True) | |
| prefix_task_rows_ = prefix_task_rows_.mean(0, keepdim=True) | |
| elif config.prompt_tuning_init == MultitaskPromptTuningInit.EXACT_SOURCE_TASK: | |
| prefix_task_cols_ = prefix_task_cols_[config.prompt_tuning_init_task, ...].unsqueeze(0) | |
| prefix_task_rows_ = prefix_task_rows_[config.prompt_tuning_init_task, ...].unsqueeze(0) | |
| state_dict = { | |
| "embedding.weight": state_dict["prompt_embeddings"], | |
| "prefix_task_cols": prefix_task_cols_, | |
| "prefix_task_rows": prefix_task_rows_, | |
| } | |
| self.load_state_dict(state_dict, strict=True) | |
| elif config.prompt_tuning_init == MultitaskPromptTuningInit.ONLY_SOURCE_SHARED: | |
| state_dict = { | |
| "embedding.weight": state_dict["prompt_embeddings"], | |
| } | |
| self.load_state_dict(state_dict, strict=False) | |
| def forward(self, indices, task_ids): | |
| if task_ids is None: | |
| raise ValueError("task_ids cannot be None") | |
| prompt_embeddings = self.embedding(indices) | |
| task_cols = torch.index_select(self.prefix_task_cols, 0, task_ids) | |
| task_rows = torch.index_select(self.prefix_task_rows, 0, task_ids) | |
| task_prompts = torch.matmul(task_cols, task_rows) | |
| prompt_embeddings *= task_prompts | |
| return prompt_embeddings | |