Merge pull request #4 from argilla-io/feat/choose-models
Browse files- .python-version +0 -1
- README.md +32 -20
- app.py +1 -35
- pyproject.toml +12 -4
- src/distilabel_dataset_generator/__init__.py +50 -24
- src/distilabel_dataset_generator/_tabbedinterface.py +3 -1
- src/distilabel_dataset_generator/app.py +38 -0
- src/distilabel_dataset_generator/apps/__init__.py +0 -0
- src/distilabel_dataset_generator/apps/base.py +3 -351
- src/distilabel_dataset_generator/apps/eval.py +5 -7
- src/distilabel_dataset_generator/apps/sft.py +170 -166
- src/distilabel_dataset_generator/apps/textcat.py +2 -5
- src/distilabel_dataset_generator/constants.py +62 -0
- src/distilabel_dataset_generator/pipelines/__init__.py +0 -0
- src/distilabel_dataset_generator/pipelines/base.py +2 -4
- src/distilabel_dataset_generator/pipelines/embeddings.py +3 -2
- src/distilabel_dataset_generator/pipelines/eval.py +15 -14
- src/distilabel_dataset_generator/pipelines/sft.py +15 -6
- src/distilabel_dataset_generator/pipelines/textcat.py +13 -14
- src/distilabel_dataset_generator/utils.py +2 -58
.python-version
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
synthetic-data-generator
|
|
|
|
README.md
CHANGED
@@ -20,25 +20,13 @@ hf_oauth_scopes:
|
|
20 |
|
21 |
<h1 align="center">
|
22 |
<br>
|
23 |
-
Synthetic Data Generator
|
24 |
<br>
|
25 |
</h1>
|
26 |
<h3 align="center">Build datasets using natural language</h2>
|
27 |
|
28 |
![Synthetic Data Generator](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/ui-full.png)
|
29 |
|
30 |
-
<p align="center">
|
31 |
-
<a href="https://pypi.org/project/synthetic-dataset-generator/">
|
32 |
-
<img alt="CI" src="https://img.shields.io/pypi/v/synthetic-dataset-generator.svg?style=flat-round&logo=pypi&logoColor=white">
|
33 |
-
</a>
|
34 |
-
<a href="https://pepy.tech/project/synthetic-dataset-generator">
|
35 |
-
<img alt="CI" src="https://static.pepy.tech/personalized-badge/synthetic-dataset-generator?period=month&units=international_system&left_color=grey&right_color=blue&left_text=pypi%20downloads/month">
|
36 |
-
</a>
|
37 |
-
<a href="https://huggingface.co/spaces/argilla/synthetic-data-generator?duplicate=true">
|
38 |
-
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg"/>
|
39 |
-
</a>
|
40 |
-
</p>
|
41 |
-
|
42 |
<p align="center">
|
43 |
<a href="https://twitter.com/argilla_io">
|
44 |
<img src="https://img.shields.io/badge/twitter-black?logo=x"/>
|
@@ -78,21 +66,29 @@ You can simply install the package with:
|
|
78 |
pip install synthetic-dataset-generator
|
79 |
```
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
### Environment Variables
|
82 |
|
83 |
-
- `HF_TOKEN`: Your Hugging Face token
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
|
86 |
|
87 |
- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
|
88 |
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
|
89 |
|
90 |
-
## Quickstart
|
91 |
-
|
92 |
-
```bash
|
93 |
-
python app.py
|
94 |
-
```
|
95 |
-
|
96 |
### Argilla integration
|
97 |
|
98 |
Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
|
@@ -104,3 +100,19 @@ Argilla is a open source tool for data curation. It allows you to annotate and r
|
|
104 |
Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps.
|
105 |
|
106 |
Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
<h1 align="center">
|
22 |
<br>
|
23 |
+
🧬 Synthetic Data Generator
|
24 |
<br>
|
25 |
</h1>
|
26 |
<h3 align="center">Build datasets using natural language</h2>
|
27 |
|
28 |
![Synthetic Data Generator](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/ui-full.png)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
<p align="center">
|
31 |
<a href="https://twitter.com/argilla_io">
|
32 |
<img src="https://img.shields.io/badge/twitter-black?logo=x"/>
|
|
|
66 |
pip install synthetic-dataset-generator
|
67 |
```
|
68 |
|
69 |
+
### Quickstart
|
70 |
+
|
71 |
+
```python
|
72 |
+
from synthetic_dataset_generator.app import demo
|
73 |
+
|
74 |
+
demo.launch()
|
75 |
+
```
|
76 |
+
|
77 |
### Environment Variables
|
78 |
|
79 |
+
- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints.
|
80 |
+
|
81 |
+
Optionally, you can set the following environment variables to customize the generation process.
|
82 |
+
|
83 |
+
- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`.
|
84 |
+
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`.
|
85 |
+
- `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`.
|
86 |
|
87 |
Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
|
88 |
|
89 |
- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
|
90 |
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
### Argilla integration
|
93 |
|
94 |
Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
|
|
|
100 |
Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps.
|
101 |
|
102 |
Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information.
|
103 |
+
|
104 |
+
## Development
|
105 |
+
|
106 |
+
Install the dependencies:
|
107 |
+
|
108 |
+
```bash
|
109 |
+
python -m venv .venv
|
110 |
+
source .venv/bin/activate
|
111 |
+
pip install -e .
|
112 |
+
```
|
113 |
+
|
114 |
+
Run the app:
|
115 |
+
|
116 |
+
```bash
|
117 |
+
python app.py
|
118 |
+
```
|
app.py
CHANGED
@@ -1,38 +1,4 @@
|
|
1 |
-
from
|
2 |
-
from src.distilabel_dataset_generator.apps.eval import app as eval_app
|
3 |
-
from src.distilabel_dataset_generator.apps.faq import app as faq_app
|
4 |
-
from src.distilabel_dataset_generator.apps.sft import app as sft_app
|
5 |
-
from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
|
6 |
-
|
7 |
-
theme = "argilla/argilla-theme"
|
8 |
-
|
9 |
-
css = """
|
10 |
-
button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
11 |
-
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)}
|
12 |
-
button.hf-login {background: var(--neutral-800); color: white}
|
13 |
-
button.hf-login:hover {background: var(--neutral-700); color: white}
|
14 |
-
.tabitem { border: 0; padding-inline: 0}
|
15 |
-
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
16 |
-
.group_padding{padding: .55em}
|
17 |
-
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
18 |
-
.gallery {white-space: wrap}
|
19 |
-
#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
|
20 |
-
#system_prompt_examples {
|
21 |
-
color: var(--body-text-color) !important;
|
22 |
-
background-color: var(--block-background-fill) !important;
|
23 |
-
}
|
24 |
-
.container {padding-inline: 0 !important}
|
25 |
-
"""
|
26 |
-
|
27 |
-
demo = TabbedInterface(
|
28 |
-
[textcat_app, sft_app, eval_app, faq_app],
|
29 |
-
["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
|
30 |
-
css=css,
|
31 |
-
title="Synthetic Data Generator",
|
32 |
-
head="Synthetic Data Generator",
|
33 |
-
theme=theme,
|
34 |
-
)
|
35 |
-
|
36 |
|
37 |
if __name__ == "__main__":
|
38 |
demo.launch()
|
|
|
1 |
+
from distilabel_dataset_generator.app import demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
if __name__ == "__main__":
|
4 |
demo.launch()
|
pyproject.toml
CHANGED
@@ -5,6 +5,18 @@ description = "Build datasets using natural language"
|
|
5 |
authors = [
|
6 |
{name = "davidberenstein1957", email = "[email protected]"},
|
7 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
dependencies = [
|
9 |
"distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
|
10 |
"gradio[oauth]<5.0.0",
|
@@ -14,14 +26,10 @@ dependencies = [
|
|
14 |
"gradio-huggingfacehub-search>=0.0.7",
|
15 |
"argilla>=2.4.0",
|
16 |
]
|
17 |
-
requires-python = "<3.13,>=3.10"
|
18 |
-
readme = "README.md"
|
19 |
-
license = {text = "apache 2"}
|
20 |
|
21 |
[build-system]
|
22 |
requires = ["pdm-backend"]
|
23 |
build-backend = "pdm.backend"
|
24 |
|
25 |
-
|
26 |
[tool.pdm]
|
27 |
distribution = true
|
|
|
5 |
authors = [
|
6 |
{name = "davidberenstein1957", email = "[email protected]"},
|
7 |
]
|
8 |
+
tags = [
|
9 |
+
"gradio",
|
10 |
+
"synthetic-data",
|
11 |
+
"huggingface",
|
12 |
+
"argilla",
|
13 |
+
"generative-ai",
|
14 |
+
"ai",
|
15 |
+
]
|
16 |
+
requires-python = "<3.13,>=3.10"
|
17 |
+
readme = "README.md"
|
18 |
+
license = {text = "Apache 2"}
|
19 |
+
|
20 |
dependencies = [
|
21 |
"distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
|
22 |
"gradio[oauth]<5.0.0",
|
|
|
26 |
"gradio-huggingfacehub-search>=0.0.7",
|
27 |
"argilla>=2.4.0",
|
28 |
]
|
|
|
|
|
|
|
29 |
|
30 |
[build-system]
|
31 |
requires = ["pdm-backend"]
|
32 |
build-backend = "pdm.backend"
|
33 |
|
|
|
34 |
[tool.pdm]
|
35 |
distribution = true
|
src/distilabel_dataset_generator/__init__.py
CHANGED
@@ -1,39 +1,64 @@
|
|
1 |
-
import os
|
2 |
import warnings
|
3 |
-
from
|
4 |
-
from typing import Optional, Union
|
5 |
|
6 |
-
import argilla as rg
|
7 |
import distilabel
|
8 |
import distilabel.distiset
|
|
|
9 |
from distilabel.utils.card.dataset_card import (
|
10 |
DistilabelDatasetCard,
|
11 |
size_categories_parser,
|
12 |
)
|
13 |
-
from huggingface_hub import DatasetCardData, HfApi
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
|
16 |
-
HF_TOKENS = [token for token in HF_TOKENS if token]
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
-
if
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
|
@@ -138,3 +163,4 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
|
|
138 |
|
139 |
|
140 |
distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
|
|
|
|
|
|
1 |
import warnings
|
2 |
+
from typing import Optional
|
|
|
3 |
|
|
|
4 |
import distilabel
|
5 |
import distilabel.distiset
|
6 |
+
from distilabel.llms import InferenceEndpointsLLM
|
7 |
from distilabel.utils.card.dataset_card import (
|
8 |
DistilabelDatasetCard,
|
9 |
size_categories_parser,
|
10 |
)
|
11 |
+
from huggingface_hub import DatasetCardData, HfApi
|
12 |
+
from pydantic import (
|
13 |
+
ValidationError,
|
14 |
+
model_validator,
|
15 |
+
)
|
16 |
|
|
|
|
|
17 |
|
18 |
+
class CustomInferenceEndpointsLLM(InferenceEndpointsLLM):
|
19 |
+
@model_validator(mode="after") # type: ignore
|
20 |
+
def only_one_of_model_id_endpoint_name_or_base_url_provided(
|
21 |
+
self,
|
22 |
+
) -> "InferenceEndpointsLLM":
|
23 |
+
"""Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
|
24 |
+
provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
|
25 |
+
favour of the dynamically calculated one.."""
|
26 |
+
|
27 |
+
if self.base_url and (self.model_id or self.endpoint_name):
|
28 |
+
warnings.warn( # type: ignore
|
29 |
+
f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
|
30 |
+
" or `endpoint_name` is also provided, the `base_url` will either be ignored"
|
31 |
+
" or overwritten with the one generated from either of those args, for serverless"
|
32 |
+
" or dedicated inference endpoints, respectively."
|
33 |
+
)
|
34 |
+
|
35 |
+
if self.use_magpie_template and self.tokenizer_id is None:
|
36 |
+
raise ValueError(
|
37 |
+
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
|
38 |
+
" set a `tokenizer_id` and try again."
|
39 |
+
)
|
40 |
|
41 |
+
if (
|
42 |
+
self.model_id
|
43 |
+
and self.tokenizer_id is None
|
44 |
+
and self.structured_output is not None
|
45 |
+
):
|
46 |
+
self.tokenizer_id = self.model_id
|
47 |
|
48 |
+
if self.base_url and not (self.model_id or self.endpoint_name):
|
49 |
+
return self
|
50 |
+
|
51 |
+
if self.model_id and not self.endpoint_name:
|
52 |
+
return self
|
53 |
+
|
54 |
+
if self.endpoint_name and not self.model_id:
|
55 |
+
return self
|
56 |
+
|
57 |
+
raise ValidationError(
|
58 |
+
f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
|
59 |
+
f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
|
60 |
+
f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
|
61 |
+
)
|
62 |
|
63 |
|
64 |
class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
|
|
|
163 |
|
164 |
|
165 |
distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
|
166 |
+
distilabel.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM
|
src/distilabel_dataset_generator/_tabbedinterface.py
CHANGED
@@ -68,7 +68,9 @@ class TabbedInterface(Blocks):
|
|
68 |
with gr.Column(scale=3):
|
69 |
pass
|
70 |
with gr.Column(scale=2):
|
71 |
-
gr.LoginButton(
|
|
|
|
|
72 |
with Tabs():
|
73 |
for interface, tab_name in zip(interface_list, tab_names, strict=False):
|
74 |
with Tab(label=tab_name):
|
|
|
68 |
with gr.Column(scale=3):
|
69 |
pass
|
70 |
with gr.Column(scale=2):
|
71 |
+
gr.LoginButton(
|
72 |
+
value="Sign in", variant="hf-login", size="sm", scale=2
|
73 |
+
)
|
74 |
with Tabs():
|
75 |
for interface, tab_name in zip(interface_list, tab_names, strict=False):
|
76 |
with Tab(label=tab_name):
|
src/distilabel_dataset_generator/app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distilabel_dataset_generator._tabbedinterface import TabbedInterface
|
2 |
+
from distilabel_dataset_generator.apps.eval import app as eval_app
|
3 |
+
from distilabel_dataset_generator.apps.faq import app as faq_app
|
4 |
+
from distilabel_dataset_generator.apps.sft import app as sft_app
|
5 |
+
from distilabel_dataset_generator.apps.textcat import app as textcat_app
|
6 |
+
|
7 |
+
theme = "argilla/argilla-theme"
|
8 |
+
|
9 |
+
css = """
|
10 |
+
button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
|
11 |
+
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)}
|
12 |
+
button.hf-login {background: var(--neutral-800); color: white}
|
13 |
+
button.hf-login:hover {background: var(--neutral-700); color: white}
|
14 |
+
.tabitem { border: 0; padding-inline: 0}
|
15 |
+
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
16 |
+
.group_padding{padding: .55em}
|
17 |
+
.gallery-item {background: var(--background-fill-secondary); text-align: left}
|
18 |
+
.gallery {white-space: wrap}
|
19 |
+
#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
|
20 |
+
#system_prompt_examples {
|
21 |
+
color: var(--body-text-color) !important;
|
22 |
+
background-color: var(--block-background-fill) !important;
|
23 |
+
}
|
24 |
+
.container {padding-inline: 0 !important}
|
25 |
+
"""
|
26 |
+
|
27 |
+
demo = TabbedInterface(
|
28 |
+
[textcat_app, sft_app, eval_app, faq_app],
|
29 |
+
["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
|
30 |
+
css=css,
|
31 |
+
title="Synthetic Data Generator",
|
32 |
+
head="Synthetic Data Generator",
|
33 |
+
theme=theme,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
demo.launch()
|
src/distilabel_dataset_generator/apps/__init__.py
ADDED
File without changes
|
src/distilabel_dataset_generator/apps/base.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import io
|
2 |
import uuid
|
3 |
-
from typing import
|
4 |
|
5 |
import argilla as rg
|
6 |
import gradio as gr
|
@@ -10,161 +10,11 @@ from distilabel.distiset import Distiset
|
|
10 |
from gradio import OAuthToken
|
11 |
from huggingface_hub import HfApi, upload_file
|
12 |
|
13 |
-
from
|
14 |
-
|
15 |
get_argilla_client,
|
16 |
-
get_login_button,
|
17 |
-
list_orgs,
|
18 |
-
swap_visibility,
|
19 |
)
|
20 |
|
21 |
-
TEXTCAT_TASK = "text_classification"
|
22 |
-
SFT_TASK = "supervised_fine_tuning"
|
23 |
-
|
24 |
-
|
25 |
-
def get_main_ui(
|
26 |
-
default_dataset_descriptions: List[str],
|
27 |
-
default_system_prompts: List[str],
|
28 |
-
default_datasets: List[pd.DataFrame],
|
29 |
-
fn_generate_system_prompt: Callable,
|
30 |
-
fn_generate_dataset: Callable,
|
31 |
-
task: str,
|
32 |
-
):
|
33 |
-
def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()):
|
34 |
-
if system_prompt in default_system_prompts:
|
35 |
-
index = default_system_prompts.index(system_prompt)
|
36 |
-
if index < len(default_datasets):
|
37 |
-
return default_datasets[index]
|
38 |
-
if task == TEXTCAT_TASK:
|
39 |
-
result = fn_generate_dataset(
|
40 |
-
system_prompt=system_prompt,
|
41 |
-
difficulty="high school",
|
42 |
-
clarity="clear",
|
43 |
-
labels=[],
|
44 |
-
num_labels=1,
|
45 |
-
num_rows=1,
|
46 |
-
progress=progress,
|
47 |
-
is_sample=True,
|
48 |
-
)
|
49 |
-
else:
|
50 |
-
result = fn_generate_dataset(
|
51 |
-
system_prompt=system_prompt,
|
52 |
-
num_turns=1,
|
53 |
-
num_rows=1,
|
54 |
-
progress=progress,
|
55 |
-
is_sample=True,
|
56 |
-
)
|
57 |
-
return result
|
58 |
-
|
59 |
-
with gr.Blocks(
|
60 |
-
title="🧬 Synthetic Data Generator",
|
61 |
-
head="🧬 Synthetic Data Generator",
|
62 |
-
css=_LOGGED_OUT_CSS,
|
63 |
-
) as app:
|
64 |
-
with gr.Row():
|
65 |
-
gr.HTML(
|
66 |
-
"""<details style='display: inline-block;'><summary><h2 style='display: inline;'>How does it work?</h2></summary><img src='https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/flow.png' width='100%' style='margin: 0 auto; display: block;'></details>"""
|
67 |
-
)
|
68 |
-
with gr.Row():
|
69 |
-
gr.Markdown(
|
70 |
-
"Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
|
71 |
-
)
|
72 |
-
with gr.Row():
|
73 |
-
gr.Column()
|
74 |
-
get_login_button()
|
75 |
-
gr.Column()
|
76 |
-
|
77 |
-
gr.Markdown("## Iterate on a sample dataset")
|
78 |
-
with gr.Column() as main_ui:
|
79 |
-
(
|
80 |
-
dataset_description,
|
81 |
-
examples,
|
82 |
-
btn_generate_system_prompt,
|
83 |
-
system_prompt,
|
84 |
-
sample_dataset,
|
85 |
-
btn_generate_sample_dataset,
|
86 |
-
) = get_iterate_on_sample_dataset_ui(
|
87 |
-
default_dataset_descriptions=default_dataset_descriptions,
|
88 |
-
default_system_prompts=default_system_prompts,
|
89 |
-
default_datasets=default_datasets,
|
90 |
-
task=task,
|
91 |
-
)
|
92 |
-
gr.Markdown("## Generate full dataset")
|
93 |
-
gr.Markdown(
|
94 |
-
"Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub."
|
95 |
-
)
|
96 |
-
with gr.Row(variant="panel") as custom_input_ui:
|
97 |
-
pass
|
98 |
-
|
99 |
-
(
|
100 |
-
dataset_name,
|
101 |
-
add_to_existing_dataset,
|
102 |
-
btn_generate_full_dataset_argilla,
|
103 |
-
btn_generate_and_push_to_argilla,
|
104 |
-
btn_push_to_argilla,
|
105 |
-
org_name,
|
106 |
-
repo_name,
|
107 |
-
private,
|
108 |
-
btn_generate_full_dataset,
|
109 |
-
btn_generate_and_push_to_hub,
|
110 |
-
btn_push_to_hub,
|
111 |
-
final_dataset,
|
112 |
-
success_message,
|
113 |
-
) = get_push_to_ui(default_datasets)
|
114 |
-
|
115 |
-
sample_dataset.change(
|
116 |
-
fn=lambda x: x,
|
117 |
-
inputs=[sample_dataset],
|
118 |
-
outputs=[final_dataset],
|
119 |
-
)
|
120 |
-
|
121 |
-
btn_generate_system_prompt.click(
|
122 |
-
fn=fn_generate_system_prompt,
|
123 |
-
inputs=[dataset_description],
|
124 |
-
outputs=[system_prompt],
|
125 |
-
show_progress=True,
|
126 |
-
).then(
|
127 |
-
fn=fn_generate_sample_dataset,
|
128 |
-
inputs=[system_prompt],
|
129 |
-
outputs=[sample_dataset],
|
130 |
-
show_progress=True,
|
131 |
-
)
|
132 |
-
|
133 |
-
btn_generate_sample_dataset.click(
|
134 |
-
fn=fn_generate_sample_dataset,
|
135 |
-
inputs=[system_prompt],
|
136 |
-
outputs=[sample_dataset],
|
137 |
-
show_progress=True,
|
138 |
-
)
|
139 |
-
|
140 |
-
app.load(fn=swap_visibility, outputs=main_ui)
|
141 |
-
app.load(get_org_dropdown, outputs=[org_name])
|
142 |
-
|
143 |
-
return (
|
144 |
-
app,
|
145 |
-
main_ui,
|
146 |
-
custom_input_ui,
|
147 |
-
dataset_description,
|
148 |
-
examples,
|
149 |
-
btn_generate_system_prompt,
|
150 |
-
system_prompt,
|
151 |
-
sample_dataset,
|
152 |
-
btn_generate_sample_dataset,
|
153 |
-
dataset_name,
|
154 |
-
add_to_existing_dataset,
|
155 |
-
btn_generate_full_dataset_argilla,
|
156 |
-
btn_generate_and_push_to_argilla,
|
157 |
-
btn_push_to_argilla,
|
158 |
-
org_name,
|
159 |
-
repo_name,
|
160 |
-
private,
|
161 |
-
btn_generate_full_dataset,
|
162 |
-
btn_generate_and_push_to_hub,
|
163 |
-
btn_push_to_hub,
|
164 |
-
final_dataset,
|
165 |
-
success_message,
|
166 |
-
)
|
167 |
-
|
168 |
|
169 |
def validate_argilla_user_workspace_dataset(
|
170 |
dataset_name: str,
|
@@ -195,186 +45,6 @@ def validate_argilla_user_workspace_dataset(
|
|
195 |
return ""
|
196 |
|
197 |
|
198 |
-
def get_org_dropdown(oauth_token: Union[OAuthToken, None]):
|
199 |
-
orgs = list_orgs(oauth_token)
|
200 |
-
return gr.Dropdown(
|
201 |
-
label="Organization",
|
202 |
-
choices=orgs,
|
203 |
-
value=orgs[0] if orgs else None,
|
204 |
-
allow_custom_value=True,
|
205 |
-
)
|
206 |
-
|
207 |
-
|
208 |
-
def get_push_to_ui(default_datasets):
|
209 |
-
with gr.Column() as push_to_ui:
|
210 |
-
(
|
211 |
-
dataset_name,
|
212 |
-
add_to_existing_dataset,
|
213 |
-
btn_generate_full_dataset_argilla,
|
214 |
-
btn_generate_and_push_to_argilla,
|
215 |
-
btn_push_to_argilla,
|
216 |
-
) = get_argilla_tab()
|
217 |
-
(
|
218 |
-
org_name,
|
219 |
-
repo_name,
|
220 |
-
private,
|
221 |
-
btn_generate_full_dataset,
|
222 |
-
btn_generate_and_push_to_hub,
|
223 |
-
btn_push_to_hub,
|
224 |
-
) = get_hf_tab()
|
225 |
-
final_dataset = get_final_dataset_row(default_datasets)
|
226 |
-
success_message = get_success_message_row()
|
227 |
-
return (
|
228 |
-
dataset_name,
|
229 |
-
add_to_existing_dataset,
|
230 |
-
btn_generate_full_dataset_argilla,
|
231 |
-
btn_generate_and_push_to_argilla,
|
232 |
-
btn_push_to_argilla,
|
233 |
-
org_name,
|
234 |
-
repo_name,
|
235 |
-
private,
|
236 |
-
btn_generate_full_dataset,
|
237 |
-
btn_generate_and_push_to_hub,
|
238 |
-
btn_push_to_hub,
|
239 |
-
final_dataset,
|
240 |
-
success_message,
|
241 |
-
)
|
242 |
-
|
243 |
-
|
244 |
-
def get_iterate_on_sample_dataset_ui(
|
245 |
-
default_dataset_descriptions: List[str],
|
246 |
-
default_system_prompts: List[str],
|
247 |
-
default_datasets: List[pd.DataFrame],
|
248 |
-
task: str,
|
249 |
-
):
|
250 |
-
with gr.Column():
|
251 |
-
dataset_description = gr.TextArea(
|
252 |
-
label="Give a precise description of your desired application. Check the examples for inspiration.",
|
253 |
-
value=default_dataset_descriptions[0],
|
254 |
-
lines=2,
|
255 |
-
)
|
256 |
-
examples = gr.Examples(
|
257 |
-
elem_id="system_prompt_examples",
|
258 |
-
examples=[[example] for example in default_dataset_descriptions],
|
259 |
-
inputs=[dataset_description],
|
260 |
-
)
|
261 |
-
with gr.Row():
|
262 |
-
gr.Column(scale=1)
|
263 |
-
btn_generate_system_prompt = gr.Button(
|
264 |
-
value="Generate system prompt and sample dataset", variant="primary"
|
265 |
-
)
|
266 |
-
gr.Column(scale=1)
|
267 |
-
|
268 |
-
system_prompt = gr.TextArea(
|
269 |
-
label="System prompt for dataset generation. You can tune it and regenerate the sample.",
|
270 |
-
value=default_system_prompts[0],
|
271 |
-
lines=2 if task == TEXTCAT_TASK else 5,
|
272 |
-
)
|
273 |
-
|
274 |
-
with gr.Row():
|
275 |
-
sample_dataset = gr.Dataframe(
|
276 |
-
value=default_datasets[0],
|
277 |
-
label=(
|
278 |
-
"Sample dataset. Text truncated to 256 tokens."
|
279 |
-
if task == TEXTCAT_TASK
|
280 |
-
else "Sample dataset. Prompts and completions truncated to 256 tokens."
|
281 |
-
),
|
282 |
-
interactive=False,
|
283 |
-
wrap=True,
|
284 |
-
)
|
285 |
-
|
286 |
-
with gr.Row():
|
287 |
-
gr.Column(scale=1)
|
288 |
-
btn_generate_sample_dataset = gr.Button(
|
289 |
-
value="Generate sample dataset", variant="primary"
|
290 |
-
)
|
291 |
-
gr.Column(scale=1)
|
292 |
-
|
293 |
-
return (
|
294 |
-
dataset_description,
|
295 |
-
examples,
|
296 |
-
btn_generate_system_prompt,
|
297 |
-
system_prompt,
|
298 |
-
sample_dataset,
|
299 |
-
btn_generate_sample_dataset,
|
300 |
-
)
|
301 |
-
|
302 |
-
|
303 |
-
def get_argilla_tab() -> Tuple[Any]:
|
304 |
-
with gr.Tab(label="Argilla"):
|
305 |
-
if get_argilla_client() is not None:
|
306 |
-
with gr.Row(variant="panel"):
|
307 |
-
dataset_name = gr.Textbox(
|
308 |
-
label="Dataset name",
|
309 |
-
placeholder="dataset_name",
|
310 |
-
value="my-distiset",
|
311 |
-
)
|
312 |
-
add_to_existing_dataset = gr.Checkbox(
|
313 |
-
label="Allow adding records to existing dataset",
|
314 |
-
info="When selected, you do need to ensure the dataset options are the same as in the existing dataset.",
|
315 |
-
value=False,
|
316 |
-
interactive=True,
|
317 |
-
scale=1,
|
318 |
-
)
|
319 |
-
|
320 |
-
with gr.Row(variant="panel"):
|
321 |
-
btn_generate_full_dataset_argilla = gr.Button(
|
322 |
-
value="Generate", variant="primary", scale=2
|
323 |
-
)
|
324 |
-
btn_generate_and_push_to_argilla = gr.Button(
|
325 |
-
value="Generate and Push to Argilla",
|
326 |
-
variant="primary",
|
327 |
-
scale=2,
|
328 |
-
)
|
329 |
-
btn_push_to_argilla = gr.Button(
|
330 |
-
value="Push to Argilla", variant="primary", scale=2
|
331 |
-
)
|
332 |
-
else:
|
333 |
-
gr.Markdown(
|
334 |
-
"Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub."
|
335 |
-
)
|
336 |
-
return (
|
337 |
-
dataset_name,
|
338 |
-
add_to_existing_dataset,
|
339 |
-
btn_generate_full_dataset_argilla,
|
340 |
-
btn_generate_and_push_to_argilla,
|
341 |
-
btn_push_to_argilla,
|
342 |
-
)
|
343 |
-
|
344 |
-
|
345 |
-
def get_hf_tab() -> Tuple[Any]:
|
346 |
-
with gr.Tab("Hugging Face Hub"):
|
347 |
-
with gr.Row(variant="panel"):
|
348 |
-
org_name = get_org_dropdown()
|
349 |
-
repo_name = gr.Textbox(
|
350 |
-
label="Repo name",
|
351 |
-
placeholder="dataset_name",
|
352 |
-
value="my-distiset",
|
353 |
-
)
|
354 |
-
private = gr.Checkbox(
|
355 |
-
label="Private dataset",
|
356 |
-
value=True,
|
357 |
-
interactive=True,
|
358 |
-
scale=1,
|
359 |
-
)
|
360 |
-
with gr.Row(variant="panel"):
|
361 |
-
btn_generate_full_dataset = gr.Button(
|
362 |
-
value="Generate", variant="primary", scale=2
|
363 |
-
)
|
364 |
-
btn_generate_and_push_to_hub = gr.Button(
|
365 |
-
value="Generate and Push to Hub", variant="primary", scale=2
|
366 |
-
)
|
367 |
-
btn_push_to_hub = gr.Button(value="Push to Hub", variant="primary", scale=2)
|
368 |
-
return (
|
369 |
-
org_name,
|
370 |
-
repo_name,
|
371 |
-
private,
|
372 |
-
btn_generate_full_dataset,
|
373 |
-
btn_generate_and_push_to_hub,
|
374 |
-
btn_push_to_hub,
|
375 |
-
)
|
376 |
-
|
377 |
-
|
378 |
def push_pipeline_code_to_hub(
|
379 |
pipeline_code: str,
|
380 |
org_name: str,
|
@@ -455,24 +125,6 @@ def validate_push_to_hub(org_name, repo_name):
|
|
455 |
return repo_id
|
456 |
|
457 |
|
458 |
-
def get_final_dataset_row(default_datasets) -> gr.Dataframe:
|
459 |
-
with gr.Row():
|
460 |
-
final_dataset = gr.Dataframe(
|
461 |
-
value=default_datasets[0],
|
462 |
-
label="Generated dataset",
|
463 |
-
interactive=False,
|
464 |
-
wrap=True,
|
465 |
-
min_width=300,
|
466 |
-
)
|
467 |
-
return final_dataset
|
468 |
-
|
469 |
-
|
470 |
-
def get_success_message_row() -> gr.Markdown:
|
471 |
-
with gr.Row():
|
472 |
-
success_message = gr.Markdown(visible=False)
|
473 |
-
return success_message
|
474 |
-
|
475 |
-
|
476 |
def show_success_message(org_name, repo_name) -> gr.Markdown:
|
477 |
client = get_argilla_client()
|
478 |
if client is None:
|
|
|
1 |
import io
|
2 |
import uuid
|
3 |
+
from typing import List, Union
|
4 |
|
5 |
import argilla as rg
|
6 |
import gradio as gr
|
|
|
10 |
from gradio import OAuthToken
|
11 |
from huggingface_hub import HfApi, upload_file
|
12 |
|
13 |
+
from distilabel_dataset_generator.constants import TEXTCAT_TASK
|
14 |
+
from distilabel_dataset_generator.utils import (
|
15 |
get_argilla_client,
|
|
|
|
|
|
|
16 |
)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def validate_argilla_user_workspace_dataset(
|
20 |
dataset_name: str,
|
|
|
45 |
return ""
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def push_pipeline_code_to_hub(
|
49 |
pipeline_code: str,
|
50 |
org_name: str,
|
|
|
125 |
return repo_id
|
126 |
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
def show_success_message(org_name, repo_name) -> gr.Markdown:
|
129 |
client = get_argilla_client()
|
130 |
if client is None:
|
src/distilabel_dataset_generator/apps/eval.py
CHANGED
@@ -16,25 +16,23 @@ from distilabel.distiset import Distiset
|
|
16 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
17 |
from huggingface_hub import HfApi
|
18 |
|
19 |
-
from
|
20 |
hide_success_message,
|
21 |
show_success_message,
|
22 |
validate_argilla_user_workspace_dataset,
|
23 |
validate_push_to_hub,
|
24 |
)
|
25 |
-
from
|
26 |
-
|
27 |
-
)
|
28 |
-
from src.distilabel_dataset_generator.pipelines.embeddings import (
|
29 |
get_embeddings,
|
30 |
get_sentence_embedding_dimensions,
|
31 |
)
|
32 |
-
from
|
33 |
generate_pipeline_code,
|
34 |
get_custom_evaluator,
|
35 |
get_ultrafeedback_evaluator,
|
36 |
)
|
37 |
-
from
|
38 |
column_to_list,
|
39 |
extract_column_names,
|
40 |
get_argilla_client,
|
|
|
16 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
17 |
from huggingface_hub import HfApi
|
18 |
|
19 |
+
from distilabel_dataset_generator.apps.base import (
|
20 |
hide_success_message,
|
21 |
show_success_message,
|
22 |
validate_argilla_user_workspace_dataset,
|
23 |
validate_push_to_hub,
|
24 |
)
|
25 |
+
from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
26 |
+
from distilabel_dataset_generator.pipelines.embeddings import (
|
|
|
|
|
27 |
get_embeddings,
|
28 |
get_sentence_embedding_dimensions,
|
29 |
)
|
30 |
+
from distilabel_dataset_generator.pipelines.eval import (
|
31 |
generate_pipeline_code,
|
32 |
get_custom_evaluator,
|
33 |
get_ultrafeedback_evaluator,
|
34 |
)
|
35 |
+
from distilabel_dataset_generator.utils import (
|
36 |
column_to_list,
|
37 |
extract_column_names,
|
38 |
get_argilla_client,
|
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -9,28 +9,25 @@ from datasets import Dataset
|
|
9 |
from distilabel.distiset import Distiset
|
10 |
from huggingface_hub import HfApi
|
11 |
|
12 |
-
from
|
13 |
hide_success_message,
|
14 |
show_success_message,
|
15 |
validate_argilla_user_workspace_dataset,
|
16 |
validate_push_to_hub,
|
17 |
)
|
18 |
-
from
|
19 |
-
|
20 |
-
)
|
21 |
-
from src.distilabel_dataset_generator.pipelines.embeddings import (
|
22 |
get_embeddings,
|
23 |
get_sentence_embedding_dimensions,
|
24 |
)
|
25 |
-
from
|
26 |
DEFAULT_DATASET_DESCRIPTIONS,
|
27 |
generate_pipeline_code,
|
28 |
get_magpie_generator,
|
29 |
get_prompt_generator,
|
30 |
get_response_generator,
|
31 |
)
|
32 |
-
from
|
33 |
-
_LOGGED_OUT_CSS,
|
34 |
get_argilla_client,
|
35 |
get_org_dropdown,
|
36 |
swap_visibility,
|
@@ -352,170 +349,177 @@ def hide_pipeline_code_visibility():
|
|
352 |
######################
|
353 |
|
354 |
|
355 |
-
with gr.Blocks(
|
356 |
with gr.Column() as main_ui:
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
with gr.
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
interactive=True,
|
371 |
-
|
372 |
)
|
373 |
-
|
374 |
-
|
375 |
-
variant="primary",
|
376 |
-
)
|
377 |
-
with gr.Column(scale=2):
|
378 |
-
examples = gr.Examples(
|
379 |
-
examples=DEFAULT_DATASET_DESCRIPTIONS,
|
380 |
-
inputs=[dataset_description],
|
381 |
-
cache_examples=False,
|
382 |
-
label="Examples",
|
383 |
-
)
|
384 |
-
with gr.Column(scale=1):
|
385 |
-
pass
|
386 |
-
|
387 |
-
gr.HTML(value="<hr>")
|
388 |
-
gr.Markdown(value="## 2. Configure your dataset")
|
389 |
-
with gr.Row(equal_height=False):
|
390 |
-
with gr.Column(scale=2):
|
391 |
-
system_prompt = gr.Textbox(
|
392 |
-
label="System prompt",
|
393 |
-
placeholder="You are a helpful assistant.",
|
394 |
-
)
|
395 |
-
num_turns = gr.Number(
|
396 |
-
value=1,
|
397 |
-
label="Number of turns in the conversation",
|
398 |
-
minimum=1,
|
399 |
-
maximum=4,
|
400 |
-
step=1,
|
401 |
-
interactive=True,
|
402 |
-
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
403 |
-
)
|
404 |
-
btn_apply_to_sample_dataset = gr.Button(
|
405 |
-
"Refresh dataset", variant="secondary"
|
406 |
-
)
|
407 |
-
with gr.Column(scale=3):
|
408 |
-
dataframe = gr.Dataframe(
|
409 |
-
headers=["prompt", "completion"],
|
410 |
-
wrap=True,
|
411 |
-
height=500,
|
412 |
-
interactive=False,
|
413 |
-
)
|
414 |
-
|
415 |
-
gr.HTML(value="<hr>")
|
416 |
-
gr.Markdown(value="## 3. Generate your dataset")
|
417 |
-
with gr.Row(equal_height=False):
|
418 |
-
with gr.Column(scale=2):
|
419 |
-
org_name = get_org_dropdown()
|
420 |
-
repo_name = gr.Textbox(
|
421 |
-
label="Repo name",
|
422 |
-
placeholder="dataset_name",
|
423 |
-
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
|
424 |
-
interactive=True,
|
425 |
-
)
|
426 |
-
num_rows = gr.Number(
|
427 |
-
label="Number of rows",
|
428 |
-
value=10,
|
429 |
-
interactive=True,
|
430 |
-
scale=1,
|
431 |
-
)
|
432 |
-
private = gr.Checkbox(
|
433 |
-
label="Private dataset",
|
434 |
-
value=False,
|
435 |
-
interactive=True,
|
436 |
-
scale=1,
|
437 |
-
)
|
438 |
-
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
439 |
-
with gr.Column(scale=3):
|
440 |
-
success_message = gr.Markdown(visible=True)
|
441 |
-
with gr.Accordion(
|
442 |
-
"Do you want to go further? Customize and run with Distilabel",
|
443 |
-
open=False,
|
444 |
-
visible=False,
|
445 |
-
) as pipeline_code_ui:
|
446 |
-
code = generate_pipeline_code(
|
447 |
-
system_prompt=system_prompt.value,
|
448 |
-
num_turns=num_turns.value,
|
449 |
-
num_rows=num_rows.value,
|
450 |
)
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
|
|
|
|
455 |
)
|
456 |
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
|
520 |
-
|
521 |
-
|
|
|
9 |
from distilabel.distiset import Distiset
|
10 |
from huggingface_hub import HfApi
|
11 |
|
12 |
+
from distilabel_dataset_generator.apps.base import (
|
13 |
hide_success_message,
|
14 |
show_success_message,
|
15 |
validate_argilla_user_workspace_dataset,
|
16 |
validate_push_to_hub,
|
17 |
)
|
18 |
+
from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE
|
19 |
+
from distilabel_dataset_generator.pipelines.embeddings import (
|
|
|
|
|
20 |
get_embeddings,
|
21 |
get_sentence_embedding_dimensions,
|
22 |
)
|
23 |
+
from distilabel_dataset_generator.pipelines.sft import (
|
24 |
DEFAULT_DATASET_DESCRIPTIONS,
|
25 |
generate_pipeline_code,
|
26 |
get_magpie_generator,
|
27 |
get_prompt_generator,
|
28 |
get_response_generator,
|
29 |
)
|
30 |
+
from distilabel_dataset_generator.utils import (
|
|
|
31 |
get_argilla_client,
|
32 |
get_org_dropdown,
|
33 |
swap_visibility,
|
|
|
349 |
######################
|
350 |
|
351 |
|
352 |
+
with gr.Blocks() as app:
|
353 |
with gr.Column() as main_ui:
|
354 |
+
if not SFT_AVAILABLE:
|
355 |
+
gr.Markdown(
|
356 |
+
value=f"## Supervised Fine-Tuning is not available for the {MODEL} model. Use Hugging Face Llama3 or Qwen2 models."
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
gr.Markdown(value="## 1. Describe the dataset you want")
|
360 |
+
with gr.Row():
|
361 |
+
with gr.Column(scale=2):
|
362 |
+
dataset_description = gr.Textbox(
|
363 |
+
label="Dataset description",
|
364 |
+
placeholder="Give a precise description of your desired dataset.",
|
365 |
+
)
|
366 |
+
with gr.Accordion("Temperature", open=False):
|
367 |
+
temperature = gr.Slider(
|
368 |
+
minimum=0.1,
|
369 |
+
maximum=1,
|
370 |
+
value=0.8,
|
371 |
+
step=0.1,
|
372 |
+
interactive=True,
|
373 |
+
show_label=False,
|
374 |
+
)
|
375 |
+
load_btn = gr.Button(
|
376 |
+
"Create dataset",
|
377 |
+
variant="primary",
|
378 |
+
)
|
379 |
+
with gr.Column(scale=2):
|
380 |
+
examples = gr.Examples(
|
381 |
+
examples=DEFAULT_DATASET_DESCRIPTIONS,
|
382 |
+
inputs=[dataset_description],
|
383 |
+
cache_examples=False,
|
384 |
+
label="Examples",
|
385 |
+
)
|
386 |
+
with gr.Column(scale=1):
|
387 |
+
pass
|
388 |
+
|
389 |
+
gr.HTML(value="<hr>")
|
390 |
+
gr.Markdown(value="## 2. Configure your dataset")
|
391 |
+
with gr.Row(equal_height=False):
|
392 |
+
with gr.Column(scale=2):
|
393 |
+
system_prompt = gr.Textbox(
|
394 |
+
label="System prompt",
|
395 |
+
placeholder="You are a helpful assistant.",
|
396 |
+
)
|
397 |
+
num_turns = gr.Number(
|
398 |
+
value=1,
|
399 |
+
label="Number of turns in the conversation",
|
400 |
+
minimum=1,
|
401 |
+
maximum=4,
|
402 |
+
step=1,
|
403 |
interactive=True,
|
404 |
+
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
405 |
)
|
406 |
+
btn_apply_to_sample_dataset = gr.Button(
|
407 |
+
"Refresh dataset", variant="secondary"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
)
|
409 |
+
with gr.Column(scale=3):
|
410 |
+
dataframe = gr.Dataframe(
|
411 |
+
headers=["prompt", "completion"],
|
412 |
+
wrap=True,
|
413 |
+
height=500,
|
414 |
+
interactive=False,
|
415 |
)
|
416 |
|
417 |
+
gr.HTML(value="<hr>")
|
418 |
+
gr.Markdown(value="## 3. Generate your dataset")
|
419 |
+
with gr.Row(equal_height=False):
|
420 |
+
with gr.Column(scale=2):
|
421 |
+
org_name = get_org_dropdown()
|
422 |
+
repo_name = gr.Textbox(
|
423 |
+
label="Repo name",
|
424 |
+
placeholder="dataset_name",
|
425 |
+
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
|
426 |
+
interactive=True,
|
427 |
+
)
|
428 |
+
num_rows = gr.Number(
|
429 |
+
label="Number of rows",
|
430 |
+
value=10,
|
431 |
+
interactive=True,
|
432 |
+
scale=1,
|
433 |
+
)
|
434 |
+
private = gr.Checkbox(
|
435 |
+
label="Private dataset",
|
436 |
+
value=False,
|
437 |
+
interactive=True,
|
438 |
+
scale=1,
|
439 |
+
)
|
440 |
+
btn_push_to_hub = gr.Button(
|
441 |
+
"Push to Hub", variant="primary", scale=2
|
442 |
+
)
|
443 |
+
with gr.Column(scale=3):
|
444 |
+
success_message = gr.Markdown(visible=True)
|
445 |
+
with gr.Accordion(
|
446 |
+
"Do you want to go further? Customize and run with Distilabel",
|
447 |
+
open=False,
|
448 |
+
visible=False,
|
449 |
+
) as pipeline_code_ui:
|
450 |
+
code = generate_pipeline_code(
|
451 |
+
system_prompt=system_prompt.value,
|
452 |
+
num_turns=num_turns.value,
|
453 |
+
num_rows=num_rows.value,
|
454 |
+
)
|
455 |
+
pipeline_code = gr.Code(
|
456 |
+
value=code,
|
457 |
+
language="python",
|
458 |
+
label="Distilabel Pipeline Code",
|
459 |
+
)
|
460 |
+
|
461 |
+
load_btn.click(
|
462 |
+
fn=generate_system_prompt,
|
463 |
+
inputs=[dataset_description, temperature],
|
464 |
+
outputs=[system_prompt],
|
465 |
+
show_progress=True,
|
466 |
+
).then(
|
467 |
+
fn=generate_sample_dataset,
|
468 |
+
inputs=[system_prompt, num_turns],
|
469 |
+
outputs=[dataframe],
|
470 |
+
show_progress=True,
|
471 |
+
)
|
472 |
|
473 |
+
btn_apply_to_sample_dataset.click(
|
474 |
+
fn=generate_sample_dataset,
|
475 |
+
inputs=[system_prompt, num_turns],
|
476 |
+
outputs=[dataframe],
|
477 |
+
show_progress=True,
|
478 |
+
)
|
479 |
|
480 |
+
btn_push_to_hub.click(
|
481 |
+
fn=validate_argilla_user_workspace_dataset,
|
482 |
+
inputs=[repo_name],
|
483 |
+
outputs=[success_message],
|
484 |
+
show_progress=True,
|
485 |
+
).then(
|
486 |
+
fn=validate_push_to_hub,
|
487 |
+
inputs=[org_name, repo_name],
|
488 |
+
outputs=[success_message],
|
489 |
+
show_progress=True,
|
490 |
+
).success(
|
491 |
+
fn=hide_success_message,
|
492 |
+
outputs=[success_message],
|
493 |
+
show_progress=True,
|
494 |
+
).success(
|
495 |
+
fn=hide_pipeline_code_visibility,
|
496 |
+
inputs=[],
|
497 |
+
outputs=[pipeline_code_ui],
|
498 |
+
).success(
|
499 |
+
fn=push_dataset,
|
500 |
+
inputs=[
|
501 |
+
org_name,
|
502 |
+
repo_name,
|
503 |
+
system_prompt,
|
504 |
+
num_turns,
|
505 |
+
num_rows,
|
506 |
+
private,
|
507 |
+
],
|
508 |
+
outputs=[success_message],
|
509 |
+
show_progress=True,
|
510 |
+
).success(
|
511 |
+
fn=show_success_message,
|
512 |
+
inputs=[org_name, repo_name],
|
513 |
+
outputs=[success_message],
|
514 |
+
).success(
|
515 |
+
fn=generate_pipeline_code,
|
516 |
+
inputs=[system_prompt, num_turns, num_rows],
|
517 |
+
outputs=[pipeline_code],
|
518 |
+
).success(
|
519 |
+
fn=show_pipeline_code_visibility,
|
520 |
+
inputs=[],
|
521 |
+
outputs=[pipeline_code_ui],
|
522 |
+
)
|
523 |
|
524 |
+
app.load(fn=swap_visibility, outputs=main_ui)
|
525 |
+
app.load(fn=get_org_dropdown, outputs=[org_name])
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -9,15 +9,13 @@ from datasets import ClassLabel, Dataset, Features, Sequence, Value
|
|
9 |
from distilabel.distiset import Distiset
|
10 |
from huggingface_hub import HfApi
|
11 |
|
|
|
12 |
from src.distilabel_dataset_generator.apps.base import (
|
13 |
hide_success_message,
|
14 |
show_success_message,
|
15 |
validate_argilla_user_workspace_dataset,
|
16 |
validate_push_to_hub,
|
17 |
)
|
18 |
-
from src.distilabel_dataset_generator.pipelines.base import (
|
19 |
-
DEFAULT_BATCH_SIZE,
|
20 |
-
)
|
21 |
from src.distilabel_dataset_generator.pipelines.embeddings import (
|
22 |
get_embeddings,
|
23 |
get_sentence_embedding_dimensions,
|
@@ -30,7 +28,6 @@ from src.distilabel_dataset_generator.pipelines.textcat import (
|
|
30 |
get_textcat_generator,
|
31 |
)
|
32 |
from src.distilabel_dataset_generator.utils import (
|
33 |
-
_LOGGED_OUT_CSS,
|
34 |
get_argilla_client,
|
35 |
get_org_dropdown,
|
36 |
get_preprocess_labels,
|
@@ -334,7 +331,7 @@ def hide_pipeline_code_visibility():
|
|
334 |
######################
|
335 |
|
336 |
|
337 |
-
with gr.Blocks(
|
338 |
with gr.Column() as main_ui:
|
339 |
gr.Markdown("## 1. Describe the dataset you want")
|
340 |
with gr.Row():
|
|
|
9 |
from distilabel.distiset import Distiset
|
10 |
from huggingface_hub import HfApi
|
11 |
|
12 |
+
from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
13 |
from src.distilabel_dataset_generator.apps.base import (
|
14 |
hide_success_message,
|
15 |
show_success_message,
|
16 |
validate_argilla_user_workspace_dataset,
|
17 |
validate_push_to_hub,
|
18 |
)
|
|
|
|
|
|
|
19 |
from src.distilabel_dataset_generator.pipelines.embeddings import (
|
20 |
get_embeddings,
|
21 |
get_sentence_embedding_dimensions,
|
|
|
28 |
get_textcat_generator,
|
29 |
)
|
30 |
from src.distilabel_dataset_generator.utils import (
|
|
|
31 |
get_argilla_client,
|
32 |
get_org_dropdown,
|
33 |
get_preprocess_labels,
|
|
|
331 |
######################
|
332 |
|
333 |
|
334 |
+
with gr.Blocks() as app:
|
335 |
with gr.Column() as main_ui:
|
336 |
gr.Markdown("## 1. Describe the dataset you want")
|
337 |
with gr.Row():
|
src/distilabel_dataset_generator/constants.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import argilla as rg
|
5 |
+
|
6 |
+
# Tasks
|
7 |
+
TEXTCAT_TASK = "text_classification"
|
8 |
+
SFT_TASK = "supervised_fine_tuning"
|
9 |
+
|
10 |
+
# Hugging Face
|
11 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
12 |
+
if HF_TOKEN is None:
|
13 |
+
raise ValueError(
|
14 |
+
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
|
15 |
+
)
|
16 |
+
|
17 |
+
# Inference
|
18 |
+
DEFAULT_BATCH_SIZE = 5
|
19 |
+
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
20 |
+
API_KEYS = (
|
21 |
+
[os.getenv("HF_TOKEN")]
|
22 |
+
+ [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
|
23 |
+
+ [os.getenv("API_KEY")]
|
24 |
+
)
|
25 |
+
API_KEYS = [token for token in API_KEYS if token]
|
26 |
+
BASE_URL = os.getenv("BASE_URL", "https://api-inference.huggingface.co/v1/")
|
27 |
+
|
28 |
+
if BASE_URL != "https://api-inference.huggingface.co/v1/" and len(API_KEYS) == 0:
|
29 |
+
raise ValueError(
|
30 |
+
"API_KEY is not set. Ensure you have set the API_KEY environment variable that has access to the Hugging Face Inference Endpoints."
|
31 |
+
)
|
32 |
+
if "Qwen2" not in MODEL and "Llama-3" not in MODEL:
|
33 |
+
SFT_AVAILABLE = False
|
34 |
+
warnings.warn(
|
35 |
+
"SFT_AVAILABLE is set to False because the model is not a Qwen or Llama model."
|
36 |
+
)
|
37 |
+
MAGPIE_PRE_QUERY_TEMPLATE = None
|
38 |
+
else:
|
39 |
+
SFT_AVAILABLE = True
|
40 |
+
if "Qwen2" in MODEL:
|
41 |
+
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
42 |
+
else:
|
43 |
+
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
|
44 |
+
|
45 |
+
# Embeddings
|
46 |
+
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
|
47 |
+
|
48 |
+
# Argilla
|
49 |
+
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
|
50 |
+
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
|
51 |
+
if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
|
52 |
+
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
|
53 |
+
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
|
54 |
+
|
55 |
+
if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
|
56 |
+
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set")
|
57 |
+
argilla_client = None
|
58 |
+
else:
|
59 |
+
argilla_client = rg.Argilla(
|
60 |
+
api_url=ARGILLA_API_URL,
|
61 |
+
api_key=ARGILLA_API_KEY,
|
62 |
+
)
|
src/distilabel_dataset_generator/pipelines/__init__.py
ADDED
File without changes
|
src/distilabel_dataset_generator/pipelines/base.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
from
|
2 |
|
3 |
-
DEFAULT_BATCH_SIZE = 5
|
4 |
TOKEN_INDEX = 0
|
5 |
-
MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
6 |
|
7 |
|
8 |
def _get_next_api_key():
|
9 |
global TOKEN_INDEX
|
10 |
-
api_key =
|
11 |
TOKEN_INDEX += 1
|
12 |
return api_key
|
|
|
1 |
+
from distilabel_dataset_generator.constants import API_KEYS
|
2 |
|
|
|
3 |
TOKEN_INDEX = 0
|
|
|
4 |
|
5 |
|
6 |
def _get_next_api_key():
|
7 |
global TOKEN_INDEX
|
8 |
+
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
|
9 |
TOKEN_INDEX += 1
|
10 |
return api_key
|
src/distilabel_dataset_generator/pipelines/embeddings.py
CHANGED
@@ -3,8 +3,9 @@ from typing import List
|
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
from sentence_transformers.models import StaticEmbedding
|
5 |
|
6 |
-
|
7 |
-
|
|
|
8 |
model = SentenceTransformer(modules=[static_embedding])
|
9 |
|
10 |
|
|
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
from sentence_transformers.models import StaticEmbedding
|
5 |
|
6 |
+
from distilabel_dataset_generator.constants import STATIC_EMBEDDING_MODEL
|
7 |
+
|
8 |
+
static_embedding = StaticEmbedding.from_model2vec(STATIC_EMBEDDING_MODEL)
|
9 |
model = SentenceTransformer(modules=[static_embedding])
|
10 |
|
11 |
|
src/distilabel_dataset_generator/pipelines/eval.py
CHANGED
@@ -5,18 +5,16 @@ from distilabel.steps.tasks import (
|
|
5 |
UltraFeedback,
|
6 |
)
|
7 |
|
8 |
-
from
|
9 |
-
|
10 |
-
|
11 |
-
)
|
12 |
-
from src.distilabel_dataset_generator.utils import extract_column_names
|
13 |
|
14 |
|
15 |
def get_ultrafeedback_evaluator(aspect, is_sample):
|
16 |
ultrafeedback_evaluator = UltraFeedback(
|
17 |
llm=InferenceEndpointsLLM(
|
18 |
model_id=MODEL,
|
19 |
-
|
20 |
api_key=_get_next_api_key(),
|
21 |
generation_kwargs={
|
22 |
"temperature": 0,
|
@@ -33,7 +31,7 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
|
|
33 |
custom_evaluator = TextGeneration(
|
34 |
llm=InferenceEndpointsLLM(
|
35 |
model_id=MODEL,
|
36 |
-
|
37 |
api_key=_get_next_api_key(),
|
38 |
structured_output={"format": "json", "schema": structured_output},
|
39 |
generation_kwargs={
|
@@ -62,7 +60,8 @@ from distilabel.steps.tasks import UltraFeedback
|
|
62 |
from distilabel.llms import InferenceEndpointsLLM
|
63 |
|
64 |
MODEL = "{MODEL}"
|
65 |
-
|
|
|
66 |
|
67 |
hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
|
68 |
data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
|
@@ -76,8 +75,8 @@ with Pipeline(name="ultrafeedback") as pipeline:
|
|
76 |
ultrafeedback_evaluator = UltraFeedback(
|
77 |
llm=InferenceEndpointsLLM(
|
78 |
model_id=MODEL,
|
79 |
-
|
80 |
-
api_key=os.environ["
|
81 |
generation_kwargs={{
|
82 |
"temperature": 0,
|
83 |
"max_new_tokens": 2048,
|
@@ -101,7 +100,8 @@ from distilabel.steps.tasks import UltraFeedback
|
|
101 |
from distilabel.llms import InferenceEndpointsLLM
|
102 |
|
103 |
MODEL = "{MODEL}"
|
104 |
-
|
|
|
105 |
|
106 |
hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
|
107 |
data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
|
@@ -119,8 +119,8 @@ with Pipeline(name="ultrafeedback") as pipeline:
|
|
119 |
aspect=aspect,
|
120 |
llm=InferenceEndpointsLLM(
|
121 |
model_id=MODEL,
|
122 |
-
|
123 |
-
api_key=os.environ["
|
124 |
generation_kwargs={{
|
125 |
"temperature": 0,
|
126 |
"max_new_tokens": 2048,
|
@@ -157,6 +157,7 @@ from distilabel.steps.tasks import TextGeneration
|
|
157 |
from distilabel.llms import InferenceEndpointsLLM
|
158 |
|
159 |
MODEL = "{MODEL}"
|
|
|
160 |
CUSTOM_TEMPLATE = "{prompt_template}"
|
161 |
os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
162 |
|
@@ -171,7 +172,7 @@ with Pipeline(name="custom-evaluation") as pipeline:
|
|
171 |
custom_evaluator = TextGeneration(
|
172 |
llm=InferenceEndpointsLLM(
|
173 |
model_id=MODEL,
|
174 |
-
|
175 |
api_key=os.environ["HF_TOKEN"],
|
176 |
structured_output={{"format": "json", "schema": {structured_output}}},
|
177 |
generation_kwargs={{
|
|
|
5 |
UltraFeedback,
|
6 |
)
|
7 |
|
8 |
+
from distilabel_dataset_generator.constants import BASE_URL, MODEL
|
9 |
+
from distilabel_dataset_generator.pipelines.base import _get_next_api_key
|
10 |
+
from distilabel_dataset_generator.utils import extract_column_names
|
|
|
|
|
11 |
|
12 |
|
13 |
def get_ultrafeedback_evaluator(aspect, is_sample):
|
14 |
ultrafeedback_evaluator = UltraFeedback(
|
15 |
llm=InferenceEndpointsLLM(
|
16 |
model_id=MODEL,
|
17 |
+
base_url=BASE_URL,
|
18 |
api_key=_get_next_api_key(),
|
19 |
generation_kwargs={
|
20 |
"temperature": 0,
|
|
|
31 |
custom_evaluator = TextGeneration(
|
32 |
llm=InferenceEndpointsLLM(
|
33 |
model_id=MODEL,
|
34 |
+
base_url=BASE_URL,
|
35 |
api_key=_get_next_api_key(),
|
36 |
structured_output={"format": "json", "schema": structured_output},
|
37 |
generation_kwargs={
|
|
|
60 |
from distilabel.llms import InferenceEndpointsLLM
|
61 |
|
62 |
MODEL = "{MODEL}"
|
63 |
+
BASE_URL = "{BASE_URL}"
|
64 |
+
os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
65 |
|
66 |
hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
|
67 |
data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
|
|
|
75 |
ultrafeedback_evaluator = UltraFeedback(
|
76 |
llm=InferenceEndpointsLLM(
|
77 |
model_id=MODEL,
|
78 |
+
base_url=BASE_URL,
|
79 |
+
api_key=os.environ["API_KEY"],
|
80 |
generation_kwargs={{
|
81 |
"temperature": 0,
|
82 |
"max_new_tokens": 2048,
|
|
|
100 |
from distilabel.llms import InferenceEndpointsLLM
|
101 |
|
102 |
MODEL = "{MODEL}"
|
103 |
+
BASE_URL = "{BASE_URL}"
|
104 |
+
os.environ["BASE_URL"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
105 |
|
106 |
hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
|
107 |
data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
|
|
|
119 |
aspect=aspect,
|
120 |
llm=InferenceEndpointsLLM(
|
121 |
model_id=MODEL,
|
122 |
+
base_url=BASE_URL,
|
123 |
+
api_key=os.environ["BASE_URL"],
|
124 |
generation_kwargs={{
|
125 |
"temperature": 0,
|
126 |
"max_new_tokens": 2048,
|
|
|
157 |
from distilabel.llms import InferenceEndpointsLLM
|
158 |
|
159 |
MODEL = "{MODEL}"
|
160 |
+
BASE_URL = "{BASE_URL}"
|
161 |
CUSTOM_TEMPLATE = "{prompt_template}"
|
162 |
os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
163 |
|
|
|
172 |
custom_evaluator = TextGeneration(
|
173 |
llm=InferenceEndpointsLLM(
|
174 |
model_id=MODEL,
|
175 |
+
base_url=BASE_URL,
|
176 |
api_key=os.environ["HF_TOKEN"],
|
177 |
structured_output={{"format": "json", "schema": {structured_output}}},
|
178 |
generation_kwargs={{
|
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
from distilabel.llms import InferenceEndpointsLLM
|
2 |
from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
3 |
|
4 |
-
from
|
|
|
|
|
5 |
MODEL,
|
6 |
-
_get_next_api_key,
|
7 |
)
|
|
|
8 |
|
9 |
INFORMATION_SEEKING_PROMPT = (
|
10 |
"You are an AI assistant designed to provide accurate and concise information on a wide"
|
@@ -144,6 +146,7 @@ def get_prompt_generator(temperature):
|
|
144 |
api_key=_get_next_api_key(),
|
145 |
model_id=MODEL,
|
146 |
tokenizer_id=MODEL,
|
|
|
147 |
generation_kwargs={
|
148 |
"temperature": temperature,
|
149 |
"max_new_tokens": 2048,
|
@@ -165,8 +168,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample):
|
|
165 |
llm=InferenceEndpointsLLM(
|
166 |
model_id=MODEL,
|
167 |
tokenizer_id=MODEL,
|
|
|
168 |
api_key=_get_next_api_key(),
|
169 |
-
magpie_pre_query_template=
|
170 |
generation_kwargs={
|
171 |
"temperature": 0.9,
|
172 |
"do_sample": True,
|
@@ -184,8 +188,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample):
|
|
184 |
llm=InferenceEndpointsLLM(
|
185 |
model_id=MODEL,
|
186 |
tokenizer_id=MODEL,
|
|
|
187 |
api_key=_get_next_api_key(),
|
188 |
-
magpie_pre_query_template=
|
189 |
generation_kwargs={
|
190 |
"temperature": 0.9,
|
191 |
"do_sample": True,
|
@@ -208,6 +213,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
|
|
208 |
llm=InferenceEndpointsLLM(
|
209 |
model_id=MODEL,
|
210 |
tokenizer_id=MODEL,
|
|
|
211 |
api_key=_get_next_api_key(),
|
212 |
generation_kwargs={
|
213 |
"temperature": 0.8,
|
@@ -223,6 +229,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
|
|
223 |
llm=InferenceEndpointsLLM(
|
224 |
model_id=MODEL,
|
225 |
tokenizer_id=MODEL,
|
|
|
226 |
api_key=_get_next_api_key(),
|
227 |
generation_kwargs={
|
228 |
"temperature": 0.8,
|
@@ -247,14 +254,16 @@ from distilabel.steps.tasks import MagpieGenerator
|
|
247 |
from distilabel.llms import InferenceEndpointsLLM
|
248 |
|
249 |
MODEL = "{MODEL}"
|
|
|
250 |
SYSTEM_PROMPT = "{system_prompt}"
|
251 |
-
os.environ["
|
252 |
|
253 |
with Pipeline(name="sft") as pipeline:
|
254 |
magpie = MagpieGenerator(
|
255 |
llm=InferenceEndpointsLLM(
|
256 |
model_id=MODEL,
|
257 |
tokenizer_id=MODEL,
|
|
|
258 |
magpie_pre_query_template="llama3",
|
259 |
generation_kwargs={{
|
260 |
"temperature": 0.9,
|
@@ -262,7 +271,7 @@ with Pipeline(name="sft") as pipeline:
|
|
262 |
"max_new_tokens": 2048,
|
263 |
"stop_sequences": {_STOP_SEQUENCES}
|
264 |
}},
|
265 |
-
api_key=os.environ["
|
266 |
),
|
267 |
n_turns={num_turns},
|
268 |
num_rows={num_rows},
|
|
|
1 |
from distilabel.llms import InferenceEndpointsLLM
|
2 |
from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
3 |
|
4 |
+
from distilabel_dataset_generator.constants import (
|
5 |
+
BASE_URL,
|
6 |
+
MAGPIE_PRE_QUERY_TEMPLATE,
|
7 |
MODEL,
|
|
|
8 |
)
|
9 |
+
from distilabel_dataset_generator.pipelines.base import _get_next_api_key
|
10 |
|
11 |
INFORMATION_SEEKING_PROMPT = (
|
12 |
"You are an AI assistant designed to provide accurate and concise information on a wide"
|
|
|
146 |
api_key=_get_next_api_key(),
|
147 |
model_id=MODEL,
|
148 |
tokenizer_id=MODEL,
|
149 |
+
base_url=BASE_URL,
|
150 |
generation_kwargs={
|
151 |
"temperature": temperature,
|
152 |
"max_new_tokens": 2048,
|
|
|
168 |
llm=InferenceEndpointsLLM(
|
169 |
model_id=MODEL,
|
170 |
tokenizer_id=MODEL,
|
171 |
+
base_url=BASE_URL,
|
172 |
api_key=_get_next_api_key(),
|
173 |
+
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
174 |
generation_kwargs={
|
175 |
"temperature": 0.9,
|
176 |
"do_sample": True,
|
|
|
188 |
llm=InferenceEndpointsLLM(
|
189 |
model_id=MODEL,
|
190 |
tokenizer_id=MODEL,
|
191 |
+
base_url=BASE_URL,
|
192 |
api_key=_get_next_api_key(),
|
193 |
+
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
194 |
generation_kwargs={
|
195 |
"temperature": 0.9,
|
196 |
"do_sample": True,
|
|
|
213 |
llm=InferenceEndpointsLLM(
|
214 |
model_id=MODEL,
|
215 |
tokenizer_id=MODEL,
|
216 |
+
base_url=BASE_URL,
|
217 |
api_key=_get_next_api_key(),
|
218 |
generation_kwargs={
|
219 |
"temperature": 0.8,
|
|
|
229 |
llm=InferenceEndpointsLLM(
|
230 |
model_id=MODEL,
|
231 |
tokenizer_id=MODEL,
|
232 |
+
base_url=BASE_URL,
|
233 |
api_key=_get_next_api_key(),
|
234 |
generation_kwargs={
|
235 |
"temperature": 0.8,
|
|
|
254 |
from distilabel.llms import InferenceEndpointsLLM
|
255 |
|
256 |
MODEL = "{MODEL}"
|
257 |
+
BASE_URL = "{BASE_URL}"
|
258 |
SYSTEM_PROMPT = "{system_prompt}"
|
259 |
+
os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
260 |
|
261 |
with Pipeline(name="sft") as pipeline:
|
262 |
magpie = MagpieGenerator(
|
263 |
llm=InferenceEndpointsLLM(
|
264 |
model_id=MODEL,
|
265 |
tokenizer_id=MODEL,
|
266 |
+
base_url=BASE_URL,
|
267 |
magpie_pre_query_template="llama3",
|
268 |
generation_kwargs={{
|
269 |
"temperature": 0.9,
|
|
|
271 |
"max_new_tokens": 2048,
|
272 |
"stop_sequences": {_STOP_SEQUENCES}
|
273 |
}},
|
274 |
+
api_key=os.environ["BASE_URL"],
|
275 |
),
|
276 |
n_turns={num_turns},
|
277 |
num_rows={num_rows},
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import random
|
2 |
-
from pydantic import BaseModel, Field
|
3 |
from typing import List
|
4 |
|
5 |
from distilabel.llms import InferenceEndpointsLLM
|
@@ -8,12 +7,11 @@ from distilabel.steps.tasks import (
|
|
8 |
TextClassification,
|
9 |
TextGeneration,
|
10 |
)
|
|
|
11 |
|
12 |
-
from
|
13 |
-
|
14 |
-
|
15 |
-
)
|
16 |
-
from src.distilabel_dataset_generator.utils import get_preprocess_labels
|
17 |
|
18 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
19 |
|
@@ -73,7 +71,7 @@ def get_prompt_generator(temperature):
|
|
73 |
llm=InferenceEndpointsLLM(
|
74 |
api_key=_get_next_api_key(),
|
75 |
model_id=MODEL,
|
76 |
-
|
77 |
structured_output={"format": "json", "schema": TextClassificationTask},
|
78 |
generation_kwargs={
|
79 |
"temperature": temperature,
|
@@ -92,7 +90,7 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
92 |
textcat_generator = GenerateTextClassificationData(
|
93 |
llm=InferenceEndpointsLLM(
|
94 |
model_id=MODEL,
|
95 |
-
|
96 |
api_key=_get_next_api_key(),
|
97 |
generation_kwargs={
|
98 |
"temperature": 0.9,
|
@@ -114,7 +112,7 @@ def get_labeller_generator(system_prompt, labels, num_labels):
|
|
114 |
labeller_generator = TextClassification(
|
115 |
llm=InferenceEndpointsLLM(
|
116 |
model_id=MODEL,
|
117 |
-
|
118 |
api_key=_get_next_api_key(),
|
119 |
generation_kwargs={
|
120 |
"temperature": 0.7,
|
@@ -149,8 +147,9 @@ from distilabel.steps import LoadDataFromDicts, KeepColumns
|
|
149 |
from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
|
150 |
|
151 |
MODEL = "{MODEL}"
|
|
|
152 |
TEXT_CLASSIFICATION_TASK = "{system_prompt}"
|
153 |
-
os.environ["
|
154 |
"hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
155 |
)
|
156 |
|
@@ -161,8 +160,8 @@ with Pipeline(name="textcat") as pipeline:
|
|
161 |
textcat_generation = GenerateTextClassificationData(
|
162 |
llm=InferenceEndpointsLLM(
|
163 |
model_id=MODEL,
|
164 |
-
|
165 |
-
api_key=os.environ["
|
166 |
generation_kwargs={{
|
167 |
"temperature": 0.8,
|
168 |
"max_new_tokens": 2048,
|
@@ -205,8 +204,8 @@ with Pipeline(name="textcat") as pipeline:
|
|
205 |
textcat_labeller = TextClassification(
|
206 |
llm=InferenceEndpointsLLM(
|
207 |
model_id=MODEL,
|
208 |
-
|
209 |
-
api_key=os.environ["
|
210 |
generation_kwargs={{
|
211 |
"temperature": 0.8,
|
212 |
"max_new_tokens": 2048,
|
|
|
1 |
import random
|
|
|
2 |
from typing import List
|
3 |
|
4 |
from distilabel.llms import InferenceEndpointsLLM
|
|
|
7 |
TextClassification,
|
8 |
TextGeneration,
|
9 |
)
|
10 |
+
from pydantic import BaseModel, Field
|
11 |
|
12 |
+
from distilabel_dataset_generator.constants import BASE_URL, MODEL
|
13 |
+
from distilabel_dataset_generator.pipelines.base import _get_next_api_key
|
14 |
+
from distilabel_dataset_generator.utils import get_preprocess_labels
|
|
|
|
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
17 |
|
|
|
71 |
llm=InferenceEndpointsLLM(
|
72 |
api_key=_get_next_api_key(),
|
73 |
model_id=MODEL,
|
74 |
+
base_url=BASE_URL,
|
75 |
structured_output={"format": "json", "schema": TextClassificationTask},
|
76 |
generation_kwargs={
|
77 |
"temperature": temperature,
|
|
|
90 |
textcat_generator = GenerateTextClassificationData(
|
91 |
llm=InferenceEndpointsLLM(
|
92 |
model_id=MODEL,
|
93 |
+
base_url=BASE_URL,
|
94 |
api_key=_get_next_api_key(),
|
95 |
generation_kwargs={
|
96 |
"temperature": 0.9,
|
|
|
112 |
labeller_generator = TextClassification(
|
113 |
llm=InferenceEndpointsLLM(
|
114 |
model_id=MODEL,
|
115 |
+
base_url=BASE_URL,
|
116 |
api_key=_get_next_api_key(),
|
117 |
generation_kwargs={
|
118 |
"temperature": 0.7,
|
|
|
147 |
from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
|
148 |
|
149 |
MODEL = "{MODEL}"
|
150 |
+
BASE_URL = "{BASE_URL}"
|
151 |
TEXT_CLASSIFICATION_TASK = "{system_prompt}"
|
152 |
+
os.environ["API_KEY"] = (
|
153 |
"hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
154 |
)
|
155 |
|
|
|
160 |
textcat_generation = GenerateTextClassificationData(
|
161 |
llm=InferenceEndpointsLLM(
|
162 |
model_id=MODEL,
|
163 |
+
base_url=BASE_URL,
|
164 |
+
api_key=os.environ["API_KEY"],
|
165 |
generation_kwargs={{
|
166 |
"temperature": 0.8,
|
167 |
"max_new_tokens": 2048,
|
|
|
204 |
textcat_labeller = TextClassification(
|
205 |
llm=InferenceEndpointsLLM(
|
206 |
model_id=MODEL,
|
207 |
+
base_url=BASE_URL,
|
208 |
+
api_key=os.environ["API_KEY"],
|
209 |
generation_kwargs={{
|
210 |
"temperature": 0.8,
|
211 |
"max_new_tokens": 2048,
|
src/distilabel_dataset_generator/utils.py
CHANGED
@@ -6,40 +6,13 @@ import gradio as gr
|
|
6 |
import numpy as np
|
7 |
import pandas as pd
|
8 |
from gradio.oauth import (
|
9 |
-
|
10 |
-
OAUTH_CLIENT_SECRET,
|
11 |
-
OAUTH_SCOPES,
|
12 |
-
OPENID_PROVIDER_URL,
|
13 |
get_space,
|
14 |
)
|
15 |
from huggingface_hub import whoami
|
16 |
from jinja2 import Environment, meta
|
17 |
|
18 |
-
from
|
19 |
-
|
20 |
-
_LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
|
21 |
-
|
22 |
-
|
23 |
-
_CHECK_IF_SPACE_IS_SET = (
|
24 |
-
all(
|
25 |
-
[
|
26 |
-
OAUTH_CLIENT_ID,
|
27 |
-
OAUTH_CLIENT_SECRET,
|
28 |
-
OAUTH_SCOPES,
|
29 |
-
OPENID_PROVIDER_URL,
|
30 |
-
]
|
31 |
-
)
|
32 |
-
or get_space() is None
|
33 |
-
)
|
34 |
-
|
35 |
-
if _CHECK_IF_SPACE_IS_SET:
|
36 |
-
from gradio.oauth import OAuthToken
|
37 |
-
else:
|
38 |
-
OAuthToken = str
|
39 |
-
|
40 |
-
|
41 |
-
def get_login_button():
|
42 |
-
return gr.LoginButton(value="Sign in!", size="sm", scale=2).activate()
|
43 |
|
44 |
|
45 |
def get_duplicate_button():
|
@@ -85,13 +58,6 @@ def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
|
|
85 |
)
|
86 |
|
87 |
|
88 |
-
def get_token(oauth_token: Union[OAuthToken, None]):
|
89 |
-
if oauth_token:
|
90 |
-
return oauth_token.token
|
91 |
-
else:
|
92 |
-
return ""
|
93 |
-
|
94 |
-
|
95 |
def swap_visibility(oauth_token: Union[OAuthToken, None]):
|
96 |
if oauth_token:
|
97 |
return gr.update(elem_classes=["main_ui_logged_in"])
|
@@ -99,28 +65,6 @@ def swap_visibility(oauth_token: Union[OAuthToken, None]):
|
|
99 |
return gr.update(elem_classes=["main_ui_logged_out"])
|
100 |
|
101 |
|
102 |
-
def get_base_app():
|
103 |
-
with gr.Blocks(
|
104 |
-
title="🧬 Synthetic Data Generator",
|
105 |
-
head="🧬 Synthetic Data Generator",
|
106 |
-
css=_LOGGED_OUT_CSS,
|
107 |
-
) as app:
|
108 |
-
with gr.Row():
|
109 |
-
gr.Markdown(
|
110 |
-
"Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
|
111 |
-
)
|
112 |
-
with gr.Row():
|
113 |
-
gr.Column()
|
114 |
-
get_login_button()
|
115 |
-
gr.Column()
|
116 |
-
|
117 |
-
gr.Markdown("## Iterate on a sample dataset")
|
118 |
-
with gr.Column() as main_ui:
|
119 |
-
pass
|
120 |
-
|
121 |
-
return app
|
122 |
-
|
123 |
-
|
124 |
def get_argilla_client() -> Union[rg.Argilla, None]:
|
125 |
return argilla_client
|
126 |
|
|
|
6 |
import numpy as np
|
7 |
import pandas as pd
|
8 |
from gradio.oauth import (
|
9 |
+
OAuthToken,
|
|
|
|
|
|
|
10 |
get_space,
|
11 |
)
|
12 |
from huggingface_hub import whoami
|
13 |
from jinja2 import Environment, meta
|
14 |
|
15 |
+
from distilabel_dataset_generator.constants import argilla_client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
def get_duplicate_button():
|
|
|
58 |
)
|
59 |
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def swap_visibility(oauth_token: Union[OAuthToken, None]):
|
62 |
if oauth_token:
|
63 |
return gr.update(elem_classes=["main_ui_logged_in"])
|
|
|
65 |
return gr.update(elem_classes=["main_ui_logged_out"])
|
66 |
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def get_argilla_client() -> Union[rg.Argilla, None]:
|
69 |
return argilla_client
|
70 |
|