Spaces:
Runtime error
Runtime error
Commit
Β·
9b4773a
1
Parent(s):
5829740
feat: update batch size
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
|
@@ -74,6 +74,8 @@ def generate_dataset(
|
|
| 74 |
if repo_name is not None and org_name is not None
|
| 75 |
else None
|
| 76 |
)
|
|
|
|
|
|
|
| 77 |
if repo_id is not None:
|
| 78 |
if not all([repo_id, org_name, repo_name]):
|
| 79 |
raise gr.Error(
|
|
@@ -295,7 +297,7 @@ with gr.Blocks(
|
|
| 295 |
],
|
| 296 |
outputs=[table],
|
| 297 |
show_progress=True,
|
| 298 |
-
).
|
| 299 |
fn=show_success_message,
|
| 300 |
inputs=[org_name, repo_name],
|
| 301 |
outputs=[success_message],
|
|
|
|
| 74 |
if repo_name is not None and org_name is not None
|
| 75 |
else None
|
| 76 |
)
|
| 77 |
+
if oauth_token is None or oauth_token == "":
|
| 78 |
+
print(oauth_token, repo_id)
|
| 79 |
if repo_id is not None:
|
| 80 |
if not all([repo_id, org_name, repo_name]):
|
| 81 |
raise gr.Error(
|
|
|
|
| 297 |
],
|
| 298 |
outputs=[table],
|
| 299 |
show_progress=True,
|
| 300 |
+
).success(
|
| 301 |
fn=show_success_message,
|
| 302 |
inputs=[org_name, repo_name],
|
| 303 |
outputs=[success_message],
|
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
|
@@ -138,6 +138,7 @@ _STOP_SEQUENCES = [
|
|
| 138 |
"assistant",
|
| 139 |
" \n\n",
|
| 140 |
]
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
def _get_output_mappings(num_turns):
|
|
@@ -205,33 +206,30 @@ def get_pipeline(num_turns, num_rows, system_prompt):
|
|
| 205 |
"stop_sequences": _STOP_SEQUENCES,
|
| 206 |
},
|
| 207 |
),
|
| 208 |
-
batch_size=
|
| 209 |
n_turns=num_turns,
|
| 210 |
num_rows=num_rows,
|
| 211 |
system_prompt=system_prompt,
|
| 212 |
output_mappings={"instruction": "prompt"},
|
| 213 |
-
only_instruction=True
|
| 214 |
)
|
| 215 |
-
|
| 216 |
generate_response = TextGeneration(
|
| 217 |
llm=InferenceEndpointsLLM(
|
| 218 |
model_id=MODEL,
|
| 219 |
tokenizer_id=MODEL,
|
| 220 |
api_key=os.environ["HF_TOKEN"],
|
| 221 |
-
generation_kwargs={
|
| 222 |
-
"temperature": 0.8,
|
| 223 |
-
"max_new_tokens": 1024
|
| 224 |
-
},
|
| 225 |
),
|
| 226 |
system_prompt=system_prompt,
|
| 227 |
output_mappings={"generation": "completion"},
|
| 228 |
-
input_mappings={"instruction": "prompt"}
|
| 229 |
)
|
| 230 |
-
|
| 231 |
keep_columns = KeepColumns(
|
| 232 |
columns=list(output_mappings.values()) + ["model_name"],
|
| 233 |
)
|
| 234 |
-
|
| 235 |
magpie.connect(generate_response)
|
| 236 |
generate_response.connect(keep_columns)
|
| 237 |
return pipeline
|
|
@@ -250,7 +248,7 @@ def get_pipeline(num_turns, num_rows, system_prompt):
|
|
| 250 |
"stop_sequences": _STOP_SEQUENCES,
|
| 251 |
},
|
| 252 |
),
|
| 253 |
-
batch_size=
|
| 254 |
n_turns=num_turns,
|
| 255 |
num_rows=num_rows,
|
| 256 |
system_prompt=system_prompt,
|
|
|
|
| 138 |
"assistant",
|
| 139 |
" \n\n",
|
| 140 |
]
|
| 141 |
+
DEFAULT_BATCH_SIZE = 1
|
| 142 |
|
| 143 |
|
| 144 |
def _get_output_mappings(num_turns):
|
|
|
|
| 206 |
"stop_sequences": _STOP_SEQUENCES,
|
| 207 |
},
|
| 208 |
),
|
| 209 |
+
batch_size=DEFAULT_BATCH_SIZE,
|
| 210 |
n_turns=num_turns,
|
| 211 |
num_rows=num_rows,
|
| 212 |
system_prompt=system_prompt,
|
| 213 |
output_mappings={"instruction": "prompt"},
|
| 214 |
+
only_instruction=True,
|
| 215 |
)
|
| 216 |
+
|
| 217 |
generate_response = TextGeneration(
|
| 218 |
llm=InferenceEndpointsLLM(
|
| 219 |
model_id=MODEL,
|
| 220 |
tokenizer_id=MODEL,
|
| 221 |
api_key=os.environ["HF_TOKEN"],
|
| 222 |
+
generation_kwargs={"temperature": 0.8, "max_new_tokens": 1024},
|
|
|
|
|
|
|
|
|
|
| 223 |
),
|
| 224 |
system_prompt=system_prompt,
|
| 225 |
output_mappings={"generation": "completion"},
|
| 226 |
+
input_mappings={"instruction": "prompt"},
|
| 227 |
)
|
| 228 |
+
|
| 229 |
keep_columns = KeepColumns(
|
| 230 |
columns=list(output_mappings.values()) + ["model_name"],
|
| 231 |
)
|
| 232 |
+
|
| 233 |
magpie.connect(generate_response)
|
| 234 |
generate_response.connect(keep_columns)
|
| 235 |
return pipeline
|
|
|
|
| 248 |
"stop_sequences": _STOP_SEQUENCES,
|
| 249 |
},
|
| 250 |
),
|
| 251 |
+
batch_size=DEFAULT_BATCH_SIZE,
|
| 252 |
n_turns=num_turns,
|
| 253 |
num_rows=num_rows,
|
| 254 |
system_prompt=system_prompt,
|
src/distilabel_dataset_generator/utils.py
CHANGED
|
@@ -30,7 +30,7 @@ def get_login_button():
|
|
| 30 |
return gr.LoginButton(
|
| 31 |
value="Sign in with Hugging Face!",
|
| 32 |
size="lg",
|
| 33 |
-
)
|
| 34 |
|
| 35 |
|
| 36 |
def get_duplicate_button():
|
|
|
|
| 30 |
return gr.LoginButton(
|
| 31 |
value="Sign in with Hugging Face!",
|
| 32 |
size="lg",
|
| 33 |
+
).activate()
|
| 34 |
|
| 35 |
|
| 36 |
def get_duplicate_button():
|