fix new dataset prompt tokenizers
Browse files
    	
        src/axolotl/datasets.py
    CHANGED
    
    | @@ -106,7 +106,7 @@ class ConstantLengthDataset(IterableDataset): | |
| 106 | 
             
                                        }
         | 
| 107 | 
             
                                    else:
         | 
| 108 | 
             
                                        logging.warning(
         | 
| 109 | 
            -
                                            "dropping batch due to tensor size mismatch"
         | 
| 110 | 
             
                                        )
         | 
| 111 | 
             
                                buffer = {"input_ids": [], "attention_mask": [], "labels": []}
         | 
| 112 | 
             
                                buffer_len = 0
         | 
|  | |
| 106 | 
             
                                        }
         | 
| 107 | 
             
                                    else:
         | 
| 108 | 
             
                                        logging.warning(
         | 
| 109 | 
            +
                                            f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
         | 
| 110 | 
             
                                        )
         | 
| 111 | 
             
                                buffer = {"input_ids": [], "attention_mask": [], "labels": []}
         | 
| 112 | 
             
                                buffer_len = 0
         | 
    	
        src/axolotl/prompt_strategies/__init__.py
    CHANGED
    
    | @@ -1,11 +1,13 @@ | |
| 1 | 
             
            import importlib
         | 
| 2 | 
            -
            from functools import cache
         | 
| 3 |  | 
| 4 | 
            -
            @cache
         | 
| 5 | 
             
            def load(strategy, tokenizer, cfg):
         | 
| 6 | 
             
                try:
         | 
| 7 | 
            -
                     | 
| 8 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 9 | 
             
                    return fn(tokenizer, cfg)
         | 
| 10 | 
             
                except:
         | 
| 11 | 
             
                    pass
         | 
|  | |
| 1 | 
             
            import importlib
         | 
|  | |
| 2 |  | 
|  | |
| 3 | 
             
            def load(strategy, tokenizer, cfg):
         | 
| 4 | 
             
                try:
         | 
| 5 | 
            +
                    load_fn = "load"
         | 
| 6 | 
            +
                    if strategy.split(".")[-1].startswith("load_"):
         | 
| 7 | 
            +
                        load_fn = strategy.split(".")[-1]
         | 
| 8 | 
            +
                        strategy = ".".join(strategy.split(".")[:-1])
         | 
| 9 | 
            +
                    m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
         | 
| 10 | 
            +
                    fn = getattr(m, load_fn)
         | 
| 11 | 
             
                    return fn(tokenizer, cfg)
         | 
| 12 | 
             
                except:
         | 
| 13 | 
             
                    pass
         | 
    	
        src/axolotl/prompt_strategies/creative_acr.py
    ADDED
    
    | @@ -0,0 +1,137 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Union, Generator
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import yaml
         | 
| 4 | 
            +
            from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
         | 
| 8 | 
            +
                def parse_instruction_fields(self, prompt) -> (str, str, str):
         | 
| 9 | 
            +
                    question = prompt["instruction"]
         | 
| 10 | 
            +
                    answer = prompt["revision"]  # don't use prompt[answer], that's data we don't want in the dataset
         | 
| 11 | 
            +
                    return (
         | 
| 12 | 
            +
                        question,
         | 
| 13 | 
            +
                        "",
         | 
| 14 | 
            +
                        answer,
         | 
| 15 | 
            +
                    )
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
         | 
| 19 | 
            +
                user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
         | 
| 20 | 
            +
            refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
         | 
| 21 | 
            +
            prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
         | 
| 22 | 
            +
            creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
         | 
| 23 | 
            +
            comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
         | 
| 24 | 
            +
            Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            Output your answer in YAML format like so:
         | 
| 27 | 
            +
            scores:
         | 
| 28 | 
            +
              refusal: <0-10>
         | 
| 29 | 
            +
              prescriptive_bias: <0-10>
         | 
| 30 | 
            +
              creativity: <0-10>
         | 
| 31 | 
            +
              comprehensiveness: <0-10>
         | 
| 32 | 
            +
            critiques:
         | 
| 33 | 
            +
              refusal:
         | 
| 34 | 
            +
                explanation: ...
         | 
| 35 | 
            +
                improvements: ...
         | 
| 36 | 
            +
              prescriptive_bias:
         | 
| 37 | 
            +
                explanation: ...
         | 
| 38 | 
            +
                improvements: ...
         | 
| 39 | 
            +
              creativity:
         | 
| 40 | 
            +
                explanation: ...
         | 
| 41 | 
            +
                improvements: ...
         | 
| 42 | 
            +
              comprehensiveness:
         | 
| 43 | 
            +
                explanation: ...
         | 
| 44 | 
            +
                improvements: ...
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            Question: {question}
         | 
| 47 | 
            +
            Answer: {answer}
         | 
| 48 | 
            +
            """
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def parse_instruction_fields(self, prompt) -> (str, str, str):
         | 
| 51 | 
            +
                    scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
         | 
| 52 | 
            +
                    critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
         | 
| 53 | 
            +
                    evaluation = scores + critiques
         | 
| 54 | 
            +
                    question = prompt["instruction"]
         | 
| 55 | 
            +
                    answer = prompt["answer"]
         | 
| 56 | 
            +
                    return (
         | 
| 57 | 
            +
                        self.user_prompt.format(question=question, answer=answer),
         | 
| 58 | 
            +
                        "",
         | 
| 59 | 
            +
                        evaluation,
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
         | 
| 64 | 
            +
                user_prompt = """Definitions:
         | 
| 65 | 
            +
            refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
         | 
| 66 | 
            +
            prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
         | 
| 67 | 
            +
            creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
         | 
| 68 | 
            +
            comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            Question: {question}
         | 
| 73 | 
            +
            Answer: {answer}
         | 
| 74 | 
            +
            Evaluation:
         | 
| 75 | 
            +
            {evaluation}
         | 
| 76 | 
            +
            """
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def parse_instruction_fields(self, prompt) -> (str, str, str):
         | 
| 79 | 
            +
                    scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
         | 
| 80 | 
            +
                    critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
         | 
| 81 | 
            +
                    evaluation = scores + critiques
         | 
| 82 | 
            +
                    question = prompt["instruction"]
         | 
| 83 | 
            +
                    answer = prompt["answer"]
         | 
| 84 | 
            +
                    return (
         | 
| 85 | 
            +
                        self.user_prompt.format(question=question, answer=answer, evaluation=evaluation),
         | 
| 86 | 
            +
                        "",
         | 
| 87 | 
            +
                        prompt["revision"],
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            class CreativePrompterBase:
         | 
| 92 | 
            +
                system_prompt = ""
         | 
| 93 | 
            +
                prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def build_prompt(
         | 
| 96 | 
            +
                    self,
         | 
| 97 | 
            +
                    instruction: str,
         | 
| 98 | 
            +
                    input: Union[None, str] = None,
         | 
| 99 | 
            +
                    output: Union[None, str] = None,
         | 
| 100 | 
            +
                ) -> Generator[str, None, None]:
         | 
| 101 | 
            +
                    if self.system_prompt:
         | 
| 102 | 
            +
                        res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:"
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        res = f"USER: {instruction}\nASSISTANT:"
         | 
| 105 | 
            +
                    if output:
         | 
| 106 | 
            +
                        res = f"{res}{output}"
         | 
| 107 | 
            +
                    yield res
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            class CreativeAnswerPrompter(CreativePrompterBase):
         | 
| 111 | 
            +
                system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            class CreativeCritiquePrompter(CreativePrompterBase):
         | 
| 115 | 
            +
                system_prompt = ""
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class CreativeRevisePrompter(CreativePrompterBase):
         | 
| 119 | 
            +
                system_prompt = ""
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            def load_answer(tokenizer, cfg):
         | 
| 123 | 
            +
                return CreativeAnsweringPromptTokenizingStrategy(
         | 
| 124 | 
            +
                    CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
         | 
| 125 | 
            +
                )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def load_critique(tokenizer, cfg):
         | 
| 129 | 
            +
                return CreativeCritiquePromptTokenizingStrategy(
         | 
| 130 | 
            +
                    CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
         | 
| 131 | 
            +
                )
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def load_revise(tokenizer, cfg):
         | 
| 135 | 
            +
                return CreativeRevisePromptTokenizingStrategy(
         | 
| 136 | 
            +
                    CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
         | 
| 137 | 
            +
                )
         | 
    	
        src/axolotl/prompt_strategies/pygmalion.py
    CHANGED
    
    | @@ -41,9 +41,9 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): | |
| 41 | 
             
                        elif role == "bot":
         | 
| 42 | 
             
                            prefix = "<|model|>"
         | 
| 43 | 
             
                            res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
         | 
| 44 | 
            -
                            res["input_ids"] = [*self.bot_prefix_token_ids, *res["input_ids"]]
         | 
| 45 | 
             
                            # mask out the prefix token, rest is not masked out from labels
         | 
| 46 | 
            -
                             | 
|  | |
| 47 | 
             
                        else:
         | 
| 48 | 
             
                            logging.warning(f"unknown role in conversation: {role}")
         | 
| 49 | 
             
                            res = defaultdict(lambda: [])
         | 
|  | |
| 41 | 
             
                        elif role == "bot":
         | 
| 42 | 
             
                            prefix = "<|model|>"
         | 
| 43 | 
             
                            res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
         | 
|  | |
| 44 | 
             
                            # mask out the prefix token, rest is not masked out from labels
         | 
| 45 | 
            +
                            # make sure we create the labels first, otherwise we get incorrect lengths
         | 
| 46 | 
            +
                            labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):]
         | 
| 47 | 
             
                        else:
         | 
| 48 | 
             
                            logging.warning(f"unknown role in conversation: {role}")
         | 
| 49 | 
             
                            res = defaultdict(lambda: [])
         | 
    	
        src/axolotl/utils/data.py
    CHANGED
    
    | @@ -75,7 +75,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa | |
| 75 | 
             
                        ds = None
         | 
| 76 | 
             
                        ds_from_hub = False
         | 
| 77 | 
             
                        try:
         | 
| 78 | 
            -
                            load_dataset(d.path, streaming=True)
         | 
| 79 | 
             
                            ds_from_hub = True
         | 
| 80 | 
             
                        except FileNotFoundError:
         | 
| 81 | 
             
                            pass
         | 
| @@ -83,18 +83,18 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa | |
| 83 | 
             
                        # prefer local dataset, even if hub exists
         | 
| 84 | 
             
                        if Path(d.path).exists():
         | 
| 85 | 
             
                            ds: IterableDataset = load_dataset(
         | 
| 86 | 
            -
                                "json", data_files=d.path, streaming= | 
| 87 | 
             
                            )
         | 
| 88 | 
             
                        elif ds_from_hub:
         | 
| 89 | 
             
                            if d.data_files:
         | 
| 90 | 
            -
                                ds = load_dataset(d.path, streaming= | 
| 91 | 
             
                            else:
         | 
| 92 | 
            -
                                ds = load_dataset(d.path, streaming=True)
         | 
| 93 | 
             
                        else:
         | 
| 94 | 
             
                            fp = hf_hub_download(
         | 
| 95 | 
             
                                repo_id=d.path, repo_type="dataset", filename=d.data_files
         | 
| 96 | 
             
                            )
         | 
| 97 | 
            -
                            ds = load_dataset("json", data_files=fp, streaming= | 
| 98 | 
             
                        if not ds:
         | 
| 99 | 
             
                            raise Exception("unhandled dataset load")
         | 
| 100 | 
             
                        d_type = d.type
         | 
|  | |
| 75 | 
             
                        ds = None
         | 
| 76 | 
             
                        ds_from_hub = False
         | 
| 77 | 
             
                        try:
         | 
| 78 | 
            +
                            load_dataset(d.path, streaming=True, use_auth_token=True)
         | 
| 79 | 
             
                            ds_from_hub = True
         | 
| 80 | 
             
                        except FileNotFoundError:
         | 
| 81 | 
             
                            pass
         | 
|  | |
| 83 | 
             
                        # prefer local dataset, even if hub exists
         | 
| 84 | 
             
                        if Path(d.path).exists():
         | 
| 85 | 
             
                            ds: IterableDataset = load_dataset(
         | 
| 86 | 
            +
                                "json", data_files=d.path, streaming=False, split=None
         | 
| 87 | 
             
                            )
         | 
| 88 | 
             
                        elif ds_from_hub:
         | 
| 89 | 
             
                            if d.data_files:
         | 
| 90 | 
            +
                                ds = load_dataset(d.path, streaming=False, data_files=d.data_files, use_auth_token=True)
         | 
| 91 | 
             
                            else:
         | 
| 92 | 
            +
                                ds = load_dataset(d.path, streaming=False, use_auth_token=True)
         | 
| 93 | 
             
                        else:
         | 
| 94 | 
             
                            fp = hf_hub_download(
         | 
| 95 | 
             
                                repo_id=d.path, repo_type="dataset", filename=d.data_files
         | 
| 96 | 
             
                            )
         | 
| 97 | 
            +
                            ds = load_dataset("json", data_files=fp, streaming=False, split=None)
         | 
| 98 | 
             
                        if not ds:
         | 
| 99 | 
             
                            raise Exception("unhandled dataset load")
         | 
| 100 | 
             
                        d_type = d.type
         | 
