cods-live / app.py
Léo Andéol
YANF
996fc32 unverified
import gradio as gr
import timm
import torch
from cods.classif.cp import ClassificationConformalizer
from cods.classif.data import ClassificationDataset
from cods.classif.data.predictions import ClassificationPredictions
from cods.classif.models import ClassificationModel
from datasets import load_dataset
# from ultralytics import YOLO
from PIL import Image
# from transformers import AutoImageProcessor, AutoModelForImageClassification
from dataset import DatasetWrapper
DATASETS = {
"miniimagenet": "timm/mini-imagenet",
"imagenette": "frgfm/imagenette",
"imagenet": "imagenet-1k",
}
MODELS = {
"miniimagenet": [
"QuentinJG/ResNet18-miniimagenet",
"shahrukhx01/vit-base-patch16-miniimagenet",
],
}
classification_conformalizer = ClassificationConformalizer(method="lac", preprocess="softmax")
def calibrate(dataset_name, model_name):
global model
# processor = AutoImageProcessor.from_pretrained(model_name)
# model = AutoModelForImageClassification.from_pretrained(model_name)
# model = #lambda x: model(processor(x))#**processor(x, return_tensors="pt"))
# model = timm.create_model(model_name, pretrained=True)#, num_classes=100)
model_name = "resnet34"
global pretrained_resnet_34
pretrained_resnet_34 = timm.create_model(model_name, pretrained=True)
classifier = ClassificationModel(model=pretrained_resnet_34, model_name=model_name)
global dataset
dataset = load_dataset(DATASETS[dataset_name], split="validation")
dataset = DatasetWrapper(dataset)
val_preds = classifier.build_predictions(
dataset,
dataset_name=dataset_name,
split_name="cal",
batch_size=512,
shuffle=False,
)
classification_conformalizer.calibrate(val_preds, alpha=0.1)
return f"Calibrated on {dataset_name} with model {model_name}"
def predict_image(img):
img_old = img.copy()
img = dataset.transforms(img).unsqueeze(0)
pred = pretrained_resnet_34(img)
inference_pred = ClassificationPredictions(
dataset_name="uploaded",
split_name="test",
image_paths=[None],
idx_to_cls=dataset.idx_to_cls,
true_cls=torch.tensor([-1]), # Placeholder for true class
pred_cls=pred, # Placeholder for predicted class probabilities
)
result = classification_conformalizer.conformalize(inference_pred)
list_of_classes = [dataset.idx_to_cls[i] for i in result[0].detach().numpy()]
result = f"Predicted classes with 90% confidence: {list_of_classes}"
return img_old, result
# Load a pretrained YOLOv8n model
# model = YOLO("yolov8n.pt")
def main_function(lbd, img):
# results = model(img) # predict on an image
# r = results[0]
# im_bgr = r.plot() # BGR-order numpy array
# im_rgb = Image.fromarray(im_bgr[..., ::-1]) # RGB-order PIL image
# new_img = im_rgb
# res = results[0].save(filename="output.jpg") # save the image
# # load image
# new_img = Image.open("output.jpg")
new_img = img
return new_img
with gr.Blocks() as demo:
gr.Markdown("# Image Classification with Conformal Prediction")
gr.Markdown("## Upload an image and get conformalized classification predictions.")
with gr.Row():
dataset_dropdown = gr.Dropdown(
choices=DATASETS.keys(), label="Select Dataset", value=list(DATASETS.keys())[0]
)
model_dropdown = gr.Dropdown(
choices=MODELS[dataset_dropdown.value],
label="Select Model",
value=MODELS[dataset_dropdown.value][0],
)
calibrate_btn = gr.Button("Calibrate")
status_text = gr.Textbox(label="Status", interactive=False)
gr.Markdown("---")
with gr.Row():
input_image = gr.Image(label="Upload Image", type="pil")
output_image = gr.Image(label="Processed Image")
predict_btn = gr.Button("Predict")
result_text = gr.Textbox(label="Prediction Result")
# Connect components
calibrate_btn.click(
fn=calibrate, inputs=[dataset_dropdown, model_dropdown], outputs=status_text
)
predict_btn.click(fn=predict_image, inputs=input_image, outputs=[output_image, result_text])
demo.launch()