Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	hide evaluation, update temperature
Browse files- src/synthetic_dataset_generator/app.py +3 -3
- src/synthetic_dataset_generator/apps/base.py +0 -44
- src/synthetic_dataset_generator/apps/sft.py +13 -8
- src/synthetic_dataset_generator/apps/textcat.py +10 -4
- src/synthetic_dataset_generator/pipelines/sft.py +7 -7
- src/synthetic_dataset_generator/pipelines/textcat.py +6 -5
    	
        src/synthetic_dataset_generator/app.py
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
             
            from synthetic_dataset_generator._tabbedinterface import TabbedInterface
         | 
| 2 | 
            -
            from synthetic_dataset_generator.apps.eval import app as eval_app
         | 
| 3 | 
             
            from synthetic_dataset_generator.apps.readme import app as readme_app
         | 
| 4 | 
             
            from synthetic_dataset_generator.apps.sft import app as sft_app
         | 
| 5 | 
             
            from synthetic_dataset_generator.apps.textcat import app as textcat_app
         | 
| @@ -23,8 +23,8 @@ button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-prima | |
| 23 | 
             
            image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
         | 
| 24 |  | 
| 25 | 
             
            demo = TabbedInterface(
         | 
| 26 | 
            -
                [textcat_app, sft_app,  | 
| 27 | 
            -
                ["Text Classification", "Supervised Fine-Tuning", " | 
| 28 | 
             
                css=css,
         | 
| 29 | 
             
                title=image,
         | 
| 30 | 
             
                head="Synthetic Data Generator",
         | 
|  | |
| 1 | 
             
            from synthetic_dataset_generator._tabbedinterface import TabbedInterface
         | 
| 2 | 
            +
            # from synthetic_dataset_generator.apps.eval import app as eval_app
         | 
| 3 | 
             
            from synthetic_dataset_generator.apps.readme import app as readme_app
         | 
| 4 | 
             
            from synthetic_dataset_generator.apps.sft import app as sft_app
         | 
| 5 | 
             
            from synthetic_dataset_generator.apps.textcat import app as textcat_app
         | 
|  | |
| 23 | 
             
            image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
         | 
| 24 |  | 
| 25 | 
             
            demo = TabbedInterface(
         | 
| 26 | 
            +
                [textcat_app, sft_app, readme_app],
         | 
| 27 | 
            +
                ["Text Classification", "Supervised Fine-Tuning", "README"],
         | 
| 28 | 
             
                css=css,
         | 
| 29 | 
             
                title=image,
         | 
| 30 | 
             
                head="Synthetic Data Generator",
         | 
    	
        src/synthetic_dataset_generator/apps/base.py
    CHANGED
    
    | @@ -67,50 +67,6 @@ def push_pipeline_code_to_hub( | |
| 67 | 
             
                progress(1.0, desc="Pipeline code uploaded")
         | 
| 68 |  | 
| 69 |  | 
| 70 | 
            -
            def push_dataset_to_hub(
         | 
| 71 | 
            -
                dataframe: pd.DataFrame,
         | 
| 72 | 
            -
                private: bool = True,
         | 
| 73 | 
            -
                org_name: str = None,
         | 
| 74 | 
            -
                repo_name: str = None,
         | 
| 75 | 
            -
                oauth_token: Union[OAuthToken, None] = None,
         | 
| 76 | 
            -
                progress=gr.Progress(),
         | 
| 77 | 
            -
                labels: List[str] = None,
         | 
| 78 | 
            -
                num_labels: int = None,
         | 
| 79 | 
            -
                task: str = TEXTCAT_TASK,
         | 
| 80 | 
            -
            ) -> pd.DataFrame:
         | 
| 81 | 
            -
                progress(0.1, desc="Setting up dataset")
         | 
| 82 | 
            -
                repo_id = validate_push_to_hub(org_name, repo_name)
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                if task == TEXTCAT_TASK:
         | 
| 85 | 
            -
                    if num_labels == 1:
         | 
| 86 | 
            -
                        dataframe["label"] = dataframe["label"].replace("", None)
         | 
| 87 | 
            -
                        features = Features(
         | 
| 88 | 
            -
                            {"text": Value("string"), "label": ClassLabel(names=labels)}
         | 
| 89 | 
            -
                        )
         | 
| 90 | 
            -
                    else:
         | 
| 91 | 
            -
                        features = Features(
         | 
| 92 | 
            -
                            {
         | 
| 93 | 
            -
                                "text": Value("string"),
         | 
| 94 | 
            -
                                "labels": Sequence(feature=ClassLabel(names=labels)),
         | 
| 95 | 
            -
                            }
         | 
| 96 | 
            -
                        )
         | 
| 97 | 
            -
                    distiset = Distiset(
         | 
| 98 | 
            -
                        {"default": Dataset.from_pandas(dataframe, features=features)}
         | 
| 99 | 
            -
                    )
         | 
| 100 | 
            -
                else:
         | 
| 101 | 
            -
                    distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
         | 
| 102 | 
            -
                progress(0.2, desc="Pushing dataset to hub")
         | 
| 103 | 
            -
                distiset.push_to_hub(
         | 
| 104 | 
            -
                    repo_id=repo_id,
         | 
| 105 | 
            -
                    private=private,
         | 
| 106 | 
            -
                    include_script=False,
         | 
| 107 | 
            -
                    token=oauth_token.token,
         | 
| 108 | 
            -
                    create_pr=False,
         | 
| 109 | 
            -
                )
         | 
| 110 | 
            -
                progress(1.0, desc="Dataset pushed to hub")
         | 
| 111 | 
            -
                return dataframe
         | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
             
            def validate_push_to_hub(org_name, repo_name):
         | 
| 115 | 
             
                repo_id = (
         | 
| 116 | 
             
                    f"{org_name}/{repo_name}"
         | 
|  | |
| 67 | 
             
                progress(1.0, desc="Pipeline code uploaded")
         | 
| 68 |  | 
| 69 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 70 | 
             
            def validate_push_to_hub(org_name, repo_name):
         | 
| 71 | 
             
                repo_id = (
         | 
| 72 | 
             
                    f"{org_name}/{repo_name}"
         | 
    	
        src/synthetic_dataset_generator/apps/sft.py
    CHANGED
    
    | @@ -15,7 +15,7 @@ from synthetic_dataset_generator.apps.base import ( | |
| 15 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 16 | 
             
                validate_push_to_hub,
         | 
| 17 | 
             
            )
         | 
| 18 | 
            -
            from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE
         | 
| 19 | 
             
            from synthetic_dataset_generator.pipelines.embeddings import (
         | 
| 20 | 
             
                get_embeddings,
         | 
| 21 | 
             
                get_sentence_embedding_dimensions,
         | 
| @@ -49,10 +49,10 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame: | |
| 49 | 
             
                return dataframe
         | 
| 50 |  | 
| 51 |  | 
| 52 | 
            -
            def generate_system_prompt(dataset_description,  | 
| 53 | 
             
                progress(0.0, desc="Generating system prompt")
         | 
| 54 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| 55 | 
            -
                generate_description = get_prompt_generator( | 
| 56 | 
             
                progress(0.7, desc="Generating system prompt")
         | 
| 57 | 
             
                result = next(
         | 
| 58 | 
             
                    generate_description.process(
         | 
| @@ -92,12 +92,13 @@ def generate_dataset( | |
| 92 | 
             
                system_prompt: str,
         | 
| 93 | 
             
                num_turns: int = 1,
         | 
| 94 | 
             
                num_rows: int = 10,
         | 
|  | |
| 95 | 
             
                is_sample: bool = False,
         | 
| 96 | 
             
                progress=gr.Progress(),
         | 
| 97 | 
             
            ) -> pd.DataFrame:
         | 
| 98 | 
             
                progress(0.0, desc="(1/2) Generating instructions")
         | 
| 99 | 
            -
                magpie_generator = get_magpie_generator(system_prompt, num_turns, is_sample)
         | 
| 100 | 
            -
                response_generator = get_response_generator(system_prompt, num_turns, is_sample)
         | 
| 101 | 
             
                total_steps: int = num_rows * 2
         | 
| 102 | 
             
                batch_size = DEFAULT_BATCH_SIZE
         | 
| 103 |  | 
| @@ -216,6 +217,7 @@ def push_dataset( | |
| 216 | 
             
                num_turns: int = 1,
         | 
| 217 | 
             
                num_rows: int = 10,
         | 
| 218 | 
             
                private: bool = False,
         | 
|  | |
| 219 | 
             
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
| 220 | 
             
                progress=gr.Progress(),
         | 
| 221 | 
             
            ) -> pd.DataFrame:
         | 
| @@ -223,6 +225,7 @@ def push_dataset( | |
| 223 | 
             
                    system_prompt=system_prompt,
         | 
| 224 | 
             
                    num_turns=num_turns,
         | 
| 225 | 
             
                    num_rows=num_rows,
         | 
|  | |
| 226 | 
             
                )
         | 
| 227 | 
             
                push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
         | 
| 228 | 
             
                try:
         | 
| @@ -439,7 +442,7 @@ with gr.Blocks() as app: | |
| 439 | 
             
                                    label="Temperature",
         | 
| 440 | 
             
                                    minimum=0.1,
         | 
| 441 | 
             
                                    maximum=1,
         | 
| 442 | 
            -
                                    value=0. | 
| 443 | 
             
                                    step=0.1,
         | 
| 444 | 
             
                                    interactive=True,
         | 
| 445 | 
             
                                )
         | 
| @@ -463,6 +466,7 @@ with gr.Blocks() as app: | |
| 463 | 
             
                                        system_prompt=system_prompt.value,
         | 
| 464 | 
             
                                        num_turns=num_turns.value,
         | 
| 465 | 
             
                                        num_rows=num_rows.value,
         | 
|  | |
| 466 | 
             
                                    )
         | 
| 467 | 
             
                                    pipeline_code = gr.Code(
         | 
| 468 | 
             
                                        value=code,
         | 
| @@ -472,7 +476,7 @@ with gr.Blocks() as app: | |
| 472 |  | 
| 473 | 
             
                    load_btn.click(
         | 
| 474 | 
             
                        fn=generate_system_prompt,
         | 
| 475 | 
            -
                        inputs=[dataset_description | 
| 476 | 
             
                        outputs=[system_prompt],
         | 
| 477 | 
             
                        show_progress=True,
         | 
| 478 | 
             
                    ).then(
         | 
| @@ -516,6 +520,7 @@ with gr.Blocks() as app: | |
| 516 | 
             
                            num_turns,
         | 
| 517 | 
             
                            num_rows,
         | 
| 518 | 
             
                            private,
         | 
|  | |
| 519 | 
             
                        ],
         | 
| 520 | 
             
                        outputs=[success_message],
         | 
| 521 | 
             
                        show_progress=True,
         | 
| @@ -525,7 +530,7 @@ with gr.Blocks() as app: | |
| 525 | 
             
                        outputs=[success_message],
         | 
| 526 | 
             
                    ).success(
         | 
| 527 | 
             
                        fn=generate_pipeline_code,
         | 
| 528 | 
            -
                        inputs=[system_prompt, num_turns, num_rows],
         | 
| 529 | 
             
                        outputs=[pipeline_code],
         | 
| 530 | 
             
                    ).success(
         | 
| 531 | 
             
                        fn=show_pipeline_code_visibility,
         | 
|  | |
| 15 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 16 | 
             
                validate_push_to_hub,
         | 
| 17 | 
             
            )
         | 
| 18 | 
            +
            from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE, MODEL
         | 
| 19 | 
             
            from synthetic_dataset_generator.pipelines.embeddings import (
         | 
| 20 | 
             
                get_embeddings,
         | 
| 21 | 
             
                get_sentence_embedding_dimensions,
         | 
|  | |
| 49 | 
             
                return dataframe
         | 
| 50 |  | 
| 51 |  | 
| 52 | 
            +
            def generate_system_prompt(dataset_description, progress=gr.Progress()):
         | 
| 53 | 
             
                progress(0.0, desc="Generating system prompt")
         | 
| 54 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| 55 | 
            +
                generate_description = get_prompt_generator()
         | 
| 56 | 
             
                progress(0.7, desc="Generating system prompt")
         | 
| 57 | 
             
                result = next(
         | 
| 58 | 
             
                    generate_description.process(
         | 
|  | |
| 92 | 
             
                system_prompt: str,
         | 
| 93 | 
             
                num_turns: int = 1,
         | 
| 94 | 
             
                num_rows: int = 10,
         | 
| 95 | 
            +
                temperature: float = 0.9,
         | 
| 96 | 
             
                is_sample: bool = False,
         | 
| 97 | 
             
                progress=gr.Progress(),
         | 
| 98 | 
             
            ) -> pd.DataFrame:
         | 
| 99 | 
             
                progress(0.0, desc="(1/2) Generating instructions")
         | 
| 100 | 
            +
                magpie_generator = get_magpie_generator(system_prompt, num_turns, temperature, is_sample)
         | 
| 101 | 
            +
                response_generator = get_response_generator(system_prompt, num_turns, temperature, is_sample)
         | 
| 102 | 
             
                total_steps: int = num_rows * 2
         | 
| 103 | 
             
                batch_size = DEFAULT_BATCH_SIZE
         | 
| 104 |  | 
|  | |
| 217 | 
             
                num_turns: int = 1,
         | 
| 218 | 
             
                num_rows: int = 10,
         | 
| 219 | 
             
                private: bool = False,
         | 
| 220 | 
            +
                temperature: float = 0.9,
         | 
| 221 | 
             
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
| 222 | 
             
                progress=gr.Progress(),
         | 
| 223 | 
             
            ) -> pd.DataFrame:
         | 
|  | |
| 225 | 
             
                    system_prompt=system_prompt,
         | 
| 226 | 
             
                    num_turns=num_turns,
         | 
| 227 | 
             
                    num_rows=num_rows,
         | 
| 228 | 
            +
                    temperature=temperature,
         | 
| 229 | 
             
                )
         | 
| 230 | 
             
                push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
         | 
| 231 | 
             
                try:
         | 
|  | |
| 442 | 
             
                                    label="Temperature",
         | 
| 443 | 
             
                                    minimum=0.1,
         | 
| 444 | 
             
                                    maximum=1,
         | 
| 445 | 
            +
                                    value=0.9,
         | 
| 446 | 
             
                                    step=0.1,
         | 
| 447 | 
             
                                    interactive=True,
         | 
| 448 | 
             
                                )
         | 
|  | |
| 466 | 
             
                                        system_prompt=system_prompt.value,
         | 
| 467 | 
             
                                        num_turns=num_turns.value,
         | 
| 468 | 
             
                                        num_rows=num_rows.value,
         | 
| 469 | 
            +
                                        temperature=temperature.value,
         | 
| 470 | 
             
                                    )
         | 
| 471 | 
             
                                    pipeline_code = gr.Code(
         | 
| 472 | 
             
                                        value=code,
         | 
|  | |
| 476 |  | 
| 477 | 
             
                    load_btn.click(
         | 
| 478 | 
             
                        fn=generate_system_prompt,
         | 
| 479 | 
            +
                        inputs=[dataset_description],
         | 
| 480 | 
             
                        outputs=[system_prompt],
         | 
| 481 | 
             
                        show_progress=True,
         | 
| 482 | 
             
                    ).then(
         | 
|  | |
| 520 | 
             
                            num_turns,
         | 
| 521 | 
             
                            num_rows,
         | 
| 522 | 
             
                            private,
         | 
| 523 | 
            +
                            temperature
         | 
| 524 | 
             
                        ],
         | 
| 525 | 
             
                        outputs=[success_message],
         | 
| 526 | 
             
                        show_progress=True,
         | 
|  | |
| 530 | 
             
                        outputs=[success_message],
         | 
| 531 | 
             
                    ).success(
         | 
| 532 | 
             
                        fn=generate_pipeline_code,
         | 
| 533 | 
            +
                        inputs=[system_prompt, num_turns, num_rows, temperature],
         | 
| 534 | 
             
                        outputs=[pipeline_code],
         | 
| 535 | 
             
                    ).success(
         | 
| 536 | 
             
                        fn=show_pipeline_code_visibility,
         | 
    	
        src/synthetic_dataset_generator/apps/textcat.py
    CHANGED
    
    | @@ -45,10 +45,10 @@ def _get_dataframe(): | |
| 45 | 
             
                )
         | 
| 46 |  | 
| 47 |  | 
| 48 | 
            -
            def generate_system_prompt(dataset_description,  | 
| 49 | 
             
                progress(0.0, desc="Generating text classification task")
         | 
| 50 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| 51 | 
            -
                generate_description = get_prompt_generator( | 
| 52 | 
             
                progress(0.7, desc="Generating text classification task")
         | 
| 53 | 
             
                result = next(
         | 
| 54 | 
             
                    generate_description.process(
         | 
| @@ -89,13 +89,14 @@ def generate_dataset( | |
| 89 | 
             
                labels: List[str] = None,
         | 
| 90 | 
             
                num_labels: int = 1,
         | 
| 91 | 
             
                num_rows: int = 10,
         | 
|  | |
| 92 | 
             
                is_sample: bool = False,
         | 
| 93 | 
             
                progress=gr.Progress(),
         | 
| 94 | 
             
            ) -> pd.DataFrame:
         | 
| 95 | 
             
                progress(0.0, desc="(1/2) Generating text classification data")
         | 
| 96 | 
             
                labels = get_preprocess_labels(labels)
         | 
| 97 | 
             
                textcat_generator = get_textcat_generator(
         | 
| 98 | 
            -
                    difficulty=difficulty, clarity=clarity, is_sample=is_sample
         | 
| 99 | 
             
                )
         | 
| 100 | 
             
                labeller_generator = get_labeller_generator(
         | 
| 101 | 
             
                    system_prompt=f"{system_prompt} {', '.join(labels)}",
         | 
| @@ -204,6 +205,7 @@ def push_dataset( | |
| 204 | 
             
                num_rows: int = 10,
         | 
| 205 | 
             
                labels: List[str] = None,
         | 
| 206 | 
             
                private: bool = False,
         | 
|  | |
| 207 | 
             
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
| 208 | 
             
                progress=gr.Progress(),
         | 
| 209 | 
             
            ) -> pd.DataFrame:
         | 
| @@ -214,6 +216,7 @@ def push_dataset( | |
| 214 | 
             
                    num_labels=num_labels,
         | 
| 215 | 
             
                    labels=labels,
         | 
| 216 | 
             
                    num_rows=num_rows,
         | 
|  | |
| 217 | 
             
                )
         | 
| 218 | 
             
                push_dataset_to_hub(
         | 
| 219 | 
             
                    dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
         | 
| @@ -471,6 +474,7 @@ with gr.Blocks() as app: | |
| 471 | 
             
                                    labels=labels.value,
         | 
| 472 | 
             
                                    num_labels=num_labels.value,
         | 
| 473 | 
             
                                    num_rows=num_rows.value,
         | 
|  | |
| 474 | 
             
                                )
         | 
| 475 | 
             
                                pipeline_code = gr.Code(
         | 
| 476 | 
             
                                    value=code,
         | 
| @@ -480,7 +484,7 @@ with gr.Blocks() as app: | |
| 480 |  | 
| 481 | 
             
                load_btn.click(
         | 
| 482 | 
             
                    fn=generate_system_prompt,
         | 
| 483 | 
            -
                    inputs=[dataset_description | 
| 484 | 
             
                    outputs=[system_prompt, labels],
         | 
| 485 | 
             
                    show_progress=True,
         | 
| 486 | 
             
                ).then(
         | 
| @@ -537,6 +541,7 @@ with gr.Blocks() as app: | |
| 537 | 
             
                        num_rows,
         | 
| 538 | 
             
                        labels,
         | 
| 539 | 
             
                        private,
         | 
|  | |
| 540 | 
             
                    ],
         | 
| 541 | 
             
                    outputs=[success_message],
         | 
| 542 | 
             
                    show_progress=True,
         | 
| @@ -553,6 +558,7 @@ with gr.Blocks() as app: | |
| 553 | 
             
                        labels,
         | 
| 554 | 
             
                        num_labels,
         | 
| 555 | 
             
                        num_rows,
         | 
|  | |
| 556 | 
             
                    ],
         | 
| 557 | 
             
                    outputs=[pipeline_code],
         | 
| 558 | 
             
                ).success(
         | 
|  | |
| 45 | 
             
                )
         | 
| 46 |  | 
| 47 |  | 
| 48 | 
            +
            def generate_system_prompt(dataset_description, progress=gr.Progress()):
         | 
| 49 | 
             
                progress(0.0, desc="Generating text classification task")
         | 
| 50 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| 51 | 
            +
                generate_description = get_prompt_generator()
         | 
| 52 | 
             
                progress(0.7, desc="Generating text classification task")
         | 
| 53 | 
             
                result = next(
         | 
| 54 | 
             
                    generate_description.process(
         | 
|  | |
| 89 | 
             
                labels: List[str] = None,
         | 
| 90 | 
             
                num_labels: int = 1,
         | 
| 91 | 
             
                num_rows: int = 10,
         | 
| 92 | 
            +
                temperature: float = 0.9,
         | 
| 93 | 
             
                is_sample: bool = False,
         | 
| 94 | 
             
                progress=gr.Progress(),
         | 
| 95 | 
             
            ) -> pd.DataFrame:
         | 
| 96 | 
             
                progress(0.0, desc="(1/2) Generating text classification data")
         | 
| 97 | 
             
                labels = get_preprocess_labels(labels)
         | 
| 98 | 
             
                textcat_generator = get_textcat_generator(
         | 
| 99 | 
            +
                    difficulty=difficulty, clarity=clarity, temperature=temperature, is_sample=is_sample
         | 
| 100 | 
             
                )
         | 
| 101 | 
             
                labeller_generator = get_labeller_generator(
         | 
| 102 | 
             
                    system_prompt=f"{system_prompt} {', '.join(labels)}",
         | 
|  | |
| 205 | 
             
                num_rows: int = 10,
         | 
| 206 | 
             
                labels: List[str] = None,
         | 
| 207 | 
             
                private: bool = False,
         | 
| 208 | 
            +
                temperature: float = 0.8,
         | 
| 209 | 
             
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
| 210 | 
             
                progress=gr.Progress(),
         | 
| 211 | 
             
            ) -> pd.DataFrame:
         | 
|  | |
| 216 | 
             
                    num_labels=num_labels,
         | 
| 217 | 
             
                    labels=labels,
         | 
| 218 | 
             
                    num_rows=num_rows,
         | 
| 219 | 
            +
                    temperature=temperature,
         | 
| 220 | 
             
                )
         | 
| 221 | 
             
                push_dataset_to_hub(
         | 
| 222 | 
             
                    dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
         | 
|  | |
| 474 | 
             
                                    labels=labels.value,
         | 
| 475 | 
             
                                    num_labels=num_labels.value,
         | 
| 476 | 
             
                                    num_rows=num_rows.value,
         | 
| 477 | 
            +
                                    temperature=temperature.value,
         | 
| 478 | 
             
                                )
         | 
| 479 | 
             
                                pipeline_code = gr.Code(
         | 
| 480 | 
             
                                    value=code,
         | 
|  | |
| 484 |  | 
| 485 | 
             
                load_btn.click(
         | 
| 486 | 
             
                    fn=generate_system_prompt,
         | 
| 487 | 
            +
                    inputs=[dataset_description],
         | 
| 488 | 
             
                    outputs=[system_prompt, labels],
         | 
| 489 | 
             
                    show_progress=True,
         | 
| 490 | 
             
                ).then(
         | 
|  | |
| 541 | 
             
                        num_rows,
         | 
| 542 | 
             
                        labels,
         | 
| 543 | 
             
                        private,
         | 
| 544 | 
            +
                        temperature
         | 
| 545 | 
             
                    ],
         | 
| 546 | 
             
                    outputs=[success_message],
         | 
| 547 | 
             
                    show_progress=True,
         | 
|  | |
| 558 | 
             
                        labels,
         | 
| 559 | 
             
                        num_labels,
         | 
| 560 | 
             
                        num_rows,
         | 
| 561 | 
            +
                        temperature
         | 
| 562 | 
             
                    ],
         | 
| 563 | 
             
                    outputs=[pipeline_code],
         | 
| 564 | 
             
                ).success(
         | 
    	
        src/synthetic_dataset_generator/pipelines/sft.py
    CHANGED
    
    | @@ -140,7 +140,7 @@ def _get_output_mappings(num_turns): | |
| 140 | 
             
                    return {"conversation": "messages"}
         | 
| 141 |  | 
| 142 |  | 
| 143 | 
            -
            def get_prompt_generator( | 
| 144 | 
             
                prompt_generator = TextGeneration(
         | 
| 145 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 146 | 
             
                        api_key=_get_next_api_key(),
         | 
| @@ -148,7 +148,7 @@ def get_prompt_generator(temperature): | |
| 148 | 
             
                        tokenizer_id=MODEL,
         | 
| 149 | 
             
                        base_url=BASE_URL,
         | 
| 150 | 
             
                        generation_kwargs={
         | 
| 151 | 
            -
                            "temperature":  | 
| 152 | 
             
                            "max_new_tokens": 2048,
         | 
| 153 | 
             
                            "do_sample": True,
         | 
| 154 | 
             
                        },
         | 
| @@ -160,7 +160,7 @@ def get_prompt_generator(temperature): | |
| 160 | 
             
                return prompt_generator
         | 
| 161 |  | 
| 162 |  | 
| 163 | 
            -
            def get_magpie_generator(system_prompt, num_turns, is_sample):
         | 
| 164 | 
             
                input_mappings = _get_output_mappings(num_turns)
         | 
| 165 | 
             
                output_mappings = input_mappings.copy()
         | 
| 166 | 
             
                if num_turns == 1:
         | 
| @@ -172,7 +172,7 @@ def get_magpie_generator(system_prompt, num_turns, is_sample): | |
| 172 | 
             
                            api_key=_get_next_api_key(),
         | 
| 173 | 
             
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 174 | 
             
                            generation_kwargs={
         | 
| 175 | 
            -
                                "temperature":  | 
| 176 | 
             
                                "do_sample": True,
         | 
| 177 | 
             
                                "max_new_tokens": 256 if is_sample else 512,
         | 
| 178 | 
             
                                "stop_sequences": _STOP_SEQUENCES,
         | 
| @@ -192,7 +192,7 @@ def get_magpie_generator(system_prompt, num_turns, is_sample): | |
| 192 | 
             
                            api_key=_get_next_api_key(),
         | 
| 193 | 
             
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 194 | 
             
                            generation_kwargs={
         | 
| 195 | 
            -
                                "temperature":  | 
| 196 | 
             
                                "do_sample": True,
         | 
| 197 | 
             
                                "max_new_tokens": 256 if is_sample else 1024,
         | 
| 198 | 
             
                                "stop_sequences": _STOP_SEQUENCES,
         | 
| @@ -243,7 +243,7 @@ def get_response_generator(system_prompt, num_turns, is_sample): | |
| 243 | 
             
                return response_generator
         | 
| 244 |  | 
| 245 |  | 
| 246 | 
            -
            def generate_pipeline_code(system_prompt, num_turns, num_rows):
         | 
| 247 | 
             
                input_mappings = _get_output_mappings(num_turns)
         | 
| 248 | 
             
                code = f"""
         | 
| 249 | 
             
            # Requirements: `pip install distilabel[hf-inference-endpoints]`
         | 
| @@ -266,7 +266,7 @@ with Pipeline(name="sft") as pipeline: | |
| 266 | 
             
                        base_url=BASE_URL,
         | 
| 267 | 
             
                        magpie_pre_query_template="llama3",
         | 
| 268 | 
             
                        generation_kwargs={{
         | 
| 269 | 
            -
                            "temperature":  | 
| 270 | 
             
                            "do_sample": True,
         | 
| 271 | 
             
                            "max_new_tokens": 2048,
         | 
| 272 | 
             
                            "stop_sequences": {_STOP_SEQUENCES}
         | 
|  | |
| 140 | 
             
                    return {"conversation": "messages"}
         | 
| 141 |  | 
| 142 |  | 
| 143 | 
            +
            def get_prompt_generator():
         | 
| 144 | 
             
                prompt_generator = TextGeneration(
         | 
| 145 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 146 | 
             
                        api_key=_get_next_api_key(),
         | 
|  | |
| 148 | 
             
                        tokenizer_id=MODEL,
         | 
| 149 | 
             
                        base_url=BASE_URL,
         | 
| 150 | 
             
                        generation_kwargs={
         | 
| 151 | 
            +
                            "temperature": 0.8,
         | 
| 152 | 
             
                            "max_new_tokens": 2048,
         | 
| 153 | 
             
                            "do_sample": True,
         | 
| 154 | 
             
                        },
         | 
|  | |
| 160 | 
             
                return prompt_generator
         | 
| 161 |  | 
| 162 |  | 
| 163 | 
            +
            def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
         | 
| 164 | 
             
                input_mappings = _get_output_mappings(num_turns)
         | 
| 165 | 
             
                output_mappings = input_mappings.copy()
         | 
| 166 | 
             
                if num_turns == 1:
         | 
|  | |
| 172 | 
             
                            api_key=_get_next_api_key(),
         | 
| 173 | 
             
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 174 | 
             
                            generation_kwargs={
         | 
| 175 | 
            +
                                "temperature": temperature,
         | 
| 176 | 
             
                                "do_sample": True,
         | 
| 177 | 
             
                                "max_new_tokens": 256 if is_sample else 512,
         | 
| 178 | 
             
                                "stop_sequences": _STOP_SEQUENCES,
         | 
|  | |
| 192 | 
             
                            api_key=_get_next_api_key(),
         | 
| 193 | 
             
                            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
         | 
| 194 | 
             
                            generation_kwargs={
         | 
| 195 | 
            +
                                "temperature": temperature,
         | 
| 196 | 
             
                                "do_sample": True,
         | 
| 197 | 
             
                                "max_new_tokens": 256 if is_sample else 1024,
         | 
| 198 | 
             
                                "stop_sequences": _STOP_SEQUENCES,
         | 
|  | |
| 243 | 
             
                return response_generator
         | 
| 244 |  | 
| 245 |  | 
| 246 | 
            +
            def generate_pipeline_code(system_prompt, num_turns, num_rows, temperature):
         | 
| 247 | 
             
                input_mappings = _get_output_mappings(num_turns)
         | 
| 248 | 
             
                code = f"""
         | 
| 249 | 
             
            # Requirements: `pip install distilabel[hf-inference-endpoints]`
         | 
|  | |
| 266 | 
             
                        base_url=BASE_URL,
         | 
| 267 | 
             
                        magpie_pre_query_template="llama3",
         | 
| 268 | 
             
                        generation_kwargs={{
         | 
| 269 | 
            +
                            "temperature": {temperature},
         | 
| 270 | 
             
                            "do_sample": True,
         | 
| 271 | 
             
                            "max_new_tokens": 2048,
         | 
| 272 | 
             
                            "stop_sequences": {_STOP_SEQUENCES}
         | 
    	
        src/synthetic_dataset_generator/pipelines/textcat.py
    CHANGED
    
    | @@ -66,7 +66,7 @@ class TextClassificationTask(BaseModel): | |
| 66 | 
             
                )
         | 
| 67 |  | 
| 68 |  | 
| 69 | 
            -
            def get_prompt_generator( | 
| 70 | 
             
                prompt_generator = TextGeneration(
         | 
| 71 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 72 | 
             
                        api_key=_get_next_api_key(),
         | 
| @@ -74,7 +74,7 @@ def get_prompt_generator(temperature): | |
| 74 | 
             
                        base_url=BASE_URL,
         | 
| 75 | 
             
                        structured_output={"format": "json", "schema": TextClassificationTask},
         | 
| 76 | 
             
                        generation_kwargs={
         | 
| 77 | 
            -
                            "temperature":  | 
| 78 | 
             
                            "max_new_tokens": 2048,
         | 
| 79 | 
             
                            "do_sample": True,
         | 
| 80 | 
             
                        },
         | 
| @@ -86,14 +86,14 @@ def get_prompt_generator(temperature): | |
| 86 | 
             
                return prompt_generator
         | 
| 87 |  | 
| 88 |  | 
| 89 | 
            -
            def get_textcat_generator(difficulty, clarity, is_sample):
         | 
| 90 | 
             
                textcat_generator = GenerateTextClassificationData(
         | 
| 91 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 92 | 
             
                        model_id=MODEL,
         | 
| 93 | 
             
                        base_url=BASE_URL,
         | 
| 94 | 
             
                        api_key=_get_next_api_key(),
         | 
| 95 | 
             
                        generation_kwargs={
         | 
| 96 | 
            -
                            "temperature":  | 
| 97 | 
             
                            "max_new_tokens": 256 if is_sample else 2048,
         | 
| 98 | 
             
                            "do_sample": True,
         | 
| 99 | 
             
                            "top_k": 50,
         | 
| @@ -135,6 +135,7 @@ def generate_pipeline_code( | |
| 135 | 
             
                labels: List[str] = None,
         | 
| 136 | 
             
                num_labels: int = 1,
         | 
| 137 | 
             
                num_rows: int = 10,
         | 
|  | |
| 138 | 
             
            ) -> str:
         | 
| 139 | 
             
                labels = get_preprocess_labels(labels)
         | 
| 140 | 
             
                base_code = f"""
         | 
| @@ -163,7 +164,7 @@ with Pipeline(name="textcat") as pipeline: | |
| 163 | 
             
                        base_url=BASE_URL,
         | 
| 164 | 
             
                        api_key=os.environ["API_KEY"],
         | 
| 165 | 
             
                        generation_kwargs={{
         | 
| 166 | 
            -
                            "temperature":  | 
| 167 | 
             
                            "max_new_tokens": 2048,
         | 
| 168 | 
             
                            "do_sample": True,
         | 
| 169 | 
             
                            "top_k": 50,
         | 
|  | |
| 66 | 
             
                )
         | 
| 67 |  | 
| 68 |  | 
| 69 | 
            +
            def get_prompt_generator():
         | 
| 70 | 
             
                prompt_generator = TextGeneration(
         | 
| 71 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 72 | 
             
                        api_key=_get_next_api_key(),
         | 
|  | |
| 74 | 
             
                        base_url=BASE_URL,
         | 
| 75 | 
             
                        structured_output={"format": "json", "schema": TextClassificationTask},
         | 
| 76 | 
             
                        generation_kwargs={
         | 
| 77 | 
            +
                            "temperature": 0.8,
         | 
| 78 | 
             
                            "max_new_tokens": 2048,
         | 
| 79 | 
             
                            "do_sample": True,
         | 
| 80 | 
             
                        },
         | 
|  | |
| 86 | 
             
                return prompt_generator
         | 
| 87 |  | 
| 88 |  | 
| 89 | 
            +
            def get_textcat_generator(difficulty, clarity, temperature, is_sample):
         | 
| 90 | 
             
                textcat_generator = GenerateTextClassificationData(
         | 
| 91 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 92 | 
             
                        model_id=MODEL,
         | 
| 93 | 
             
                        base_url=BASE_URL,
         | 
| 94 | 
             
                        api_key=_get_next_api_key(),
         | 
| 95 | 
             
                        generation_kwargs={
         | 
| 96 | 
            +
                            "temperature": temperature,
         | 
| 97 | 
             
                            "max_new_tokens": 256 if is_sample else 2048,
         | 
| 98 | 
             
                            "do_sample": True,
         | 
| 99 | 
             
                            "top_k": 50,
         | 
|  | |
| 135 | 
             
                labels: List[str] = None,
         | 
| 136 | 
             
                num_labels: int = 1,
         | 
| 137 | 
             
                num_rows: int = 10,
         | 
| 138 | 
            +
                temperature: float = 0.9,
         | 
| 139 | 
             
            ) -> str:
         | 
| 140 | 
             
                labels = get_preprocess_labels(labels)
         | 
| 141 | 
             
                base_code = f"""
         | 
|  | |
| 164 | 
             
                        base_url=BASE_URL,
         | 
| 165 | 
             
                        api_key=os.environ["API_KEY"],
         | 
| 166 | 
             
                        generation_kwargs={{
         | 
| 167 | 
            +
                            "temperature": {temperature},
         | 
| 168 | 
             
                            "max_new_tokens": 2048,
         | 
| 169 | 
             
                            "do_sample": True,
         | 
| 170 | 
             
                            "top_k": 50,
         | 
 
			

