Spaces:
Runtime error
Runtime error
Add support for image classification
Browse files
app.py
CHANGED
|
@@ -31,11 +31,12 @@ DATASETS_PREVIEW_API = os.getenv("DATASETS_PREVIEW_API")
|
|
| 31 |
TASK_TO_ID = {
|
| 32 |
"binary_classification": 1,
|
| 33 |
"multi_class_classification": 2,
|
| 34 |
-
# "multi_label_classification": 3, # Not fully supported in AutoTrain
|
| 35 |
"entity_extraction": 4,
|
| 36 |
"extractive_question_answering": 5,
|
| 37 |
"translation": 6,
|
| 38 |
"summarization": 8,
|
|
|
|
|
|
|
| 39 |
}
|
| 40 |
|
| 41 |
TASK_TO_DEFAULT_METRICS = {
|
|
@@ -50,8 +51,22 @@ TASK_TO_DEFAULT_METRICS = {
|
|
| 50 |
"extractive_question_answering": [],
|
| 51 |
"translation": ["sacrebleu"],
|
| 52 |
"summarization": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
}
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
| 56 |
|
| 57 |
# Extracted from utils.get_supported_metrics
|
|
@@ -355,6 +370,27 @@ with st.expander("Advanced configuration"):
|
|
| 355 |
col_mapping[question_col] = "question"
|
| 356 |
col_mapping[answers_text_col] = "answers.text"
|
| 357 |
col_mapping[answers_start_col] = "answers.answer_start"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
# Select metrics
|
| 360 |
st.markdown("**Select metrics**")
|
|
@@ -408,9 +444,9 @@ with st.form(key="form"):
|
|
| 408 |
"proj_name": f"eval-project-{project_id}",
|
| 409 |
"task": TASK_TO_ID[selected_task],
|
| 410 |
"config": {
|
| 411 |
-
"language":
|
| 412 |
-
if selected_task
|
| 413 |
-
else "
|
| 414 |
"max_models": 5,
|
| 415 |
"instance": {
|
| 416 |
"provider": "aws",
|
|
|
|
| 31 |
TASK_TO_ID = {
|
| 32 |
"binary_classification": 1,
|
| 33 |
"multi_class_classification": 2,
|
|
|
|
| 34 |
"entity_extraction": 4,
|
| 35 |
"extractive_question_answering": 5,
|
| 36 |
"translation": 6,
|
| 37 |
"summarization": 8,
|
| 38 |
+
"image_binary_classification": 17,
|
| 39 |
+
"image_multi_class_classification": 18,
|
| 40 |
}
|
| 41 |
|
| 42 |
TASK_TO_DEFAULT_METRICS = {
|
|
|
|
| 51 |
"extractive_question_answering": [],
|
| 52 |
"translation": ["sacrebleu"],
|
| 53 |
"summarization": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
|
| 54 |
+
"image_binary_classification": ["f1", "precision", "recall", "auc", "accuracy"],
|
| 55 |
+
"image_multi_class_classification": [
|
| 56 |
+
"f1",
|
| 57 |
+
"precision",
|
| 58 |
+
"recall",
|
| 59 |
+
"accuracy",
|
| 60 |
+
],
|
| 61 |
}
|
| 62 |
|
| 63 |
+
AUTOTRAIN_TASK_TO_LANG = {
|
| 64 |
+
"translation": "en2de",
|
| 65 |
+
"image_binary_classification": "unk",
|
| 66 |
+
"image_multi_class_classification": "unk",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
| 71 |
|
| 72 |
# Extracted from utils.get_supported_metrics
|
|
|
|
| 370 |
col_mapping[question_col] = "question"
|
| 371 |
col_mapping[answers_text_col] = "answers.text"
|
| 372 |
col_mapping[answers_start_col] = "answers.answer_start"
|
| 373 |
+
elif selected_task in ["image_binary_classification", "image_multi_class_classification"]:
|
| 374 |
+
with col1:
|
| 375 |
+
st.markdown("`image` column")
|
| 376 |
+
st.text("")
|
| 377 |
+
st.text("")
|
| 378 |
+
st.text("")
|
| 379 |
+
st.text("")
|
| 380 |
+
st.markdown("`target` column")
|
| 381 |
+
with col2:
|
| 382 |
+
image_col = st.selectbox(
|
| 383 |
+
"This column should contain the images to be classified",
|
| 384 |
+
col_names,
|
| 385 |
+
index=col_names.index(get_key(metadata[0]["col_mapping"], "image")) if metadata is not None else 0,
|
| 386 |
+
)
|
| 387 |
+
target_col = st.selectbox(
|
| 388 |
+
"This column should contain the labels associated with the images",
|
| 389 |
+
col_names,
|
| 390 |
+
index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
|
| 391 |
+
)
|
| 392 |
+
col_mapping[image_col] = "image"
|
| 393 |
+
col_mapping[target_col] = "target"
|
| 394 |
|
| 395 |
# Select metrics
|
| 396 |
st.markdown("**Select metrics**")
|
|
|
|
| 444 |
"proj_name": f"eval-project-{project_id}",
|
| 445 |
"task": TASK_TO_ID[selected_task],
|
| 446 |
"config": {
|
| 447 |
+
"language": AUTOTRAIN_TASK_TO_LANG[selected_task]
|
| 448 |
+
if selected_task in AUTOTRAIN_TASK_TO_LANG
|
| 449 |
+
else "en",
|
| 450 |
"max_models": 5,
|
| 451 |
"instance": {
|
| 452 |
"provider": "aws",
|
utils.py
CHANGED
|
@@ -11,14 +11,15 @@ from tqdm import tqdm
|
|
| 11 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
| 12 |
"binary_classification": "text-classification",
|
| 13 |
"multi_class_classification": "text-classification",
|
| 14 |
-
# "multi_label_classification": "text-classification", # Not fully supported in AutoTrain
|
| 15 |
"entity_extraction": "token-classification",
|
| 16 |
"extractive_question_answering": "question-answering",
|
| 17 |
"translation": "translation",
|
| 18 |
"summarization": "summarization",
|
| 19 |
-
|
|
|
|
| 20 |
}
|
| 21 |
|
|
|
|
| 22 |
HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items()}
|
| 23 |
LOGS_REPO = "evaluation-job-logs"
|
| 24 |
|
|
|
|
| 11 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
| 12 |
"binary_classification": "text-classification",
|
| 13 |
"multi_class_classification": "text-classification",
|
|
|
|
| 14 |
"entity_extraction": "token-classification",
|
| 15 |
"extractive_question_answering": "question-answering",
|
| 16 |
"translation": "translation",
|
| 17 |
"summarization": "summarization",
|
| 18 |
+
"image_binary_classification": "image-classification",
|
| 19 |
+
"image_multi_class_classification": "image-classification",
|
| 20 |
}
|
| 21 |
|
| 22 |
+
|
| 23 |
HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items()}
|
| 24 |
LOGS_REPO = "evaluation-job-logs"
|
| 25 |
|