davidberenstein1957 HF staff commited on
Commit
ec33fc2
2 Parent(s): da59bd9 9feda8c

Merge pull request #4 from argilla-io/feat/choose-models

Browse files
.python-version DELETED
@@ -1 +0,0 @@
1
- synthetic-data-generator
 
 
README.md CHANGED
@@ -20,25 +20,13 @@ hf_oauth_scopes:
20
 
21
  <h1 align="center">
22
  <br>
23
- Synthetic Data Generator
24
  <br>
25
  </h1>
26
  <h3 align="center">Build datasets using natural language</h2>
27
 
28
  ![Synthetic Data Generator](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/ui-full.png)
29
 
30
- <p align="center">
31
- <a href="https://pypi.org/project/synthetic-dataset-generator/">
32
- <img alt="CI" src="https://img.shields.io/pypi/v/synthetic-dataset-generator.svg?style=flat-round&logo=pypi&logoColor=white">
33
- </a>
34
- <a href="https://pepy.tech/project/synthetic-dataset-generator">
35
- <img alt="CI" src="https://static.pepy.tech/personalized-badge/synthetic-dataset-generator?period=month&units=international_system&left_color=grey&right_color=blue&left_text=pypi%20downloads/month">
36
- </a>
37
- <a href="https://huggingface.co/spaces/argilla/synthetic-data-generator?duplicate=true">
38
- <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg"/>
39
- </a>
40
- </p>
41
-
42
  <p align="center">
43
  <a href="https://twitter.com/argilla_io">
44
  <img src="https://img.shields.io/badge/twitter-black?logo=x"/>
@@ -78,21 +66,29 @@ You can simply install the package with:
78
  pip install synthetic-dataset-generator
79
  ```
80
 
 
 
 
 
 
 
 
 
81
  ### Environment Variables
82
 
83
- - `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run *Free* Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained).
 
 
 
 
 
 
84
 
85
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
86
 
87
  - `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
88
  - `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
89
 
90
- ## Quickstart
91
-
92
- ```bash
93
- python app.py
94
- ```
95
-
96
  ### Argilla integration
97
 
98
  Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
@@ -104,3 +100,19 @@ Argilla is a open source tool for data curation. It allows you to annotate and r
104
  Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps.
105
 
106
  Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  <h1 align="center">
22
  <br>
23
+ 🧬 Synthetic Data Generator
24
  <br>
25
  </h1>
26
  <h3 align="center">Build datasets using natural language</h2>
27
 
28
  ![Synthetic Data Generator](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/ui-full.png)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  <p align="center">
31
  <a href="https://twitter.com/argilla_io">
32
  <img src="https://img.shields.io/badge/twitter-black?logo=x"/>
 
66
  pip install synthetic-dataset-generator
67
  ```
68
 
69
+ ### Quickstart
70
+
71
+ ```python
72
+ from synthetic_dataset_generator.app import demo
73
+
74
+ demo.launch()
75
+ ```
76
+
77
  ### Environment Variables
78
 
79
+ - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints.
80
+
81
+ Optionally, you can set the following environment variables to customize the generation process.
82
+
83
+ - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`.
84
+ - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`.
85
+ - `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`.
86
 
87
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
88
 
89
  - `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
90
  - `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
91
 
 
 
 
 
 
 
92
  ### Argilla integration
93
 
94
  Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
 
100
  Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps.
101
 
102
  Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information.
103
+
104
+ ## Development
105
+
106
+ Install the dependencies:
107
+
108
+ ```bash
109
+ python -m venv .venv
110
+ source .venv/bin/activate
111
+ pip install -e .
112
+ ```
113
+
114
+ Run the app:
115
+
116
+ ```bash
117
+ python app.py
118
+ ```
app.py CHANGED
@@ -1,38 +1,4 @@
1
- from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
2
- from src.distilabel_dataset_generator.apps.eval import app as eval_app
3
- from src.distilabel_dataset_generator.apps.faq import app as faq_app
4
- from src.distilabel_dataset_generator.apps.sft import app as sft_app
5
- from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
6
-
7
- theme = "argilla/argilla-theme"
8
-
9
- css = """
10
- button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
11
- button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)}
12
- button.hf-login {background: var(--neutral-800); color: white}
13
- button.hf-login:hover {background: var(--neutral-700); color: white}
14
- .tabitem { border: 0; padding-inline: 0}
15
- .main_ui_logged_out{opacity: 0.3; pointer-events: none}
16
- .group_padding{padding: .55em}
17
- .gallery-item {background: var(--background-fill-secondary); text-align: left}
18
- .gallery {white-space: wrap}
19
- #space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
20
- #system_prompt_examples {
21
- color: var(--body-text-color) !important;
22
- background-color: var(--block-background-fill) !important;
23
- }
24
- .container {padding-inline: 0 !important}
25
- """
26
-
27
- demo = TabbedInterface(
28
- [textcat_app, sft_app, eval_app, faq_app],
29
- ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
30
- css=css,
31
- title="Synthetic Data Generator",
32
- head="Synthetic Data Generator",
33
- theme=theme,
34
- )
35
-
36
 
37
  if __name__ == "__main__":
38
  demo.launch()
 
1
+ from distilabel_dataset_generator.app import demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
  demo.launch()
pyproject.toml CHANGED
@@ -5,6 +5,18 @@ description = "Build datasets using natural language"
5
  authors = [
6
  {name = "davidberenstein1957", email = "[email protected]"},
7
  ]
 
 
 
 
 
 
 
 
 
 
 
 
8
  dependencies = [
9
  "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
10
  "gradio[oauth]<5.0.0",
@@ -14,14 +26,10 @@ dependencies = [
14
  "gradio-huggingfacehub-search>=0.0.7",
15
  "argilla>=2.4.0",
16
  ]
17
- requires-python = "<3.13,>=3.10"
18
- readme = "README.md"
19
- license = {text = "apache 2"}
20
 
21
  [build-system]
22
  requires = ["pdm-backend"]
23
  build-backend = "pdm.backend"
24
 
25
-
26
  [tool.pdm]
27
  distribution = true
 
5
  authors = [
6
  {name = "davidberenstein1957", email = "[email protected]"},
7
  ]
8
+ tags = [
9
+ "gradio",
10
+ "synthetic-data",
11
+ "huggingface",
12
+ "argilla",
13
+ "generative-ai",
14
+ "ai",
15
+ ]
16
+ requires-python = "<3.13,>=3.10"
17
+ readme = "README.md"
18
+ license = {text = "Apache 2"}
19
+
20
  dependencies = [
21
  "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
22
  "gradio[oauth]<5.0.0",
 
26
  "gradio-huggingfacehub-search>=0.0.7",
27
  "argilla>=2.4.0",
28
  ]
 
 
 
29
 
30
  [build-system]
31
  requires = ["pdm-backend"]
32
  build-backend = "pdm.backend"
33
 
 
34
  [tool.pdm]
35
  distribution = true
src/distilabel_dataset_generator/__init__.py CHANGED
@@ -1,39 +1,64 @@
1
- import os
2
  import warnings
3
- from pathlib import Path
4
- from typing import Optional, Union
5
 
6
- import argilla as rg
7
  import distilabel
8
  import distilabel.distiset
 
9
  from distilabel.utils.card.dataset_card import (
10
  DistilabelDatasetCard,
11
  size_categories_parser,
12
  )
13
- from huggingface_hub import DatasetCardData, HfApi, upload_file
 
 
 
 
14
 
15
- HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
16
- HF_TOKENS = [token for token in HF_TOKENS if token]
17
 
18
- if len(HF_TOKENS) == 0:
19
- raise ValueError(
20
- "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
21
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
24
- ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
25
- if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
26
- ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
27
- ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
 
28
 
29
- if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
30
- warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set")
31
- argilla_client = None
32
- else:
33
- argilla_client = rg.Argilla(
34
- api_url=ARGILLA_API_URL,
35
- api_key=ARGILLA_API_KEY,
36
- )
 
 
 
 
 
 
37
 
38
 
39
  class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
@@ -138,3 +163,4 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
138
 
139
 
140
  distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
 
 
 
1
  import warnings
2
+ from typing import Optional
 
3
 
 
4
  import distilabel
5
  import distilabel.distiset
6
+ from distilabel.llms import InferenceEndpointsLLM
7
  from distilabel.utils.card.dataset_card import (
8
  DistilabelDatasetCard,
9
  size_categories_parser,
10
  )
11
+ from huggingface_hub import DatasetCardData, HfApi
12
+ from pydantic import (
13
+ ValidationError,
14
+ model_validator,
15
+ )
16
 
 
 
17
 
18
+ class CustomInferenceEndpointsLLM(InferenceEndpointsLLM):
19
+ @model_validator(mode="after") # type: ignore
20
+ def only_one_of_model_id_endpoint_name_or_base_url_provided(
21
+ self,
22
+ ) -> "InferenceEndpointsLLM":
23
+ """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
24
+ provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
25
+ favour of the dynamically calculated one.."""
26
+
27
+ if self.base_url and (self.model_id or self.endpoint_name):
28
+ warnings.warn( # type: ignore
29
+ f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
30
+ " or `endpoint_name` is also provided, the `base_url` will either be ignored"
31
+ " or overwritten with the one generated from either of those args, for serverless"
32
+ " or dedicated inference endpoints, respectively."
33
+ )
34
+
35
+ if self.use_magpie_template and self.tokenizer_id is None:
36
+ raise ValueError(
37
+ "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
38
+ " set a `tokenizer_id` and try again."
39
+ )
40
 
41
+ if (
42
+ self.model_id
43
+ and self.tokenizer_id is None
44
+ and self.structured_output is not None
45
+ ):
46
+ self.tokenizer_id = self.model_id
47
 
48
+ if self.base_url and not (self.model_id or self.endpoint_name):
49
+ return self
50
+
51
+ if self.model_id and not self.endpoint_name:
52
+ return self
53
+
54
+ if self.endpoint_name and not self.model_id:
55
+ return self
56
+
57
+ raise ValidationError(
58
+ f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
59
+ f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
60
+ f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
61
+ )
62
 
63
 
64
  class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
 
163
 
164
 
165
  distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
166
+ distilabel.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM
src/distilabel_dataset_generator/_tabbedinterface.py CHANGED
@@ -68,7 +68,9 @@ class TabbedInterface(Blocks):
68
  with gr.Column(scale=3):
69
  pass
70
  with gr.Column(scale=2):
71
- gr.LoginButton(value="Sign in!", variant="hf-login", size="sm", scale=2)
 
 
72
  with Tabs():
73
  for interface, tab_name in zip(interface_list, tab_names, strict=False):
74
  with Tab(label=tab_name):
 
68
  with gr.Column(scale=3):
69
  pass
70
  with gr.Column(scale=2):
71
+ gr.LoginButton(
72
+ value="Sign in", variant="hf-login", size="sm", scale=2
73
+ )
74
  with Tabs():
75
  for interface, tab_name in zip(interface_list, tab_names, strict=False):
76
  with Tab(label=tab_name):
src/distilabel_dataset_generator/app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distilabel_dataset_generator._tabbedinterface import TabbedInterface
2
+ from distilabel_dataset_generator.apps.eval import app as eval_app
3
+ from distilabel_dataset_generator.apps.faq import app as faq_app
4
+ from distilabel_dataset_generator.apps.sft import app as sft_app
5
+ from distilabel_dataset_generator.apps.textcat import app as textcat_app
6
+
7
+ theme = "argilla/argilla-theme"
8
+
9
+ css = """
10
+ button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
11
+ button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)}
12
+ button.hf-login {background: var(--neutral-800); color: white}
13
+ button.hf-login:hover {background: var(--neutral-700); color: white}
14
+ .tabitem { border: 0; padding-inline: 0}
15
+ .main_ui_logged_out{opacity: 0.3; pointer-events: none}
16
+ .group_padding{padding: .55em}
17
+ .gallery-item {background: var(--background-fill-secondary); text-align: left}
18
+ .gallery {white-space: wrap}
19
+ #space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
20
+ #system_prompt_examples {
21
+ color: var(--body-text-color) !important;
22
+ background-color: var(--block-background-fill) !important;
23
+ }
24
+ .container {padding-inline: 0 !important}
25
+ """
26
+
27
+ demo = TabbedInterface(
28
+ [textcat_app, sft_app, eval_app, faq_app],
29
+ ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
30
+ css=css,
31
+ title="Synthetic Data Generator",
32
+ head="Synthetic Data Generator",
33
+ theme=theme,
34
+ )
35
+
36
+
37
+ if __name__ == "__main__":
38
+ demo.launch()
src/distilabel_dataset_generator/apps/__init__.py ADDED
File without changes
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -1,6 +1,6 @@
1
  import io
2
  import uuid
3
- from typing import Any, Callable, List, Tuple, Union
4
 
5
  import argilla as rg
6
  import gradio as gr
@@ -10,161 +10,11 @@ from distilabel.distiset import Distiset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
12
 
13
- from src.distilabel_dataset_generator.utils import (
14
- _LOGGED_OUT_CSS,
15
  get_argilla_client,
16
- get_login_button,
17
- list_orgs,
18
- swap_visibility,
19
  )
20
 
21
- TEXTCAT_TASK = "text_classification"
22
- SFT_TASK = "supervised_fine_tuning"
23
-
24
-
25
- def get_main_ui(
26
- default_dataset_descriptions: List[str],
27
- default_system_prompts: List[str],
28
- default_datasets: List[pd.DataFrame],
29
- fn_generate_system_prompt: Callable,
30
- fn_generate_dataset: Callable,
31
- task: str,
32
- ):
33
- def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()):
34
- if system_prompt in default_system_prompts:
35
- index = default_system_prompts.index(system_prompt)
36
- if index < len(default_datasets):
37
- return default_datasets[index]
38
- if task == TEXTCAT_TASK:
39
- result = fn_generate_dataset(
40
- system_prompt=system_prompt,
41
- difficulty="high school",
42
- clarity="clear",
43
- labels=[],
44
- num_labels=1,
45
- num_rows=1,
46
- progress=progress,
47
- is_sample=True,
48
- )
49
- else:
50
- result = fn_generate_dataset(
51
- system_prompt=system_prompt,
52
- num_turns=1,
53
- num_rows=1,
54
- progress=progress,
55
- is_sample=True,
56
- )
57
- return result
58
-
59
- with gr.Blocks(
60
- title="🧬 Synthetic Data Generator",
61
- head="🧬 Synthetic Data Generator",
62
- css=_LOGGED_OUT_CSS,
63
- ) as app:
64
- with gr.Row():
65
- gr.HTML(
66
- """<details style='display: inline-block;'><summary><h2 style='display: inline;'>How does it work?</h2></summary><img src='https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/flow.png' width='100%' style='margin: 0 auto; display: block;'></details>"""
67
- )
68
- with gr.Row():
69
- gr.Markdown(
70
- "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
71
- )
72
- with gr.Row():
73
- gr.Column()
74
- get_login_button()
75
- gr.Column()
76
-
77
- gr.Markdown("## Iterate on a sample dataset")
78
- with gr.Column() as main_ui:
79
- (
80
- dataset_description,
81
- examples,
82
- btn_generate_system_prompt,
83
- system_prompt,
84
- sample_dataset,
85
- btn_generate_sample_dataset,
86
- ) = get_iterate_on_sample_dataset_ui(
87
- default_dataset_descriptions=default_dataset_descriptions,
88
- default_system_prompts=default_system_prompts,
89
- default_datasets=default_datasets,
90
- task=task,
91
- )
92
- gr.Markdown("## Generate full dataset")
93
- gr.Markdown(
94
- "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
95
- )
96
- with gr.Row(variant="panel") as custom_input_ui:
97
- pass
98
-
99
- (
100
- dataset_name,
101
- add_to_existing_dataset,
102
- btn_generate_full_dataset_argilla,
103
- btn_generate_and_push_to_argilla,
104
- btn_push_to_argilla,
105
- org_name,
106
- repo_name,
107
- private,
108
- btn_generate_full_dataset,
109
- btn_generate_and_push_to_hub,
110
- btn_push_to_hub,
111
- final_dataset,
112
- success_message,
113
- ) = get_push_to_ui(default_datasets)
114
-
115
- sample_dataset.change(
116
- fn=lambda x: x,
117
- inputs=[sample_dataset],
118
- outputs=[final_dataset],
119
- )
120
-
121
- btn_generate_system_prompt.click(
122
- fn=fn_generate_system_prompt,
123
- inputs=[dataset_description],
124
- outputs=[system_prompt],
125
- show_progress=True,
126
- ).then(
127
- fn=fn_generate_sample_dataset,
128
- inputs=[system_prompt],
129
- outputs=[sample_dataset],
130
- show_progress=True,
131
- )
132
-
133
- btn_generate_sample_dataset.click(
134
- fn=fn_generate_sample_dataset,
135
- inputs=[system_prompt],
136
- outputs=[sample_dataset],
137
- show_progress=True,
138
- )
139
-
140
- app.load(fn=swap_visibility, outputs=main_ui)
141
- app.load(get_org_dropdown, outputs=[org_name])
142
-
143
- return (
144
- app,
145
- main_ui,
146
- custom_input_ui,
147
- dataset_description,
148
- examples,
149
- btn_generate_system_prompt,
150
- system_prompt,
151
- sample_dataset,
152
- btn_generate_sample_dataset,
153
- dataset_name,
154
- add_to_existing_dataset,
155
- btn_generate_full_dataset_argilla,
156
- btn_generate_and_push_to_argilla,
157
- btn_push_to_argilla,
158
- org_name,
159
- repo_name,
160
- private,
161
- btn_generate_full_dataset,
162
- btn_generate_and_push_to_hub,
163
- btn_push_to_hub,
164
- final_dataset,
165
- success_message,
166
- )
167
-
168
 
169
  def validate_argilla_user_workspace_dataset(
170
  dataset_name: str,
@@ -195,186 +45,6 @@ def validate_argilla_user_workspace_dataset(
195
  return ""
196
 
197
 
198
- def get_org_dropdown(oauth_token: Union[OAuthToken, None]):
199
- orgs = list_orgs(oauth_token)
200
- return gr.Dropdown(
201
- label="Organization",
202
- choices=orgs,
203
- value=orgs[0] if orgs else None,
204
- allow_custom_value=True,
205
- )
206
-
207
-
208
- def get_push_to_ui(default_datasets):
209
- with gr.Column() as push_to_ui:
210
- (
211
- dataset_name,
212
- add_to_existing_dataset,
213
- btn_generate_full_dataset_argilla,
214
- btn_generate_and_push_to_argilla,
215
- btn_push_to_argilla,
216
- ) = get_argilla_tab()
217
- (
218
- org_name,
219
- repo_name,
220
- private,
221
- btn_generate_full_dataset,
222
- btn_generate_and_push_to_hub,
223
- btn_push_to_hub,
224
- ) = get_hf_tab()
225
- final_dataset = get_final_dataset_row(default_datasets)
226
- success_message = get_success_message_row()
227
- return (
228
- dataset_name,
229
- add_to_existing_dataset,
230
- btn_generate_full_dataset_argilla,
231
- btn_generate_and_push_to_argilla,
232
- btn_push_to_argilla,
233
- org_name,
234
- repo_name,
235
- private,
236
- btn_generate_full_dataset,
237
- btn_generate_and_push_to_hub,
238
- btn_push_to_hub,
239
- final_dataset,
240
- success_message,
241
- )
242
-
243
-
244
- def get_iterate_on_sample_dataset_ui(
245
- default_dataset_descriptions: List[str],
246
- default_system_prompts: List[str],
247
- default_datasets: List[pd.DataFrame],
248
- task: str,
249
- ):
250
- with gr.Column():
251
- dataset_description = gr.TextArea(
252
- label="Give a precise description of your desired application. Check the examples for inspiration.",
253
- value=default_dataset_descriptions[0],
254
- lines=2,
255
- )
256
- examples = gr.Examples(
257
- elem_id="system_prompt_examples",
258
- examples=[[example] for example in default_dataset_descriptions],
259
- inputs=[dataset_description],
260
- )
261
- with gr.Row():
262
- gr.Column(scale=1)
263
- btn_generate_system_prompt = gr.Button(
264
- value="Generate system prompt and sample dataset", variant="primary"
265
- )
266
- gr.Column(scale=1)
267
-
268
- system_prompt = gr.TextArea(
269
- label="System prompt for dataset generation. You can tune it and regenerate the sample.",
270
- value=default_system_prompts[0],
271
- lines=2 if task == TEXTCAT_TASK else 5,
272
- )
273
-
274
- with gr.Row():
275
- sample_dataset = gr.Dataframe(
276
- value=default_datasets[0],
277
- label=(
278
- "Sample dataset. Text truncated to 256 tokens."
279
- if task == TEXTCAT_TASK
280
- else "Sample dataset. Prompts and completions truncated to 256 tokens."
281
- ),
282
- interactive=False,
283
- wrap=True,
284
- )
285
-
286
- with gr.Row():
287
- gr.Column(scale=1)
288
- btn_generate_sample_dataset = gr.Button(
289
- value="Generate sample dataset", variant="primary"
290
- )
291
- gr.Column(scale=1)
292
-
293
- return (
294
- dataset_description,
295
- examples,
296
- btn_generate_system_prompt,
297
- system_prompt,
298
- sample_dataset,
299
- btn_generate_sample_dataset,
300
- )
301
-
302
-
303
- def get_argilla_tab() -> Tuple[Any]:
304
- with gr.Tab(label="Argilla"):
305
- if get_argilla_client() is not None:
306
- with gr.Row(variant="panel"):
307
- dataset_name = gr.Textbox(
308
- label="Dataset name",
309
- placeholder="dataset_name",
310
- value="my-distiset",
311
- )
312
- add_to_existing_dataset = gr.Checkbox(
313
- label="Allow adding records to existing dataset",
314
- info="When selected, you do need to ensure the dataset options are the same as in the existing dataset.",
315
- value=False,
316
- interactive=True,
317
- scale=1,
318
- )
319
-
320
- with gr.Row(variant="panel"):
321
- btn_generate_full_dataset_argilla = gr.Button(
322
- value="Generate", variant="primary", scale=2
323
- )
324
- btn_generate_and_push_to_argilla = gr.Button(
325
- value="Generate and Push to Argilla",
326
- variant="primary",
327
- scale=2,
328
- )
329
- btn_push_to_argilla = gr.Button(
330
- value="Push to Argilla", variant="primary", scale=2
331
- )
332
- else:
333
- gr.Markdown(
334
- "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub."
335
- )
336
- return (
337
- dataset_name,
338
- add_to_existing_dataset,
339
- btn_generate_full_dataset_argilla,
340
- btn_generate_and_push_to_argilla,
341
- btn_push_to_argilla,
342
- )
343
-
344
-
345
- def get_hf_tab() -> Tuple[Any]:
346
- with gr.Tab("Hugging Face Hub"):
347
- with gr.Row(variant="panel"):
348
- org_name = get_org_dropdown()
349
- repo_name = gr.Textbox(
350
- label="Repo name",
351
- placeholder="dataset_name",
352
- value="my-distiset",
353
- )
354
- private = gr.Checkbox(
355
- label="Private dataset",
356
- value=True,
357
- interactive=True,
358
- scale=1,
359
- )
360
- with gr.Row(variant="panel"):
361
- btn_generate_full_dataset = gr.Button(
362
- value="Generate", variant="primary", scale=2
363
- )
364
- btn_generate_and_push_to_hub = gr.Button(
365
- value="Generate and Push to Hub", variant="primary", scale=2
366
- )
367
- btn_push_to_hub = gr.Button(value="Push to Hub", variant="primary", scale=2)
368
- return (
369
- org_name,
370
- repo_name,
371
- private,
372
- btn_generate_full_dataset,
373
- btn_generate_and_push_to_hub,
374
- btn_push_to_hub,
375
- )
376
-
377
-
378
  def push_pipeline_code_to_hub(
379
  pipeline_code: str,
380
  org_name: str,
@@ -455,24 +125,6 @@ def validate_push_to_hub(org_name, repo_name):
455
  return repo_id
456
 
457
 
458
- def get_final_dataset_row(default_datasets) -> gr.Dataframe:
459
- with gr.Row():
460
- final_dataset = gr.Dataframe(
461
- value=default_datasets[0],
462
- label="Generated dataset",
463
- interactive=False,
464
- wrap=True,
465
- min_width=300,
466
- )
467
- return final_dataset
468
-
469
-
470
- def get_success_message_row() -> gr.Markdown:
471
- with gr.Row():
472
- success_message = gr.Markdown(visible=False)
473
- return success_message
474
-
475
-
476
  def show_success_message(org_name, repo_name) -> gr.Markdown:
477
  client = get_argilla_client()
478
  if client is None:
 
1
  import io
2
  import uuid
3
+ from typing import List, Union
4
 
5
  import argilla as rg
6
  import gradio as gr
 
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
12
 
13
+ from distilabel_dataset_generator.constants import TEXTCAT_TASK
14
+ from distilabel_dataset_generator.utils import (
15
  get_argilla_client,
 
 
 
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def validate_argilla_user_workspace_dataset(
20
  dataset_name: str,
 
45
  return ""
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def push_pipeline_code_to_hub(
49
  pipeline_code: str,
50
  org_name: str,
 
125
  return repo_id
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def show_success_message(org_name, repo_name) -> gr.Markdown:
129
  client = get_argilla_client()
130
  if client is None:
src/distilabel_dataset_generator/apps/eval.py CHANGED
@@ -16,25 +16,23 @@ from distilabel.distiset import Distiset
16
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
17
  from huggingface_hub import HfApi
18
 
19
- from src.distilabel_dataset_generator.apps.base import (
20
  hide_success_message,
21
  show_success_message,
22
  validate_argilla_user_workspace_dataset,
23
  validate_push_to_hub,
24
  )
25
- from src.distilabel_dataset_generator.pipelines.base import (
26
- DEFAULT_BATCH_SIZE,
27
- )
28
- from src.distilabel_dataset_generator.pipelines.embeddings import (
29
  get_embeddings,
30
  get_sentence_embedding_dimensions,
31
  )
32
- from src.distilabel_dataset_generator.pipelines.eval import (
33
  generate_pipeline_code,
34
  get_custom_evaluator,
35
  get_ultrafeedback_evaluator,
36
  )
37
- from src.distilabel_dataset_generator.utils import (
38
  column_to_list,
39
  extract_column_names,
40
  get_argilla_client,
 
16
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
17
  from huggingface_hub import HfApi
18
 
19
+ from distilabel_dataset_generator.apps.base import (
20
  hide_success_message,
21
  show_success_message,
22
  validate_argilla_user_workspace_dataset,
23
  validate_push_to_hub,
24
  )
25
+ from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE
26
+ from distilabel_dataset_generator.pipelines.embeddings import (
 
 
27
  get_embeddings,
28
  get_sentence_embedding_dimensions,
29
  )
30
+ from distilabel_dataset_generator.pipelines.eval import (
31
  generate_pipeline_code,
32
  get_custom_evaluator,
33
  get_ultrafeedback_evaluator,
34
  )
35
+ from distilabel_dataset_generator.utils import (
36
  column_to_list,
37
  extract_column_names,
38
  get_argilla_client,
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -9,28 +9,25 @@ from datasets import Dataset
9
  from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
- from src.distilabel_dataset_generator.apps.base import (
13
  hide_success_message,
14
  show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
18
- from src.distilabel_dataset_generator.pipelines.base import (
19
- DEFAULT_BATCH_SIZE,
20
- )
21
- from src.distilabel_dataset_generator.pipelines.embeddings import (
22
  get_embeddings,
23
  get_sentence_embedding_dimensions,
24
  )
25
- from src.distilabel_dataset_generator.pipelines.sft import (
26
  DEFAULT_DATASET_DESCRIPTIONS,
27
  generate_pipeline_code,
28
  get_magpie_generator,
29
  get_prompt_generator,
30
  get_response_generator,
31
  )
32
- from src.distilabel_dataset_generator.utils import (
33
- _LOGGED_OUT_CSS,
34
  get_argilla_client,
35
  get_org_dropdown,
36
  swap_visibility,
@@ -352,170 +349,177 @@ def hide_pipeline_code_visibility():
352
  ######################
353
 
354
 
355
- with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
356
  with gr.Column() as main_ui:
357
- gr.Markdown(value="## 1. Describe the dataset you want")
358
- with gr.Row():
359
- with gr.Column(scale=2):
360
- dataset_description = gr.Textbox(
361
- label="Dataset description",
362
- placeholder="Give a precise description of your desired dataset.",
363
- )
364
- with gr.Accordion("Temperature", open=False):
365
- temperature = gr.Slider(
366
- minimum=0.1,
367
- maximum=1,
368
- value=0.8,
369
- step=0.1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  interactive=True,
371
- show_label=False,
372
  )
373
- load_btn = gr.Button(
374
- "Create dataset",
375
- variant="primary",
376
- )
377
- with gr.Column(scale=2):
378
- examples = gr.Examples(
379
- examples=DEFAULT_DATASET_DESCRIPTIONS,
380
- inputs=[dataset_description],
381
- cache_examples=False,
382
- label="Examples",
383
- )
384
- with gr.Column(scale=1):
385
- pass
386
-
387
- gr.HTML(value="<hr>")
388
- gr.Markdown(value="## 2. Configure your dataset")
389
- with gr.Row(equal_height=False):
390
- with gr.Column(scale=2):
391
- system_prompt = gr.Textbox(
392
- label="System prompt",
393
- placeholder="You are a helpful assistant.",
394
- )
395
- num_turns = gr.Number(
396
- value=1,
397
- label="Number of turns in the conversation",
398
- minimum=1,
399
- maximum=4,
400
- step=1,
401
- interactive=True,
402
- info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
403
- )
404
- btn_apply_to_sample_dataset = gr.Button(
405
- "Refresh dataset", variant="secondary"
406
- )
407
- with gr.Column(scale=3):
408
- dataframe = gr.Dataframe(
409
- headers=["prompt", "completion"],
410
- wrap=True,
411
- height=500,
412
- interactive=False,
413
- )
414
-
415
- gr.HTML(value="<hr>")
416
- gr.Markdown(value="## 3. Generate your dataset")
417
- with gr.Row(equal_height=False):
418
- with gr.Column(scale=2):
419
- org_name = get_org_dropdown()
420
- repo_name = gr.Textbox(
421
- label="Repo name",
422
- placeholder="dataset_name",
423
- value=f"my-distiset-{str(uuid.uuid4())[:8]}",
424
- interactive=True,
425
- )
426
- num_rows = gr.Number(
427
- label="Number of rows",
428
- value=10,
429
- interactive=True,
430
- scale=1,
431
- )
432
- private = gr.Checkbox(
433
- label="Private dataset",
434
- value=False,
435
- interactive=True,
436
- scale=1,
437
- )
438
- btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
439
- with gr.Column(scale=3):
440
- success_message = gr.Markdown(visible=True)
441
- with gr.Accordion(
442
- "Do you want to go further? Customize and run with Distilabel",
443
- open=False,
444
- visible=False,
445
- ) as pipeline_code_ui:
446
- code = generate_pipeline_code(
447
- system_prompt=system_prompt.value,
448
- num_turns=num_turns.value,
449
- num_rows=num_rows.value,
450
  )
451
- pipeline_code = gr.Code(
452
- value=code,
453
- language="python",
454
- label="Distilabel Pipeline Code",
 
 
455
  )
456
 
457
- load_btn.click(
458
- fn=generate_system_prompt,
459
- inputs=[dataset_description, temperature],
460
- outputs=[system_prompt],
461
- show_progress=True,
462
- ).then(
463
- fn=generate_sample_dataset,
464
- inputs=[system_prompt, num_turns],
465
- outputs=[dataframe],
466
- show_progress=True,
467
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- btn_apply_to_sample_dataset.click(
470
- fn=generate_sample_dataset,
471
- inputs=[system_prompt, num_turns],
472
- outputs=[dataframe],
473
- show_progress=True,
474
- )
475
 
476
- btn_push_to_hub.click(
477
- fn=validate_argilla_user_workspace_dataset,
478
- inputs=[repo_name],
479
- outputs=[success_message],
480
- show_progress=True,
481
- ).then(
482
- fn=validate_push_to_hub,
483
- inputs=[org_name, repo_name],
484
- outputs=[success_message],
485
- show_progress=True,
486
- ).success(
487
- fn=hide_success_message,
488
- outputs=[success_message],
489
- show_progress=True,
490
- ).success(
491
- fn=hide_pipeline_code_visibility,
492
- inputs=[],
493
- outputs=[pipeline_code_ui],
494
- ).success(
495
- fn=push_dataset,
496
- inputs=[
497
- org_name,
498
- repo_name,
499
- system_prompt,
500
- num_turns,
501
- num_rows,
502
- private,
503
- ],
504
- outputs=[success_message],
505
- show_progress=True,
506
- ).success(
507
- fn=show_success_message,
508
- inputs=[org_name, repo_name],
509
- outputs=[success_message],
510
- ).success(
511
- fn=generate_pipeline_code,
512
- inputs=[system_prompt, num_turns, num_rows],
513
- outputs=[pipeline_code],
514
- ).success(
515
- fn=show_pipeline_code_visibility,
516
- inputs=[],
517
- outputs=[pipeline_code_ui],
518
- )
519
 
520
- app.load(fn=swap_visibility, outputs=main_ui)
521
- app.load(fn=get_org_dropdown, outputs=[org_name])
 
9
  from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
+ from distilabel_dataset_generator.apps.base import (
13
  hide_success_message,
14
  show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
18
+ from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE
19
+ from distilabel_dataset_generator.pipelines.embeddings import (
 
 
20
  get_embeddings,
21
  get_sentence_embedding_dimensions,
22
  )
23
+ from distilabel_dataset_generator.pipelines.sft import (
24
  DEFAULT_DATASET_DESCRIPTIONS,
25
  generate_pipeline_code,
26
  get_magpie_generator,
27
  get_prompt_generator,
28
  get_response_generator,
29
  )
30
+ from distilabel_dataset_generator.utils import (
 
31
  get_argilla_client,
32
  get_org_dropdown,
33
  swap_visibility,
 
349
  ######################
350
 
351
 
352
+ with gr.Blocks() as app:
353
  with gr.Column() as main_ui:
354
+ if not SFT_AVAILABLE:
355
+ gr.Markdown(
356
+ value=f"## Supervised Fine-Tuning is not available for the {MODEL} model. Use Hugging Face Llama3 or Qwen2 models."
357
+ )
358
+ else:
359
+ gr.Markdown(value="## 1. Describe the dataset you want")
360
+ with gr.Row():
361
+ with gr.Column(scale=2):
362
+ dataset_description = gr.Textbox(
363
+ label="Dataset description",
364
+ placeholder="Give a precise description of your desired dataset.",
365
+ )
366
+ with gr.Accordion("Temperature", open=False):
367
+ temperature = gr.Slider(
368
+ minimum=0.1,
369
+ maximum=1,
370
+ value=0.8,
371
+ step=0.1,
372
+ interactive=True,
373
+ show_label=False,
374
+ )
375
+ load_btn = gr.Button(
376
+ "Create dataset",
377
+ variant="primary",
378
+ )
379
+ with gr.Column(scale=2):
380
+ examples = gr.Examples(
381
+ examples=DEFAULT_DATASET_DESCRIPTIONS,
382
+ inputs=[dataset_description],
383
+ cache_examples=False,
384
+ label="Examples",
385
+ )
386
+ with gr.Column(scale=1):
387
+ pass
388
+
389
+ gr.HTML(value="<hr>")
390
+ gr.Markdown(value="## 2. Configure your dataset")
391
+ with gr.Row(equal_height=False):
392
+ with gr.Column(scale=2):
393
+ system_prompt = gr.Textbox(
394
+ label="System prompt",
395
+ placeholder="You are a helpful assistant.",
396
+ )
397
+ num_turns = gr.Number(
398
+ value=1,
399
+ label="Number of turns in the conversation",
400
+ minimum=1,
401
+ maximum=4,
402
+ step=1,
403
  interactive=True,
404
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
405
  )
406
+ btn_apply_to_sample_dataset = gr.Button(
407
+ "Refresh dataset", variant="secondary"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  )
409
+ with gr.Column(scale=3):
410
+ dataframe = gr.Dataframe(
411
+ headers=["prompt", "completion"],
412
+ wrap=True,
413
+ height=500,
414
+ interactive=False,
415
  )
416
 
417
+ gr.HTML(value="<hr>")
418
+ gr.Markdown(value="## 3. Generate your dataset")
419
+ with gr.Row(equal_height=False):
420
+ with gr.Column(scale=2):
421
+ org_name = get_org_dropdown()
422
+ repo_name = gr.Textbox(
423
+ label="Repo name",
424
+ placeholder="dataset_name",
425
+ value=f"my-distiset-{str(uuid.uuid4())[:8]}",
426
+ interactive=True,
427
+ )
428
+ num_rows = gr.Number(
429
+ label="Number of rows",
430
+ value=10,
431
+ interactive=True,
432
+ scale=1,
433
+ )
434
+ private = gr.Checkbox(
435
+ label="Private dataset",
436
+ value=False,
437
+ interactive=True,
438
+ scale=1,
439
+ )
440
+ btn_push_to_hub = gr.Button(
441
+ "Push to Hub", variant="primary", scale=2
442
+ )
443
+ with gr.Column(scale=3):
444
+ success_message = gr.Markdown(visible=True)
445
+ with gr.Accordion(
446
+ "Do you want to go further? Customize and run with Distilabel",
447
+ open=False,
448
+ visible=False,
449
+ ) as pipeline_code_ui:
450
+ code = generate_pipeline_code(
451
+ system_prompt=system_prompt.value,
452
+ num_turns=num_turns.value,
453
+ num_rows=num_rows.value,
454
+ )
455
+ pipeline_code = gr.Code(
456
+ value=code,
457
+ language="python",
458
+ label="Distilabel Pipeline Code",
459
+ )
460
+
461
+ load_btn.click(
462
+ fn=generate_system_prompt,
463
+ inputs=[dataset_description, temperature],
464
+ outputs=[system_prompt],
465
+ show_progress=True,
466
+ ).then(
467
+ fn=generate_sample_dataset,
468
+ inputs=[system_prompt, num_turns],
469
+ outputs=[dataframe],
470
+ show_progress=True,
471
+ )
472
 
473
+ btn_apply_to_sample_dataset.click(
474
+ fn=generate_sample_dataset,
475
+ inputs=[system_prompt, num_turns],
476
+ outputs=[dataframe],
477
+ show_progress=True,
478
+ )
479
 
480
+ btn_push_to_hub.click(
481
+ fn=validate_argilla_user_workspace_dataset,
482
+ inputs=[repo_name],
483
+ outputs=[success_message],
484
+ show_progress=True,
485
+ ).then(
486
+ fn=validate_push_to_hub,
487
+ inputs=[org_name, repo_name],
488
+ outputs=[success_message],
489
+ show_progress=True,
490
+ ).success(
491
+ fn=hide_success_message,
492
+ outputs=[success_message],
493
+ show_progress=True,
494
+ ).success(
495
+ fn=hide_pipeline_code_visibility,
496
+ inputs=[],
497
+ outputs=[pipeline_code_ui],
498
+ ).success(
499
+ fn=push_dataset,
500
+ inputs=[
501
+ org_name,
502
+ repo_name,
503
+ system_prompt,
504
+ num_turns,
505
+ num_rows,
506
+ private,
507
+ ],
508
+ outputs=[success_message],
509
+ show_progress=True,
510
+ ).success(
511
+ fn=show_success_message,
512
+ inputs=[org_name, repo_name],
513
+ outputs=[success_message],
514
+ ).success(
515
+ fn=generate_pipeline_code,
516
+ inputs=[system_prompt, num_turns, num_rows],
517
+ outputs=[pipeline_code],
518
+ ).success(
519
+ fn=show_pipeline_code_visibility,
520
+ inputs=[],
521
+ outputs=[pipeline_code_ui],
522
+ )
523
 
524
+ app.load(fn=swap_visibility, outputs=main_ui)
525
+ app.load(fn=get_org_dropdown, outputs=[org_name])
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -9,15 +9,13 @@ from datasets import ClassLabel, Dataset, Features, Sequence, Value
9
  from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
 
12
  from src.distilabel_dataset_generator.apps.base import (
13
  hide_success_message,
14
  show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
18
- from src.distilabel_dataset_generator.pipelines.base import (
19
- DEFAULT_BATCH_SIZE,
20
- )
21
  from src.distilabel_dataset_generator.pipelines.embeddings import (
22
  get_embeddings,
23
  get_sentence_embedding_dimensions,
@@ -30,7 +28,6 @@ from src.distilabel_dataset_generator.pipelines.textcat import (
30
  get_textcat_generator,
31
  )
32
  from src.distilabel_dataset_generator.utils import (
33
- _LOGGED_OUT_CSS,
34
  get_argilla_client,
35
  get_org_dropdown,
36
  get_preprocess_labels,
@@ -334,7 +331,7 @@ def hide_pipeline_code_visibility():
334
  ######################
335
 
336
 
337
- with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
338
  with gr.Column() as main_ui:
339
  gr.Markdown("## 1. Describe the dataset you want")
340
  with gr.Row():
 
9
  from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
+ from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE
13
  from src.distilabel_dataset_generator.apps.base import (
14
  hide_success_message,
15
  show_success_message,
16
  validate_argilla_user_workspace_dataset,
17
  validate_push_to_hub,
18
  )
 
 
 
19
  from src.distilabel_dataset_generator.pipelines.embeddings import (
20
  get_embeddings,
21
  get_sentence_embedding_dimensions,
 
28
  get_textcat_generator,
29
  )
30
  from src.distilabel_dataset_generator.utils import (
 
31
  get_argilla_client,
32
  get_org_dropdown,
33
  get_preprocess_labels,
 
331
  ######################
332
 
333
 
334
+ with gr.Blocks() as app:
335
  with gr.Column() as main_ui:
336
  gr.Markdown("## 1. Describe the dataset you want")
337
  with gr.Row():
src/distilabel_dataset_generator/constants.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ import argilla as rg
5
+
6
+ # Tasks
7
+ TEXTCAT_TASK = "text_classification"
8
+ SFT_TASK = "supervised_fine_tuning"
9
+
10
+ # Hugging Face
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+ if HF_TOKEN is None:
13
+ raise ValueError(
14
+ "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
15
+ )
16
+
17
+ # Inference
18
+ DEFAULT_BATCH_SIZE = 5
19
+ MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
20
+ API_KEYS = (
21
+ [os.getenv("HF_TOKEN")]
22
+ + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
23
+ + [os.getenv("API_KEY")]
24
+ )
25
+ API_KEYS = [token for token in API_KEYS if token]
26
+ BASE_URL = os.getenv("BASE_URL", "https://api-inference.huggingface.co/v1/")
27
+
28
+ if BASE_URL != "https://api-inference.huggingface.co/v1/" and len(API_KEYS) == 0:
29
+ raise ValueError(
30
+ "API_KEY is not set. Ensure you have set the API_KEY environment variable that has access to the Hugging Face Inference Endpoints."
31
+ )
32
+ if "Qwen2" not in MODEL and "Llama-3" not in MODEL:
33
+ SFT_AVAILABLE = False
34
+ warnings.warn(
35
+ "SFT_AVAILABLE is set to False because the model is not a Qwen or Llama model."
36
+ )
37
+ MAGPIE_PRE_QUERY_TEMPLATE = None
38
+ else:
39
+ SFT_AVAILABLE = True
40
+ if "Qwen2" in MODEL:
41
+ MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
42
+ else:
43
+ MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
44
+
45
+ # Embeddings
46
+ STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
47
+
48
+ # Argilla
49
+ ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
50
+ ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
51
+ if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
52
+ ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
53
+ ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
54
+
55
+ if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
56
+ warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set")
57
+ argilla_client = None
58
+ else:
59
+ argilla_client = rg.Argilla(
60
+ api_url=ARGILLA_API_URL,
61
+ api_key=ARGILLA_API_KEY,
62
+ )
src/distilabel_dataset_generator/pipelines/__init__.py ADDED
File without changes
src/distilabel_dataset_generator/pipelines/base.py CHANGED
@@ -1,12 +1,10 @@
1
- from src.distilabel_dataset_generator import HF_TOKENS
2
 
3
- DEFAULT_BATCH_SIZE = 5
4
  TOKEN_INDEX = 0
5
- MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
6
 
7
 
8
  def _get_next_api_key():
9
  global TOKEN_INDEX
10
- api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
11
  TOKEN_INDEX += 1
12
  return api_key
 
1
+ from distilabel_dataset_generator.constants import API_KEYS
2
 
 
3
  TOKEN_INDEX = 0
 
4
 
5
 
6
  def _get_next_api_key():
7
  global TOKEN_INDEX
8
+ api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
9
  TOKEN_INDEX += 1
10
  return api_key
src/distilabel_dataset_generator/pipelines/embeddings.py CHANGED
@@ -3,8 +3,9 @@ from typing import List
3
  from sentence_transformers import SentenceTransformer
4
  from sentence_transformers.models import StaticEmbedding
5
 
6
- # Initialize a StaticEmbedding module
7
- static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
 
8
  model = SentenceTransformer(modules=[static_embedding])
9
 
10
 
 
3
  from sentence_transformers import SentenceTransformer
4
  from sentence_transformers.models import StaticEmbedding
5
 
6
+ from distilabel_dataset_generator.constants import STATIC_EMBEDDING_MODEL
7
+
8
+ static_embedding = StaticEmbedding.from_model2vec(STATIC_EMBEDDING_MODEL)
9
  model = SentenceTransformer(modules=[static_embedding])
10
 
11
 
src/distilabel_dataset_generator/pipelines/eval.py CHANGED
@@ -5,18 +5,16 @@ from distilabel.steps.tasks import (
5
  UltraFeedback,
6
  )
7
 
8
- from src.distilabel_dataset_generator.pipelines.base import (
9
- MODEL,
10
- _get_next_api_key,
11
- )
12
- from src.distilabel_dataset_generator.utils import extract_column_names
13
 
14
 
15
  def get_ultrafeedback_evaluator(aspect, is_sample):
16
  ultrafeedback_evaluator = UltraFeedback(
17
  llm=InferenceEndpointsLLM(
18
  model_id=MODEL,
19
- tokenizer_id=MODEL,
20
  api_key=_get_next_api_key(),
21
  generation_kwargs={
22
  "temperature": 0,
@@ -33,7 +31,7 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
33
  custom_evaluator = TextGeneration(
34
  llm=InferenceEndpointsLLM(
35
  model_id=MODEL,
36
- tokenizer_id=MODEL,
37
  api_key=_get_next_api_key(),
38
  structured_output={"format": "json", "schema": structured_output},
39
  generation_kwargs={
@@ -62,7 +60,8 @@ from distilabel.steps.tasks import UltraFeedback
62
  from distilabel.llms import InferenceEndpointsLLM
63
 
64
  MODEL = "{MODEL}"
65
- os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
 
66
 
67
  hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
68
  data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
@@ -76,8 +75,8 @@ with Pipeline(name="ultrafeedback") as pipeline:
76
  ultrafeedback_evaluator = UltraFeedback(
77
  llm=InferenceEndpointsLLM(
78
  model_id=MODEL,
79
- tokenizer_id=MODEL,
80
- api_key=os.environ["HF_TOKEN"],
81
  generation_kwargs={{
82
  "temperature": 0,
83
  "max_new_tokens": 2048,
@@ -101,7 +100,8 @@ from distilabel.steps.tasks import UltraFeedback
101
  from distilabel.llms import InferenceEndpointsLLM
102
 
103
  MODEL = "{MODEL}"
104
- os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
 
105
 
106
  hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
107
  data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
@@ -119,8 +119,8 @@ with Pipeline(name="ultrafeedback") as pipeline:
119
  aspect=aspect,
120
  llm=InferenceEndpointsLLM(
121
  model_id=MODEL,
122
- tokenizer_id=MODEL,
123
- api_key=os.environ["HF_TOKEN"],
124
  generation_kwargs={{
125
  "temperature": 0,
126
  "max_new_tokens": 2048,
@@ -157,6 +157,7 @@ from distilabel.steps.tasks import TextGeneration
157
  from distilabel.llms import InferenceEndpointsLLM
158
 
159
  MODEL = "{MODEL}"
 
160
  CUSTOM_TEMPLATE = "{prompt_template}"
161
  os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
162
 
@@ -171,7 +172,7 @@ with Pipeline(name="custom-evaluation") as pipeline:
171
  custom_evaluator = TextGeneration(
172
  llm=InferenceEndpointsLLM(
173
  model_id=MODEL,
174
- tokenizer_id=MODEL,
175
  api_key=os.environ["HF_TOKEN"],
176
  structured_output={{"format": "json", "schema": {structured_output}}},
177
  generation_kwargs={{
 
5
  UltraFeedback,
6
  )
7
 
8
+ from distilabel_dataset_generator.constants import BASE_URL, MODEL
9
+ from distilabel_dataset_generator.pipelines.base import _get_next_api_key
10
+ from distilabel_dataset_generator.utils import extract_column_names
 
 
11
 
12
 
13
  def get_ultrafeedback_evaluator(aspect, is_sample):
14
  ultrafeedback_evaluator = UltraFeedback(
15
  llm=InferenceEndpointsLLM(
16
  model_id=MODEL,
17
+ base_url=BASE_URL,
18
  api_key=_get_next_api_key(),
19
  generation_kwargs={
20
  "temperature": 0,
 
31
  custom_evaluator = TextGeneration(
32
  llm=InferenceEndpointsLLM(
33
  model_id=MODEL,
34
+ base_url=BASE_URL,
35
  api_key=_get_next_api_key(),
36
  structured_output={"format": "json", "schema": structured_output},
37
  generation_kwargs={
 
60
  from distilabel.llms import InferenceEndpointsLLM
61
 
62
  MODEL = "{MODEL}"
63
+ BASE_URL = "{BASE_URL}"
64
+ os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
65
 
66
  hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
67
  data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
 
75
  ultrafeedback_evaluator = UltraFeedback(
76
  llm=InferenceEndpointsLLM(
77
  model_id=MODEL,
78
+ base_url=BASE_URL,
79
+ api_key=os.environ["API_KEY"],
80
  generation_kwargs={{
81
  "temperature": 0,
82
  "max_new_tokens": 2048,
 
100
  from distilabel.llms import InferenceEndpointsLLM
101
 
102
  MODEL = "{MODEL}"
103
+ BASE_URL = "{BASE_URL}"
104
+ os.environ["BASE_URL"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
105
 
106
  hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
107
  data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
 
119
  aspect=aspect,
120
  llm=InferenceEndpointsLLM(
121
  model_id=MODEL,
122
+ base_url=BASE_URL,
123
+ api_key=os.environ["BASE_URL"],
124
  generation_kwargs={{
125
  "temperature": 0,
126
  "max_new_tokens": 2048,
 
157
  from distilabel.llms import InferenceEndpointsLLM
158
 
159
  MODEL = "{MODEL}"
160
+ BASE_URL = "{BASE_URL}"
161
  CUSTOM_TEMPLATE = "{prompt_template}"
162
  os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
163
 
 
172
  custom_evaluator = TextGeneration(
173
  llm=InferenceEndpointsLLM(
174
  model_id=MODEL,
175
+ base_url=BASE_URL,
176
  api_key=os.environ["HF_TOKEN"],
177
  structured_output={{"format": "json", "schema": {structured_output}}},
178
  generation_kwargs={{
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -1,10 +1,12 @@
1
  from distilabel.llms import InferenceEndpointsLLM
2
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
3
 
4
- from src.distilabel_dataset_generator.pipelines.base import (
 
 
5
  MODEL,
6
- _get_next_api_key,
7
  )
 
8
 
9
  INFORMATION_SEEKING_PROMPT = (
10
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -144,6 +146,7 @@ def get_prompt_generator(temperature):
144
  api_key=_get_next_api_key(),
145
  model_id=MODEL,
146
  tokenizer_id=MODEL,
 
147
  generation_kwargs={
148
  "temperature": temperature,
149
  "max_new_tokens": 2048,
@@ -165,8 +168,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample):
165
  llm=InferenceEndpointsLLM(
166
  model_id=MODEL,
167
  tokenizer_id=MODEL,
 
168
  api_key=_get_next_api_key(),
169
- magpie_pre_query_template="llama3",
170
  generation_kwargs={
171
  "temperature": 0.9,
172
  "do_sample": True,
@@ -184,8 +188,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample):
184
  llm=InferenceEndpointsLLM(
185
  model_id=MODEL,
186
  tokenizer_id=MODEL,
 
187
  api_key=_get_next_api_key(),
188
- magpie_pre_query_template="llama3",
189
  generation_kwargs={
190
  "temperature": 0.9,
191
  "do_sample": True,
@@ -208,6 +213,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
208
  llm=InferenceEndpointsLLM(
209
  model_id=MODEL,
210
  tokenizer_id=MODEL,
 
211
  api_key=_get_next_api_key(),
212
  generation_kwargs={
213
  "temperature": 0.8,
@@ -223,6 +229,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
223
  llm=InferenceEndpointsLLM(
224
  model_id=MODEL,
225
  tokenizer_id=MODEL,
 
226
  api_key=_get_next_api_key(),
227
  generation_kwargs={
228
  "temperature": 0.8,
@@ -247,14 +254,16 @@ from distilabel.steps.tasks import MagpieGenerator
247
  from distilabel.llms import InferenceEndpointsLLM
248
 
249
  MODEL = "{MODEL}"
 
250
  SYSTEM_PROMPT = "{system_prompt}"
251
- os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
252
 
253
  with Pipeline(name="sft") as pipeline:
254
  magpie = MagpieGenerator(
255
  llm=InferenceEndpointsLLM(
256
  model_id=MODEL,
257
  tokenizer_id=MODEL,
 
258
  magpie_pre_query_template="llama3",
259
  generation_kwargs={{
260
  "temperature": 0.9,
@@ -262,7 +271,7 @@ with Pipeline(name="sft") as pipeline:
262
  "max_new_tokens": 2048,
263
  "stop_sequences": {_STOP_SEQUENCES}
264
  }},
265
- api_key=os.environ["HF_TOKEN"],
266
  ),
267
  n_turns={num_turns},
268
  num_rows={num_rows},
 
1
  from distilabel.llms import InferenceEndpointsLLM
2
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
3
 
4
+ from distilabel_dataset_generator.constants import (
5
+ BASE_URL,
6
+ MAGPIE_PRE_QUERY_TEMPLATE,
7
  MODEL,
 
8
  )
9
+ from distilabel_dataset_generator.pipelines.base import _get_next_api_key
10
 
11
  INFORMATION_SEEKING_PROMPT = (
12
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
146
  api_key=_get_next_api_key(),
147
  model_id=MODEL,
148
  tokenizer_id=MODEL,
149
+ base_url=BASE_URL,
150
  generation_kwargs={
151
  "temperature": temperature,
152
  "max_new_tokens": 2048,
 
168
  llm=InferenceEndpointsLLM(
169
  model_id=MODEL,
170
  tokenizer_id=MODEL,
171
+ base_url=BASE_URL,
172
  api_key=_get_next_api_key(),
173
+ magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
174
  generation_kwargs={
175
  "temperature": 0.9,
176
  "do_sample": True,
 
188
  llm=InferenceEndpointsLLM(
189
  model_id=MODEL,
190
  tokenizer_id=MODEL,
191
+ base_url=BASE_URL,
192
  api_key=_get_next_api_key(),
193
+ magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
194
  generation_kwargs={
195
  "temperature": 0.9,
196
  "do_sample": True,
 
213
  llm=InferenceEndpointsLLM(
214
  model_id=MODEL,
215
  tokenizer_id=MODEL,
216
+ base_url=BASE_URL,
217
  api_key=_get_next_api_key(),
218
  generation_kwargs={
219
  "temperature": 0.8,
 
229
  llm=InferenceEndpointsLLM(
230
  model_id=MODEL,
231
  tokenizer_id=MODEL,
232
+ base_url=BASE_URL,
233
  api_key=_get_next_api_key(),
234
  generation_kwargs={
235
  "temperature": 0.8,
 
254
  from distilabel.llms import InferenceEndpointsLLM
255
 
256
  MODEL = "{MODEL}"
257
+ BASE_URL = "{BASE_URL}"
258
  SYSTEM_PROMPT = "{system_prompt}"
259
+ os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
260
 
261
  with Pipeline(name="sft") as pipeline:
262
  magpie = MagpieGenerator(
263
  llm=InferenceEndpointsLLM(
264
  model_id=MODEL,
265
  tokenizer_id=MODEL,
266
+ base_url=BASE_URL,
267
  magpie_pre_query_template="llama3",
268
  generation_kwargs={{
269
  "temperature": 0.9,
 
271
  "max_new_tokens": 2048,
272
  "stop_sequences": {_STOP_SEQUENCES}
273
  }},
274
+ api_key=os.environ["BASE_URL"],
275
  ),
276
  n_turns={num_turns},
277
  num_rows={num_rows},
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,5 +1,4 @@
1
  import random
2
- from pydantic import BaseModel, Field
3
  from typing import List
4
 
5
  from distilabel.llms import InferenceEndpointsLLM
@@ -8,12 +7,11 @@ from distilabel.steps.tasks import (
8
  TextClassification,
9
  TextGeneration,
10
  )
 
11
 
12
- from src.distilabel_dataset_generator.pipelines.base import (
13
- MODEL,
14
- _get_next_api_key,
15
- )
16
- from src.distilabel_dataset_generator.utils import get_preprocess_labels
17
 
18
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
19
 
@@ -73,7 +71,7 @@ def get_prompt_generator(temperature):
73
  llm=InferenceEndpointsLLM(
74
  api_key=_get_next_api_key(),
75
  model_id=MODEL,
76
- tokenizer_id=MODEL,
77
  structured_output={"format": "json", "schema": TextClassificationTask},
78
  generation_kwargs={
79
  "temperature": temperature,
@@ -92,7 +90,7 @@ def get_textcat_generator(difficulty, clarity, is_sample):
92
  textcat_generator = GenerateTextClassificationData(
93
  llm=InferenceEndpointsLLM(
94
  model_id=MODEL,
95
- tokenizer_id=MODEL,
96
  api_key=_get_next_api_key(),
97
  generation_kwargs={
98
  "temperature": 0.9,
@@ -114,7 +112,7 @@ def get_labeller_generator(system_prompt, labels, num_labels):
114
  labeller_generator = TextClassification(
115
  llm=InferenceEndpointsLLM(
116
  model_id=MODEL,
117
- tokenizer_id=MODEL,
118
  api_key=_get_next_api_key(),
119
  generation_kwargs={
120
  "temperature": 0.7,
@@ -149,8 +147,9 @@ from distilabel.steps import LoadDataFromDicts, KeepColumns
149
  from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
150
 
151
  MODEL = "{MODEL}"
 
152
  TEXT_CLASSIFICATION_TASK = "{system_prompt}"
153
- os.environ["HF_TOKEN"] = (
154
  "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
155
  )
156
 
@@ -161,8 +160,8 @@ with Pipeline(name="textcat") as pipeline:
161
  textcat_generation = GenerateTextClassificationData(
162
  llm=InferenceEndpointsLLM(
163
  model_id=MODEL,
164
- tokenizer_id=MODEL,
165
- api_key=os.environ["HF_TOKEN"],
166
  generation_kwargs={{
167
  "temperature": 0.8,
168
  "max_new_tokens": 2048,
@@ -205,8 +204,8 @@ with Pipeline(name="textcat") as pipeline:
205
  textcat_labeller = TextClassification(
206
  llm=InferenceEndpointsLLM(
207
  model_id=MODEL,
208
- tokenizer_id=MODEL,
209
- api_key=os.environ["HF_TOKEN"],
210
  generation_kwargs={{
211
  "temperature": 0.8,
212
  "max_new_tokens": 2048,
 
1
  import random
 
2
  from typing import List
3
 
4
  from distilabel.llms import InferenceEndpointsLLM
 
7
  TextClassification,
8
  TextGeneration,
9
  )
10
+ from pydantic import BaseModel, Field
11
 
12
+ from distilabel_dataset_generator.constants import BASE_URL, MODEL
13
+ from distilabel_dataset_generator.pipelines.base import _get_next_api_key
14
+ from distilabel_dataset_generator.utils import get_preprocess_labels
 
 
15
 
16
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
17
 
 
71
  llm=InferenceEndpointsLLM(
72
  api_key=_get_next_api_key(),
73
  model_id=MODEL,
74
+ base_url=BASE_URL,
75
  structured_output={"format": "json", "schema": TextClassificationTask},
76
  generation_kwargs={
77
  "temperature": temperature,
 
90
  textcat_generator = GenerateTextClassificationData(
91
  llm=InferenceEndpointsLLM(
92
  model_id=MODEL,
93
+ base_url=BASE_URL,
94
  api_key=_get_next_api_key(),
95
  generation_kwargs={
96
  "temperature": 0.9,
 
112
  labeller_generator = TextClassification(
113
  llm=InferenceEndpointsLLM(
114
  model_id=MODEL,
115
+ base_url=BASE_URL,
116
  api_key=_get_next_api_key(),
117
  generation_kwargs={
118
  "temperature": 0.7,
 
147
  from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
148
 
149
  MODEL = "{MODEL}"
150
+ BASE_URL = "{BASE_URL}"
151
  TEXT_CLASSIFICATION_TASK = "{system_prompt}"
152
+ os.environ["API_KEY"] = (
153
  "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
154
  )
155
 
 
160
  textcat_generation = GenerateTextClassificationData(
161
  llm=InferenceEndpointsLLM(
162
  model_id=MODEL,
163
+ base_url=BASE_URL,
164
+ api_key=os.environ["API_KEY"],
165
  generation_kwargs={{
166
  "temperature": 0.8,
167
  "max_new_tokens": 2048,
 
204
  textcat_labeller = TextClassification(
205
  llm=InferenceEndpointsLLM(
206
  model_id=MODEL,
207
+ base_url=BASE_URL,
208
+ api_key=os.environ["API_KEY"],
209
  generation_kwargs={{
210
  "temperature": 0.8,
211
  "max_new_tokens": 2048,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -6,40 +6,13 @@ import gradio as gr
6
  import numpy as np
7
  import pandas as pd
8
  from gradio.oauth import (
9
- OAUTH_CLIENT_ID,
10
- OAUTH_CLIENT_SECRET,
11
- OAUTH_SCOPES,
12
- OPENID_PROVIDER_URL,
13
  get_space,
14
  )
15
  from huggingface_hub import whoami
16
  from jinja2 import Environment, meta
17
 
18
- from src.distilabel_dataset_generator import argilla_client
19
-
20
- _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
21
-
22
-
23
- _CHECK_IF_SPACE_IS_SET = (
24
- all(
25
- [
26
- OAUTH_CLIENT_ID,
27
- OAUTH_CLIENT_SECRET,
28
- OAUTH_SCOPES,
29
- OPENID_PROVIDER_URL,
30
- ]
31
- )
32
- or get_space() is None
33
- )
34
-
35
- if _CHECK_IF_SPACE_IS_SET:
36
- from gradio.oauth import OAuthToken
37
- else:
38
- OAuthToken = str
39
-
40
-
41
- def get_login_button():
42
- return gr.LoginButton(value="Sign in!", size="sm", scale=2).activate()
43
 
44
 
45
  def get_duplicate_button():
@@ -85,13 +58,6 @@ def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
85
  )
86
 
87
 
88
- def get_token(oauth_token: Union[OAuthToken, None]):
89
- if oauth_token:
90
- return oauth_token.token
91
- else:
92
- return ""
93
-
94
-
95
  def swap_visibility(oauth_token: Union[OAuthToken, None]):
96
  if oauth_token:
97
  return gr.update(elem_classes=["main_ui_logged_in"])
@@ -99,28 +65,6 @@ def swap_visibility(oauth_token: Union[OAuthToken, None]):
99
  return gr.update(elem_classes=["main_ui_logged_out"])
100
 
101
 
102
- def get_base_app():
103
- with gr.Blocks(
104
- title="🧬 Synthetic Data Generator",
105
- head="🧬 Synthetic Data Generator",
106
- css=_LOGGED_OUT_CSS,
107
- ) as app:
108
- with gr.Row():
109
- gr.Markdown(
110
- "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
111
- )
112
- with gr.Row():
113
- gr.Column()
114
- get_login_button()
115
- gr.Column()
116
-
117
- gr.Markdown("## Iterate on a sample dataset")
118
- with gr.Column() as main_ui:
119
- pass
120
-
121
- return app
122
-
123
-
124
  def get_argilla_client() -> Union[rg.Argilla, None]:
125
  return argilla_client
126
 
 
6
  import numpy as np
7
  import pandas as pd
8
  from gradio.oauth import (
9
+ OAuthToken,
 
 
 
10
  get_space,
11
  )
12
  from huggingface_hub import whoami
13
  from jinja2 import Environment, meta
14
 
15
+ from distilabel_dataset_generator.constants import argilla_client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def get_duplicate_button():
 
58
  )
59
 
60
 
 
 
 
 
 
 
 
61
  def swap_visibility(oauth_token: Union[OAuthToken, None]):
62
  if oauth_token:
63
  return gr.update(elem_classes=["main_ui_logged_in"])
 
65
  return gr.update(elem_classes=["main_ui_logged_out"])
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def get_argilla_client() -> Union[rg.Argilla, None]:
69
  return argilla_client
70