|  | import gradio as gr | 
					
						
						|  | import timm | 
					
						
						|  | import torch | 
					
						
						|  | from cods.classif.cp import ClassificationConformalizer | 
					
						
						|  | from cods.classif.data import ClassificationDataset | 
					
						
						|  | from cods.classif.data.predictions import ClassificationPredictions | 
					
						
						|  | from cods.classif.models import ClassificationModel | 
					
						
						|  | from datasets import load_dataset | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from dataset import DatasetWrapper | 
					
						
						|  |  | 
					
						
						|  | DATASETS = { | 
					
						
						|  | "miniimagenet": "timm/mini-imagenet", | 
					
						
						|  | "imagenette": "frgfm/imagenette", | 
					
						
						|  | "imagenet": "imagenet-1k", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | MODELS = { | 
					
						
						|  | "miniimagenet": [ | 
					
						
						|  | "QuentinJG/ResNet18-miniimagenet", | 
					
						
						|  | "shahrukhx01/vit-base-patch16-miniimagenet", | 
					
						
						|  | ], | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | classification_conformalizer = ClassificationConformalizer(method="lac", preprocess="softmax") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def calibrate(dataset_name, model_name): | 
					
						
						|  | global model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_name = "resnet34" | 
					
						
						|  | global pretrained_resnet_34 | 
					
						
						|  | pretrained_resnet_34 = timm.create_model(model_name, pretrained=True) | 
					
						
						|  | classifier = ClassificationModel(model=pretrained_resnet_34, model_name=model_name) | 
					
						
						|  | global dataset | 
					
						
						|  | dataset = load_dataset(DATASETS[dataset_name], split="validation") | 
					
						
						|  | dataset = DatasetWrapper(dataset) | 
					
						
						|  |  | 
					
						
						|  | val_preds = classifier.build_predictions( | 
					
						
						|  | dataset, | 
					
						
						|  | dataset_name=dataset_name, | 
					
						
						|  | split_name="cal", | 
					
						
						|  | batch_size=512, | 
					
						
						|  | shuffle=False, | 
					
						
						|  | ) | 
					
						
						|  | classification_conformalizer.calibrate(val_preds, alpha=0.1) | 
					
						
						|  | return f"Calibrated on {dataset_name} with model {model_name}" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def predict_image(img): | 
					
						
						|  | img_old = img.copy() | 
					
						
						|  | img = dataset.transforms(img).unsqueeze(0) | 
					
						
						|  | pred = pretrained_resnet_34(img) | 
					
						
						|  | inference_pred = ClassificationPredictions( | 
					
						
						|  | dataset_name="uploaded", | 
					
						
						|  | split_name="test", | 
					
						
						|  | image_paths=[None], | 
					
						
						|  | idx_to_cls=dataset.idx_to_cls, | 
					
						
						|  | true_cls=torch.tensor([-1]), | 
					
						
						|  | pred_cls=pred, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | result = classification_conformalizer.conformalize(inference_pred) | 
					
						
						|  | list_of_classes = [dataset.idx_to_cls[i] for i in result[0].detach().numpy()] | 
					
						
						|  | result = f"Predicted classes with 90% confidence: {list_of_classes}" | 
					
						
						|  | return img_old, result | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main_function(lbd, img): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | new_img = img | 
					
						
						|  | return new_img | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with gr.Blocks() as demo: | 
					
						
						|  | gr.Markdown("# Image Classification with Conformal Prediction") | 
					
						
						|  | gr.Markdown("## Upload an image and get conformalized classification predictions.") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | dataset_dropdown = gr.Dropdown( | 
					
						
						|  | choices=DATASETS.keys(), label="Select Dataset", value=list(DATASETS.keys())[0] | 
					
						
						|  | ) | 
					
						
						|  | model_dropdown = gr.Dropdown( | 
					
						
						|  | choices=MODELS[dataset_dropdown.value], | 
					
						
						|  | label="Select Model", | 
					
						
						|  | value=MODELS[dataset_dropdown.value][0], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | calibrate_btn = gr.Button("Calibrate") | 
					
						
						|  | status_text = gr.Textbox(label="Status", interactive=False) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("---") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | input_image = gr.Image(label="Upload Image", type="pil") | 
					
						
						|  | output_image = gr.Image(label="Processed Image") | 
					
						
						|  |  | 
					
						
						|  | predict_btn = gr.Button("Predict") | 
					
						
						|  | result_text = gr.Textbox(label="Prediction Result") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | calibrate_btn.click( | 
					
						
						|  | fn=calibrate, inputs=[dataset_dropdown, model_dropdown], outputs=status_text | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | predict_btn.click(fn=predict_image, inputs=input_image, outputs=[output_image, result_text]) | 
					
						
						|  |  | 
					
						
						|  | demo.launch() | 
					
						
						|  |  |