huzey commited on
Commit
48bc263
·
1 Parent(s): 112a402

update dataset UI

Browse files
app.py CHANGED
@@ -63,6 +63,7 @@ DATASETS = {
63
  'Pose': [
64
  ('sayakpaul/poses-controlnet-dataset', None),
65
  ('razdab/sign_pose_M', None),
 
66
  ('Fiacre/small-animal-poses-controlnet-dataset', None),
67
  ('junjuice0/vtuber-tachi-e', None),
68
  ],
@@ -77,11 +78,11 @@ DATASETS = {
77
  ('efoley/sar_tile_512', None),
78
  ],
79
  'Medical': [
80
- ('Mahadih534/Chest_CT-Scan_images-Dataset', 4),
81
- ('Falah/Alzheimer_MRI', 4),
82
- ('sartajbhuvaji/Brain-Tumor-Classification', 4),
83
  ('TrainingDataPro/chest-x-rays', None),
84
  ('hongrui/mimic_chest_xray_v_1', None),
 
 
85
  ('Leonardo6/path-vqa', None),
86
  ('Itsunori/path-vqa_jap', None),
87
  ('ruby-jrl/isic-2024-2', None),
@@ -95,7 +96,6 @@ DATASETS = {
95
  ('jlbaker361/dcgan-eval-creative_gan_256_256', None),
96
  ('Francesco/csgo-videogame', None),
97
  ('Francesco/apex-videogame', None),
98
- ('Marqo/deepfashion-multimodal', None),
99
  ('huggan/pokemon', None),
100
  ('huggan/few-shot-universe', None),
101
  ('huggan/flowers-102-categories', None),
@@ -1282,9 +1282,97 @@ def make_input_video_section():
1282
  return input_gallery, submit_button, clear_images_button, max_frames_number
1283
 
1284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1285
  def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False):
1286
  gr.Markdown('### Input Images')
1287
- input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False)
 
1288
 
1289
  submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary')
1290
  with gr.Row():
@@ -1311,11 +1399,30 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1311
  create_file_button, download_button = add_download_button(input_gallery, "input_images")
1312
 
1313
  gr.Markdown('### Load Datasets')
1314
- load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
1315
- advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets", value="Advanced" if advanced else "Basic", elem_id="advanced-radio", show_label=True)
1316
  with gr.Column() as basic_block:
1317
- example_gallery = gr.Gallery(value=example_items, label="Example Images", show_label=True, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False, elem_id="example-gallery")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1318
  with gr.Column() as advanced_block:
 
1319
  # dataset_names = DATASET_NAMES
1320
  # dataset_classes = DATASET_CLASSES
1321
  dataset_categories = list(DATASETS.keys())
@@ -1344,7 +1451,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1344
  random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True)
1345
 
1346
  # add functionality, save and load images to profile
1347
- with gr.Accordion("Saved Image Profiles", open=True) as profile_accordion:
1348
  with gr.Row():
1349
  profile_text = gr.Textbox(label="Profile name", placeholder="Type here: Profile name to save/load/delete", elem_id="profile-name", scale=6, show_label=False)
1350
  list_profiles_button = gr.Button("📋 List", elem_id="list-profile-button", variant='secondary', scale=3)
@@ -1445,87 +1552,7 @@ def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_
1445
  return gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=is_random)
1446
  is_random_checkbox.change(fn=change_random_seed, inputs=is_random_checkbox, outputs=random_seed_slider)
1447
 
1448
-
1449
- def load_dataset_images(is_advanced, dataset_name, num_images=10,
1450
- is_filter=True, filter_by_class_text="0,1,2",
1451
- is_random=False, seed=1):
1452
- progress = gr.Progress()
1453
- progress(0, desc="Loading Images")
1454
- if is_advanced == "Basic":
1455
- gr.Info("Loaded images from Ego-Exo4D")
1456
- return default_images
1457
- try:
1458
- progress(0.5, desc="Downloading Dataset")
1459
- if 'EgoThink' in dataset_name:
1460
- dataset = load_dataset(dataset_name, 'Activity', trust_remote_code=True)
1461
- else:
1462
- dataset = load_dataset(dataset_name, trust_remote_code=True)
1463
- key = list(dataset.keys())[0]
1464
- dataset = dataset[key]
1465
- except Exception as e:
1466
- raise gr.Error(f"Error loading dataset {dataset_name}: {e}")
1467
- if num_images > len(dataset):
1468
- num_images = len(dataset)
1469
-
1470
- if is_filter:
1471
- progress(0.8, desc="Filtering Images")
1472
- classes = [int(i) for i in filter_by_class_text.split(",")]
1473
- labels = np.array(dataset['label'])
1474
- unique_labels = np.unique(labels)
1475
- valid_classes = [i for i in classes if i in unique_labels]
1476
- invalid_classes = [i for i in classes if i not in unique_labels]
1477
- if len(invalid_classes) > 0:
1478
- gr.Warning(f"Classes {invalid_classes} not found in the dataset.")
1479
- if len(valid_classes) == 0:
1480
- gr.Error(f"Classes {classes} not found in the dataset.")
1481
- return None
1482
- # shuffle each class
1483
- chunk_size = num_images // len(valid_classes)
1484
- image_idx = []
1485
- for i in valid_classes:
1486
- idx = np.where(labels == i)[0]
1487
- if is_random:
1488
- idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
1489
- else:
1490
- idx = idx[:chunk_size]
1491
- image_idx.extend(idx.tolist())
1492
- if not is_filter:
1493
- if is_random:
1494
- image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
1495
- else:
1496
- image_idx = list(range(num_images))
1497
- key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
1498
- images = [dataset[i][key] for i in image_idx]
1499
- gr.Info(f"Loaded {len(images)} images from {dataset_name}")
1500
- del dataset
1501
-
1502
- if dataset_name in CENTER_CROP_DATASETS:
1503
- def center_crop_image(img):
1504
- # image: PIL image
1505
- w, h = img.size
1506
- min_hw = min(h, w)
1507
- # center crop
1508
- left = (w - min_hw) // 2
1509
- top = (h - min_hw) // 2
1510
- right = left + min_hw
1511
- bottom = top + min_hw
1512
- img = img.crop((left, top, right, bottom))
1513
- return img
1514
- images = [center_crop_image(image) for image in images]
1515
-
1516
- return images
1517
-
1518
- def load_and_append(existing_images, *args, **kwargs):
1519
- new_images = load_dataset_images(*args, **kwargs)
1520
- if new_images is None:
1521
- return existing_images
1522
- if len(new_images) == 0:
1523
- return existing_images
1524
- if existing_images is None:
1525
- existing_images = []
1526
- existing_images += new_images
1527
- gr.Info(f"Total images: {len(existing_images)}")
1528
- return existing_images
1529
 
1530
  load_images_button.click(load_and_append,
1531
  inputs=[input_gallery, advanced_radio, dataset_dropdown, num_images_slider,
@@ -1864,7 +1891,7 @@ with demo:
1864
  with gr.Row():
1865
  with gr.Column(scale=5, min_width=200):
1866
  input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
1867
- num_images_slider.value = 30
1868
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=20)
1869
 
1870
  with gr.Column(scale=5, min_width=200):
@@ -1881,7 +1908,7 @@ with demo:
1881
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1882
  sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
1883
  ] = make_parameters_section()
1884
- num_eig_slider.value = 30
1885
 
1886
  false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1887
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
 
63
  'Pose': [
64
  ('sayakpaul/poses-controlnet-dataset', None),
65
  ('razdab/sign_pose_M', None),
66
+ ('Marqo/deepfashion-multimodal', None),
67
  ('Fiacre/small-animal-poses-controlnet-dataset', None),
68
  ('junjuice0/vtuber-tachi-e', None),
69
  ],
 
78
  ('efoley/sar_tile_512', None),
79
  ],
80
  'Medical': [
81
+ ('Mahadih534/Chest_CT-Scan_images-Dataset', None),
 
 
82
  ('TrainingDataPro/chest-x-rays', None),
83
  ('hongrui/mimic_chest_xray_v_1', None),
84
+ ('sartajbhuvaji/Brain-Tumor-Classification', 4),
85
+ ('Falah/Alzheimer_MRI', 4),
86
  ('Leonardo6/path-vqa', None),
87
  ('Itsunori/path-vqa_jap', None),
88
  ('ruby-jrl/isic-2024-2', None),
 
96
  ('jlbaker361/dcgan-eval-creative_gan_256_256', None),
97
  ('Francesco/csgo-videogame', None),
98
  ('Francesco/apex-videogame', None),
 
99
  ('huggan/pokemon', None),
100
  ('huggan/few-shot-universe', None),
101
  ('huggan/flowers-102-categories', None),
 
1282
  return input_gallery, submit_button, clear_images_button, max_frames_number
1283
 
1284
 
1285
+ def load_dataset_images(is_advanced, dataset_name, num_images=10,
1286
+ is_filter=False, filter_by_class_text="0,1,2",
1287
+ is_random=False, seed=1):
1288
+ progress = gr.Progress()
1289
+ progress(0, desc="Loading Images")
1290
+
1291
+ if dataset_name == "EgoExo":
1292
+ is_advanced = "Basic"
1293
+
1294
+ if is_advanced == "Basic":
1295
+ gr.Info(f"Loaded images from EgoExo")
1296
+ return default_images
1297
+ try:
1298
+ progress(0.5, desc="Downloading Dataset")
1299
+ if 'EgoThink' in dataset_name:
1300
+ dataset = load_dataset(dataset_name, 'Activity', trust_remote_code=True)
1301
+ else:
1302
+ dataset = load_dataset(dataset_name, trust_remote_code=True)
1303
+ key = list(dataset.keys())[0]
1304
+ dataset = dataset[key]
1305
+ except Exception as e:
1306
+ raise gr.Error(f"Error loading dataset {dataset_name}: {e}")
1307
+ if num_images > len(dataset):
1308
+ num_images = len(dataset)
1309
+
1310
+ if len(filter_by_class_text) == 0:
1311
+ is_filter = False
1312
+
1313
+ if is_filter:
1314
+ progress(0.8, desc="Filtering Images")
1315
+ classes = [int(i) for i in filter_by_class_text.split(",")]
1316
+ labels = np.array(dataset['label'])
1317
+ unique_labels = np.unique(labels)
1318
+ valid_classes = [i for i in classes if i in unique_labels]
1319
+ invalid_classes = [i for i in classes if i not in unique_labels]
1320
+ if len(invalid_classes) > 0:
1321
+ gr.Warning(f"Classes {invalid_classes} not found in the dataset.")
1322
+ if len(valid_classes) == 0:
1323
+ raise gr.Error(f"Classes {classes} not found in the dataset.")
1324
+ # shuffle each class
1325
+ chunk_size = num_images // len(valid_classes)
1326
+ image_idx = []
1327
+ for i in valid_classes:
1328
+ idx = np.where(labels == i)[0]
1329
+ if is_random:
1330
+ idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
1331
+ else:
1332
+ idx = idx[:chunk_size]
1333
+ image_idx.extend(idx.tolist())
1334
+ if not is_filter:
1335
+ if is_random:
1336
+ image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
1337
+ else:
1338
+ image_idx = list(range(num_images))
1339
+ key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
1340
+ images = [dataset[i][key] for i in image_idx]
1341
+ gr.Info(f"Loaded {len(images)} images from {dataset_name}")
1342
+ del dataset
1343
+
1344
+ if dataset_name in CENTER_CROP_DATASETS:
1345
+ def center_crop_image(img):
1346
+ # image: PIL image
1347
+ w, h = img.size
1348
+ min_hw = min(h, w)
1349
+ # center crop
1350
+ left = (w - min_hw) // 2
1351
+ top = (h - min_hw) // 2
1352
+ right = left + min_hw
1353
+ bottom = top + min_hw
1354
+ img = img.crop((left, top, right, bottom))
1355
+ return img
1356
+ images = [center_crop_image(image) for image in images]
1357
+
1358
+ return images
1359
+
1360
+ def load_and_append(existing_images, *args, **kwargs):
1361
+ new_images = load_dataset_images(*args, **kwargs)
1362
+ if new_images is None:
1363
+ return existing_images
1364
+ if len(new_images) == 0:
1365
+ return existing_images
1366
+ if existing_images is None:
1367
+ existing_images = []
1368
+ existing_images += new_images
1369
+ gr.Info(f"Total images: {len(existing_images)}")
1370
+ return existing_images
1371
+
1372
  def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False):
1373
  gr.Markdown('### Input Images')
1374
+ input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
1375
+ format="webp")
1376
 
1377
  submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary')
1378
  with gr.Row():
 
1399
  create_file_button, download_button = add_download_button(input_gallery, "input_images")
1400
 
1401
  gr.Markdown('### Load Datasets')
1402
+ advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets Menu", value="Advanced" if advanced else "Basic", elem_id="advanced-radio", show_label=True)
 
1403
  with gr.Column() as basic_block:
1404
+ # gr.Markdown('### Example Image Sets')
1405
+ def make_example(name, images, dataset_name):
1406
+ with gr.Row():
1407
+ button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60)
1408
+ gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140)
1409
+ button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, 100, is_random=True, seed=42)), outputs=[input_gallery])
1410
+ return gallery, button
1411
+ example_items = [
1412
+ ("EgoExo", ['./images/egoexo1.jpg', './images/egoexo3.jpg', './images/egoexo2.jpg'], "EgoExo"),
1413
+ ("Ego", ['./images/egothink1.jpg', './images/egothink2.jpg', './images/egothink3.jpg'], "EgoThink/EgoThink"),
1414
+ ("Face", ['./images/face1.jpg', './images/face2.jpg', './images/face3.jpg'], "nielsr/CelebA-faces"),
1415
+ ("Pose", ['./images/pose1.jpg', './images/pose2.jpg', './images/pose3.jpg'], "sayakpaul/poses-controlnet-dataset"),
1416
+ # ("CatDog", ['./images/catdog1.jpg', './images/catdog2.jpg', './images/catdog3.jpg'], "microsoft/cats_vs_dogs"),
1417
+ # ("Bird", ['./images/bird1.jpg', './images/bird2.jpg', './images/bird3.jpg'], "Multimodal-Fatima/CUB_train"),
1418
+ # ("ChestXray", ['./images/chestxray1.jpg', './images/chestxray2.jpg', './images/chestxray3.jpg'], "hongrui/mimic_chest_xray_v_1"),
1419
+ ("BrainMRI", ['./images/brain1.jpg', './images/brain2.jpg', './images/brain3.jpg'], "sartajbhuvaji/Brain-Tumor-Classification"),
1420
+ ("Kanji", ['./images/kanji1.jpg', './images/kanji2.jpg', './images/kanji3.jpg'], "yashvoladoddi37/kanjienglish"),
1421
+ ]
1422
+ for name, images, dataset_name in example_items:
1423
+ make_example(name, images, dataset_name)
1424
  with gr.Column() as advanced_block:
1425
+ load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary')
1426
  # dataset_names = DATASET_NAMES
1427
  # dataset_classes = DATASET_CLASSES
1428
  dataset_categories = list(DATASETS.keys())
 
1451
  random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True)
1452
 
1453
  # add functionality, save and load images to profile
1454
+ with gr.Accordion("Saved Image Profiles", open=False) as profile_accordion:
1455
  with gr.Row():
1456
  profile_text = gr.Textbox(label="Profile name", placeholder="Type here: Profile name to save/load/delete", elem_id="profile-name", scale=6, show_label=False)
1457
  list_profiles_button = gr.Button("📋 List", elem_id="list-profile-button", variant='secondary', scale=3)
 
1552
  return gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=is_random)
1553
  is_random_checkbox.change(fn=change_random_seed, inputs=is_random_checkbox, outputs=random_seed_slider)
1554
 
1555
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1556
 
1557
  load_images_button.click(load_and_append,
1558
  inputs=[input_gallery, advanced_radio, dataset_dropdown, num_images_slider,
 
1891
  with gr.Row():
1892
  with gr.Column(scale=5, min_width=200):
1893
  input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
1894
+ num_images_slider.value = 100
1895
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=20)
1896
 
1897
  with gr.Column(scale=5, min_width=200):
 
1908
  perplexity_slider, n_neighbors_slider, min_dist_slider,
1909
  sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
1910
  ] = make_parameters_section()
1911
+ num_eig_slider.value = 100
1912
 
1913
  false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
1914
  no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
images/bird1.jpg ADDED
images/bird2.jpg ADDED
images/bird3.jpg ADDED
images/brain1.jpg ADDED
images/brain2.jpg ADDED
images/brain3.jpg ADDED
images/catdog1.jpg ADDED
images/catdog2.jpg ADDED
images/catdog3.jpg ADDED
images/chestxray1.jpg ADDED
images/chestxray2.jpg ADDED
images/chestxray3.jpg ADDED
images/egoexo1.jpg ADDED
images/egoexo2.jpg ADDED
images/egoexo3.jpg ADDED
images/egothink1.jpg ADDED
images/egothink2.jpg ADDED
images/egothink3.jpg ADDED
images/face1.jpg ADDED
images/face2.jpg ADDED
images/face3.jpg ADDED
images/image(1).jpg ADDED
images/image(2).jpg ADDED
images/image(3).jpg ADDED
images/imagenet1.jpg ADDED
images/imagenet2.jpg ADDED
images/imagenet3.jpg ADDED
images/kanji1.jpg ADDED
images/kanji2.jpg ADDED
images/kanji3.jpg ADDED
images/pose1.jpg ADDED
images/pose2.jpg ADDED
images/pose3.jpg ADDED