Spaces:
Runtime error
Runtime error
Commit
Β·
e044b6a
1
Parent(s):
4983843
add label randomification perf created sample
Browse files
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
import uuid
|
| 3 |
from typing import List, Union
|
| 4 |
|
|
@@ -11,6 +12,7 @@ from huggingface_hub import HfApi
|
|
| 11 |
|
| 12 |
from src.synthetic_dataset_generator.apps.base import (
|
| 13 |
hide_success_message,
|
|
|
|
| 14 |
show_success_message,
|
| 15 |
validate_argilla_user_workspace_dataset,
|
| 16 |
validate_push_to_hub,
|
|
@@ -119,9 +121,17 @@ def generate_dataset(
|
|
| 119 |
)
|
| 120 |
remaining_rows = num_rows - n_processed
|
| 121 |
batch_size = min(batch_size, remaining_rows)
|
| 122 |
-
inputs = [
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
batch = list(textcat_generator.process(inputs=inputs))
|
| 126 |
textcat_results.extend(batch[0])
|
| 127 |
n_processed += batch_size
|
|
@@ -160,6 +170,18 @@ def generate_dataset(
|
|
| 160 |
dataframe["label"] = dataframe["label"].apply(
|
| 161 |
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
| 162 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
progress(1.0, desc="Dataset generation completed")
|
| 164 |
return dataframe
|
| 165 |
|
|
@@ -172,6 +194,7 @@ def push_dataset_to_hub(
|
|
| 172 |
labels: List[str] = None,
|
| 173 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
| 174 |
private: bool = False,
|
|
|
|
| 175 |
):
|
| 176 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 177 |
labels = get_preprocess_labels(labels)
|
|
@@ -195,6 +218,7 @@ def push_dataset_to_hub(
|
|
| 195 |
token=oauth_token.token,
|
| 196 |
create_pr=False,
|
| 197 |
)
|
|
|
|
| 198 |
|
| 199 |
|
| 200 |
def push_dataset(
|
|
@@ -208,6 +232,7 @@ def push_dataset(
|
|
| 208 |
labels: List[str] = None,
|
| 209 |
private: bool = False,
|
| 210 |
temperature: float = 0.8,
|
|
|
|
| 211 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
| 212 |
progress=gr.Progress(),
|
| 213 |
) -> pd.DataFrame:
|
|
@@ -221,7 +246,14 @@ def push_dataset(
|
|
| 221 |
temperature=temperature,
|
| 222 |
)
|
| 223 |
push_dataset_to_hub(
|
| 224 |
-
dataframe,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
)
|
| 226 |
|
| 227 |
dataframe = dataframe[
|
|
@@ -407,7 +439,7 @@ with gr.Blocks() as app:
|
|
| 407 |
("Ambiguous", "ambiguous"),
|
| 408 |
("Mixed", "mixed"),
|
| 409 |
],
|
| 410 |
-
value="
|
| 411 |
label="Clarity",
|
| 412 |
info="Set how easily the correct label or labels can be identified.",
|
| 413 |
interactive=True,
|
|
@@ -419,7 +451,7 @@ with gr.Blocks() as app:
|
|
| 419 |
("PhD", "PhD"),
|
| 420 |
("Mixed", "mixed"),
|
| 421 |
],
|
| 422 |
-
value="
|
| 423 |
label="Difficulty",
|
| 424 |
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
| 425 |
interactive=True,
|
|
@@ -544,6 +576,7 @@ with gr.Blocks() as app:
|
|
| 544 |
labels,
|
| 545 |
private,
|
| 546 |
temperature,
|
|
|
|
| 547 |
],
|
| 548 |
outputs=[success_message],
|
| 549 |
show_progress=True,
|
|
|
|
| 1 |
import json
|
| 2 |
+
import random
|
| 3 |
import uuid
|
| 4 |
from typing import List, Union
|
| 5 |
|
|
|
|
| 12 |
|
| 13 |
from src.synthetic_dataset_generator.apps.base import (
|
| 14 |
hide_success_message,
|
| 15 |
+
push_pipeline_code_to_hub,
|
| 16 |
show_success_message,
|
| 17 |
validate_argilla_user_workspace_dataset,
|
| 18 |
validate_push_to_hub,
|
|
|
|
| 121 |
)
|
| 122 |
remaining_rows = num_rows - n_processed
|
| 123 |
batch_size = min(batch_size, remaining_rows)
|
| 124 |
+
inputs = []
|
| 125 |
+
for _ in range(batch_size):
|
| 126 |
+
if num_labels == 1:
|
| 127 |
+
num_labels = 1
|
| 128 |
+
else:
|
| 129 |
+
num_labels = int(random.gammavariate(2, 2) * num_labels)
|
| 130 |
+
sampled_labels = random.sample(labels, num_labels)
|
| 131 |
+
random.shuffle(sampled_labels)
|
| 132 |
+
inputs.append(
|
| 133 |
+
{"task": f"{system_prompt}. Labels: {', '.join(sampled_labels)}"}
|
| 134 |
+
)
|
| 135 |
batch = list(textcat_generator.process(inputs=inputs))
|
| 136 |
textcat_results.extend(batch[0])
|
| 137 |
n_processed += batch_size
|
|
|
|
| 170 |
dataframe["label"] = dataframe["label"].apply(
|
| 171 |
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
| 172 |
)
|
| 173 |
+
else:
|
| 174 |
+
dataframe["labels"] = dataframe["labels"].apply(
|
| 175 |
+
lambda x: list(
|
| 176 |
+
set(
|
| 177 |
+
[
|
| 178 |
+
label.lower().strip()
|
| 179 |
+
for label in x
|
| 180 |
+
if label.lower().strip() in labels
|
| 181 |
+
]
|
| 182 |
+
)
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
progress(1.0, desc="Dataset generation completed")
|
| 186 |
return dataframe
|
| 187 |
|
|
|
|
| 194 |
labels: List[str] = None,
|
| 195 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
| 196 |
private: bool = False,
|
| 197 |
+
pipeline_code: str = "",
|
| 198 |
):
|
| 199 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 200 |
labels = get_preprocess_labels(labels)
|
|
|
|
| 218 |
token=oauth_token.token,
|
| 219 |
create_pr=False,
|
| 220 |
)
|
| 221 |
+
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
| 222 |
|
| 223 |
|
| 224 |
def push_dataset(
|
|
|
|
| 232 |
labels: List[str] = None,
|
| 233 |
private: bool = False,
|
| 234 |
temperature: float = 0.8,
|
| 235 |
+
pipeline_code: str = "",
|
| 236 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
| 237 |
progress=gr.Progress(),
|
| 238 |
) -> pd.DataFrame:
|
|
|
|
| 246 |
temperature=temperature,
|
| 247 |
)
|
| 248 |
push_dataset_to_hub(
|
| 249 |
+
dataframe,
|
| 250 |
+
org_name,
|
| 251 |
+
repo_name,
|
| 252 |
+
num_labels,
|
| 253 |
+
labels,
|
| 254 |
+
oauth_token,
|
| 255 |
+
private,
|
| 256 |
+
pipeline_code,
|
| 257 |
)
|
| 258 |
|
| 259 |
dataframe = dataframe[
|
|
|
|
| 439 |
("Ambiguous", "ambiguous"),
|
| 440 |
("Mixed", "mixed"),
|
| 441 |
],
|
| 442 |
+
value="understandable with some effort",
|
| 443 |
label="Clarity",
|
| 444 |
info="Set how easily the correct label or labels can be identified.",
|
| 445 |
interactive=True,
|
|
|
|
| 451 |
("PhD", "PhD"),
|
| 452 |
("Mixed", "mixed"),
|
| 453 |
],
|
| 454 |
+
value="high school",
|
| 455 |
label="Difficulty",
|
| 456 |
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
| 457 |
interactive=True,
|
|
|
|
| 576 |
labels,
|
| 577 |
private,
|
| 578 |
temperature,
|
| 579 |
+
pipeline_code,
|
| 580 |
],
|
| 581 |
outputs=[success_message],
|
| 582 |
show_progress=True,
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
|
@@ -15,35 +15,29 @@ from synthetic_dataset_generator.utils import get_preprocess_labels
|
|
| 15 |
|
| 16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
| 17 |
|
| 18 |
-
Your
|
| 19 |
|
| 20 |
-
The prompt
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
{"classification_task": "Classify the following customer review of a cinema as", "labels": ["positive", "negative"]}
|
| 25 |
-
|
| 26 |
-
{"classification_task": "Categorize the following news article into one or more of the following categories:", "labels": ["politics", "sports", "technology", "entertainment", "health", "business", "environment", "education", "science", "international"]}
|
| 27 |
-
|
| 28 |
-
{"classification_task": "Classify the following news article into one or more of the following categories:", "labels": ['politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international']}
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
{"classification_task": "Classify the following movie review into one of the following categories:", "labels": ['critical', 'praise', 'disappointed', 'enthusiastic']}
|
| 35 |
-
|
| 36 |
-
{"classification_task": "Categorize the following customer service transcript into one of the following categories:", "labels": ['satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent']}
|
| 37 |
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
-
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
"""
|
| 48 |
|
| 49 |
DEFAULT_DATASET_DESCRIPTIONS = [
|
|
@@ -66,6 +60,19 @@ class TextClassificationTask(BaseModel):
|
|
| 66 |
)
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def get_prompt_generator():
|
| 70 |
prompt_generator = TextGeneration(
|
| 71 |
llm=InferenceEndpointsLLM(
|
|
|
|
| 15 |
|
| 16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
| 17 |
|
| 18 |
+
Your should write a prompt following a the dataset description. Respond with the prompt and nothing else.
|
| 19 |
|
| 20 |
+
The prompt should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
|
| 21 |
|
| 22 |
+
Make sure to always include all of the detailed information from the description and the context of the company that is provided.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
Don't include the labels in the classification_task but only provide a high level description of the classification task.
|
| 25 |
|
| 26 |
+
If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
Description: DavidMovieHouse is a cinema that has been in business for 10 years.
|
| 29 |
+
Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews. Classify the customer reviews as", "labels": ["positive", "negative"]}
|
| 30 |
|
| 31 |
+
Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
|
| 32 |
+
Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
|
| 33 |
|
| 34 |
+
Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
|
| 35 |
+
Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
|
| 36 |
|
| 37 |
+
Description: A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels "data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"
|
| 38 |
+
Output: {"classification_task": "A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels", "labels": ["data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"]}
|
| 39 |
|
| 40 |
+
Description:
|
| 41 |
"""
|
| 42 |
|
| 43 |
DEFAULT_DATASET_DESCRIPTIONS = [
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
|
| 63 |
+
class DatasetDescription(BaseModel):
|
| 64 |
+
description: str = Field(
|
| 65 |
+
...,
|
| 66 |
+
title="description",
|
| 67 |
+
description="The description of the dataset.",
|
| 68 |
+
)
|
| 69 |
+
labels: list[str] = Field(
|
| 70 |
+
...,
|
| 71 |
+
title="labels",
|
| 72 |
+
description="The possible labels for the classification task.",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
def get_prompt_generator():
|
| 77 |
prompt_generator = TextGeneration(
|
| 78 |
llm=InferenceEndpointsLLM(
|