Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # coding=utf-8 | |
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import re | |
| from typing import List, Literal, Optional | |
| from datasets import DatasetDict, concatenate_datasets, load_dataset | |
| from .configs import DataArguments | |
| DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" | |
| def apply_chat_template( | |
| example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n" | |
| ): | |
| def _strip_prefix(s, pattern): | |
| # Use re.escape to escape any special characters in the pattern | |
| return re.sub(f"^{re.escape(pattern)}", "", s) | |
| if task in ["sft", "generation"]: | |
| messages = example["messages"] | |
| # We add an empty system message if there is none | |
| if messages[0]["role"] != "system": | |
| messages.insert(0, {"role": "system", "content": ""}) | |
| example["text"] = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True if task == "generation" else False | |
| ) | |
| elif task == "rm": | |
| if all(k in example.keys() for k in ("chosen", "rejected")): | |
| chosen_messages = example["chosen"] | |
| rejected_messages = example["rejected"] | |
| # We add an empty system message if there is none | |
| if chosen_messages[0]["role"] != "system": | |
| chosen_messages.insert(0, {"role": "system", "content": ""}) | |
| if rejected_messages[0]["role"] != "system": | |
| rejected_messages.insert(0, {"role": "system", "content": ""}) | |
| example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) | |
| example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) | |
| else: | |
| raise ValueError( | |
| f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" | |
| ) | |
| elif task == "dpo": | |
| if all(k in example.keys() for k in ("chosen", "rejected")): | |
| # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token | |
| prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]] | |
| # Insert system message | |
| if example["chosen"][0]["role"] != "system": | |
| prompt_messages.insert(0, {"role": "system", "content": ""}) | |
| else: | |
| prompt_messages.insert(0, example["chosen"][0]) | |
| # TODO: handle case where chosen/rejected also have system messages | |
| chosen_messages = example["chosen"][1:] | |
| rejected_messages = example["rejected"][1:] | |
| example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) | |
| example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) | |
| example["text_prompt"] = tokenizer.apply_chat_template( | |
| prompt_messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix) | |
| example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix) | |
| else: | |
| raise ValueError( | |
| f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" | |
| ) | |
| return example | |
| def get_datasets( | |
| data_config: DataArguments | dict, | |
| splits: List[str] = ["train", "test"], | |
| shuffle: bool = True, | |
| ) -> DatasetDict: | |
| """ | |
| Loads one or more datasets with varying training set proportions. | |
| Args: | |
| data_config (`DataArguments` or `dict`): | |
| Dataset configuration and split proportions. | |
| splits (`List[str]`, *optional*, defaults to `['train', 'test']`): | |
| Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. | |
| shuffle (`bool`, *optional*, defaults to `True`): | |
| Whether to shuffle the training data. | |
| Returns | |
| [`DatasetDict`]: The dataset dictionary containing the loaded datasets. | |
| """ | |
| if type(data_config) is DataArguments: | |
| # Structure of the config to read the datasets and their mix | |
| # datasets_mixer: | |
| # - 'dataset1': 0.5 | |
| # - 'dataset2': 0.3 | |
| # - 'dataset3': 0.2 | |
| dataset_mixer = data_config.dataset_mixer | |
| elif type(data_config) is dict: | |
| # Structure of the input is: | |
| # dataset_mixer = { | |
| # "dataset1": 0.5, | |
| # "dataset1": 0.3, | |
| # "dataset1": 0.2, | |
| # } | |
| dataset_mixer = data_config | |
| else: | |
| raise ValueError(f"Data config {data_config} not recognized.") | |
| raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle) | |
| return raw_datasets | |
| def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict: | |
| """ | |
| Loads and mixes datasets according to proportions specified in `dataset_mixer`. | |
| Args: | |
| dataset_mixer (`dict`): | |
| Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. | |
| splits (Optional[List[str]], *optional*, defaults to `None`): | |
| Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. | |
| shuffle (`bool`, *optional*, defaults to `True`): | |
| Whether to shuffle the training data. | |
| """ | |
| raw_datasets = DatasetDict() | |
| raw_train_datasets = [] | |
| raw_val_datasets = [] | |
| fracs = [] | |
| for ds, frac in dataset_mixer.items(): | |
| fracs.append(frac) | |
| for split in splits: | |
| if "train" in split: | |
| raw_train_datasets.append( | |
| load_dataset( | |
| ds, | |
| split=split, | |
| ) | |
| ) | |
| elif "test" in split: | |
| raw_val_datasets.append( | |
| load_dataset( | |
| ds, | |
| split=split, | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Split type {split} not recognized as one of test or train.") | |
| if any(frac < 0 for frac in fracs): | |
| raise ValueError("Dataset fractions cannot be negative.") | |
| if len(raw_train_datasets) > 0: | |
| train_subsets = [] | |
| for dataset, frac in zip(raw_train_datasets, fracs): | |
| train_subset = dataset.select(range(int(frac * len(dataset)))) | |
| train_subsets.append(train_subset) | |
| if shuffle: | |
| raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) | |
| else: | |
| raw_datasets["train"] = concatenate_datasets(train_subsets) | |
| # No subsampling for test datasets to enable fair comparison across models | |
| if len(raw_val_datasets) > 0: | |
| if shuffle: | |
| raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) | |
| else: | |
| raw_datasets["test"] = concatenate_datasets(raw_val_datasets) | |
| if len(raw_datasets) == 0: | |
| raise ValueError( | |
| f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted." | |
| ) | |
| return raw_datasets | |
