Spaces:
Runtime error
Runtime error
import random | |
import gradio as gr | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from torch import nn | |
from torchvision.models import mobilenet_v2, resnet18 | |
from torchvision.transforms.functional import InterpolationMode | |
datasets_n_classes = { | |
"Imagenette": 10, | |
"Imagewoof": 10, | |
"Stanford_dogs": 120, | |
} | |
datasets_model_types = { | |
"Imagenette": [ | |
"base_200", | |
"base_200+100", | |
"synthetic_200", | |
"augment_noisy_200", | |
"augment_noisy_200+100", | |
"augment_clean_200", | |
], | |
"Imagewoof": [ | |
"base_200", | |
"base_200+100", | |
"synthetic_200", | |
"augment_noisy_200", | |
"augment_noisy_200+100", | |
"augment_clean_200", | |
], | |
"Stanford_dogs": [ | |
"base_200", | |
"base_200+100", | |
"synthetic_200", | |
"augment_noisy_200", | |
"augment_noisy_200+100", | |
], | |
} | |
model_arch = ["resnet18", "mobilenet_v2"] | |
list_200 = [ | |
"Original", | |
"Synthetic", | |
"Original + Synthetic (Noisy)", | |
"Original + Synthetic (Clean)", | |
] | |
list_200_100 = ["Base+100", "AugmentNoisy+100"] | |
methods_map = { | |
"200 Epochs": list_200, | |
"200 Epochs on Original + 100": list_200_100, | |
} | |
label_map = dict() | |
label_map["Imagenette (10 classes)"] = "Imagenette" | |
label_map["Imagewoof (10 classes)"] = "Imagewoof" | |
label_map["Stanford Dogs (120 classes)"] = "Stanford_dogs" | |
label_map["ResNet-18"] = "resnet18" | |
label_map["MobileNetV2"] = "mobilenet_v2" | |
label_map["200 Epochs"] = "200" | |
label_map["200 Epochs on Original + 100"] = "200+100" | |
label_map["Original"] = "base" | |
label_map["Synthetic"] = "synthetic" | |
label_map["Original + Synthetic (Noisy)"] = "augment_noisy" | |
label_map["Original + Synthetic (Clean)"] = "augment_clean" | |
label_map["Base+100"] = "base" | |
label_map["AugmentNoisy+100"] = "augment_noisy" | |
dataset_models = dict() | |
for dataset, n_classes in datasets_n_classes.items(): | |
models = dict() | |
for model_type in datasets_model_types[dataset]: | |
for arch in model_arch: | |
if arch == "resnet18": | |
model = resnet18(weights=None, num_classes=n_classes) | |
models[f"{arch}_{model_type}"] = ( | |
model, | |
f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", | |
) | |
elif arch == "mobilenet_v2": | |
model = mobilenet_v2(weights=None, num_classes=n_classes) | |
models[f"{arch}_{model_type}"] = ( | |
model, | |
f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", | |
) | |
else: | |
raise ValueError(f"Model architecture unavailable: {arch}") | |
dataset_models[dataset] = models | |
def get_random_image(dataset, label_map=label_map) -> Image: | |
dataset_root = f"./data/{label_map[dataset]}/val" | |
dataset_img = torchvision.datasets.ImageFolder( | |
dataset_root, | |
transforms.Compose([transforms.PILToTensor()]), | |
) | |
random_idx = random.randint(0, len(dataset_img) - 1) | |
image, _ = dataset_img[random_idx] | |
image = transforms.ToPILImage()(image) | |
image = image.resize( | |
(256, 256), | |
) | |
return image | |
def load_model(model_dict, model_name: str) -> nn.Module: | |
model_name_lower = model_name.lower() | |
if model_name_lower in model_dict: | |
model = model_dict[model_name_lower][0] | |
model_path = model_dict[model_name_lower][1] | |
if torch.cuda.is_available(): | |
checkpoint = torch.load(model_path) | |
else: | |
checkpoint = torch.load(model_path, map_location="cpu") | |
if "setup" in checkpoint: | |
if checkpoint["setup"]["distributed"]: | |
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( | |
checkpoint["model"], "module." | |
) | |
model.load_state_dict(checkpoint["model"]) | |
else: | |
model.load_state_dict(checkpoint) | |
return model | |
else: | |
raise ValueError( | |
f"Model {model_name} is not available for image prediction. Please choose from {[name.capitalize() for name in model_dict.keys()]}." | |
) | |
def postprocess_default(labels, output) -> dict: | |
probabilities = nn.functional.softmax(output[0], dim=0) | |
top_prob, top_catid = torch.topk(probabilities, 5) | |
confidences = { | |
labels[top_catid.tolist()[i]]: top_prob.tolist()[i] | |
for i in range(top_prob.shape[0]) | |
} | |
return confidences | |
def classify( | |
input_image: Image, | |
dataset_type: str, | |
arch_type: str, | |
methods: str, | |
training_ds: str, | |
dataset_models=dataset_models, | |
label_map=label_map, | |
) -> dict: | |
for i in [dataset_type, arch_type, methods, training_ds]: | |
if i is None: | |
raise ValueError("Please select all options.") | |
dataset_type = label_map[dataset_type] | |
arch_type = label_map[arch_type] | |
methods = label_map[methods] | |
training_ds = label_map[training_ds] | |
preprocess_input = transforms.Compose( | |
[ | |
transforms.Resize( | |
256, | |
interpolation=InterpolationMode.BILINEAR, | |
antialias=True, | |
), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
if input_image is None: | |
raise ValueError("No image was provided.") | |
input_tensor: torch.Tensor = preprocess_input(input_image) | |
input_batch = input_tensor.unsqueeze(0) | |
model = load_model( | |
dataset_models[dataset_type], f"{arch_type}_{training_ds}_{methods}" | |
) | |
if torch.cuda.is_available(): | |
input_batch = input_batch.to("cuda") | |
model.to("cuda") | |
model.eval() | |
with torch.inference_mode(): | |
output: torch.Tensor = model(input_batch) | |
with open(f"./data/{dataset_type}.txt", "r") as f: | |
labels = {i: line.strip() for i, line in enumerate(f.readlines())} | |
return postprocess_default(labels, output) | |
def update_methods(method, ds_type): | |
if ds_type == "Stanford Dogs (120 classes)" and method == "200 Epochs": | |
methods = list_200[:-1] | |
else: | |
methods = methods_map[method] | |
return gr.update(choices=methods, value=None) | |
def downloadModel( | |
dataset_type, arch_type, methods, training_ds, dataset_models=dataset_models | |
): | |
for i in [dataset_type, arch_type, methods, training_ds]: | |
if i is None: | |
return gr.update(label="Select Model", value=None) | |
dataset_type = label_map[dataset_type] | |
arch_type = label_map[arch_type] | |
methods = label_map[methods] | |
training_ds = label_map[training_ds] | |
if f"{arch_type}_{training_ds}_{methods}" not in dataset_models[dataset_type]: | |
return gr.update(label="Select Model", value=None) | |
model_path = dataset_models[dataset_type][f"{arch_type}_{training_ds}_{methods}"][1] | |
return gr.update( | |
label=f"Download Model: '{dataset_type}_{arch_type}_{training_ds}_{methods}'", | |
value=model_path, | |
) | |
if __name__ == "__main__": | |
with gr.Blocks(title="Generative Augmented Image Classifiers") as demo: | |
gr.Markdown( | |
""" | |
# Generative Augmented Image Classifiers | |
Main GitHub Repo: [Generative Data Augmentation](https://github.com/zhulinchng/generative-data-augmentation) | Generative Data Augmentation Demo: [Generative Data Augmented](https://huggingface.co/spaces/czl/generative-data-augmentation-demo). | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
dataset_type = gr.Radio( | |
choices=[ | |
"Imagenette (10 classes)", | |
"Imagewoof (10 classes)", | |
"Stanford Dogs (120 classes)", | |
], | |
label="Dataset", | |
value="Imagenette (10 classes)", | |
) | |
arch_type = gr.Radio( | |
choices=["ResNet-18", "MobileNetV2"], | |
label="Model Architecture", | |
value="ResNet-18", | |
interactive=True, | |
) | |
methods = gr.Radio( | |
label="Methods", | |
choices=["200 Epochs", "200 Epochs on Original + 100"], | |
interactive=True, | |
value="200 Epochs", | |
) | |
training_ds = gr.Radio( | |
label="Training Dataset", | |
choices=methods_map["200 Epochs"], | |
interactive=True, | |
value="Original", | |
) | |
dataset_type.change( | |
fn=update_methods, | |
inputs=[methods, dataset_type], | |
outputs=[training_ds], | |
) | |
methods.change( | |
fn=update_methods, | |
inputs=[methods, dataset_type], | |
outputs=[training_ds], | |
) | |
random_image_output = gr.Image(type="pil", label="Image to Classify") | |
with gr.Row(): | |
generate_button = gr.Button("Sample Random Image") | |
classify_button_random = gr.Button("Classify") | |
with gr.Column(): | |
output_label_random = gr.Label(num_top_classes=5) | |
download_model = gr.DownloadButton( | |
label=f"Download Model: '{label_map[dataset_type.value]}_{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}'", | |
value=dataset_models[label_map[dataset_type.value]][ | |
f"{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}" | |
][1], | |
) | |
dataset_type.change( | |
fn=downloadModel, | |
inputs=[dataset_type, arch_type, methods, training_ds], | |
outputs=[download_model], | |
) | |
arch_type.change( | |
fn=downloadModel, | |
inputs=[dataset_type, arch_type, methods, training_ds], | |
outputs=[download_model], | |
) | |
methods.change( | |
fn=downloadModel, | |
inputs=[dataset_type, arch_type, methods, training_ds], | |
outputs=[download_model], | |
) | |
training_ds.change( | |
fn=downloadModel, | |
inputs=[dataset_type, arch_type, methods, training_ds], | |
outputs=[download_model], | |
) | |
gr.Markdown( | |
""" | |
This demo showcases the performance of image classifiers trained on various datasets as part of the project 'Improving Fine-Grained Image Classification Using Diffusion-Based Generated Synthetic Images' dissertation. | |
View the models and files used in this demo [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/tree/main). | |
Usage Instructions & Documentation [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/blob/main/README.md). | |
""" | |
) | |
generate_button.click( | |
get_random_image, | |
inputs=[dataset_type], | |
outputs=random_image_output, | |
) | |
classify_button_random.click( | |
classify, | |
inputs=[random_image_output, dataset_type, arch_type, methods, training_ds], | |
outputs=output_label_random, | |
) | |
demo.launch(show_error=True) | |