import io import random from io import BytesIO from typing import List, Tuple import aiohttp import panel as pn import torch from bokeh.themes import Theme # import torchvision.transforms.functional as TVF import torch.nn.functional as F from PIL import Image from transformers import AutoImageProcessor, ResNetForImageClassification from transformers.image_transforms import to_pil_image DEVICE = "cpu" pn.extension("mathjax", design="bootstrap", sizing_mode="stretch_width") @pn.cache def load_processor_model( processor_name: str, model_name: str ) -> Tuple[AutoImageProcessor, ResNetForImageClassification]: processor = AutoImageProcessor.from_pretrained(processor_name) model = ResNetForImageClassification.from_pretrained(model_name) return processor, model def denormalize(image, mean, std): mean = torch.tensor(mean).view(1, -1, 1, 1) # Reshape for broadcasting std = torch.tensor(std).view(1, -1, 1, 1) return image * std + mean # FGSM attack code def fgsm_attack(image, epsilon, data_grad): # Collect the element-wise sign of the data gradient sign_data_grad = data_grad.sign() # Create the perturbed image by adjusting each pixel of the input image perturbed_image = image + epsilon * sign_data_grad # Adding clipping to maintain [0,1] range perturbed_image = torch.clamp(perturbed_image, 0, 1) # Return the perturbed image return perturbed_image.detach() def run_forward_backward(image: Image, epsilon): processor, model = load_processor_model( "microsoft/resnet-18", "microsoft/resnet-18" ) # Grab input processor.crop_pct = 1 input_tensor = processor(image, return_tensors="pt")["pixel_values"] input_tensor.requires_grad_(True) # Run inference output = model(input_tensor) output = output.logits # Top target top_pred = output.max(1, keepdim=False)[1] # Get NLL loss and backward loss = F.cross_entropy(output, top_pred) model.zero_grad() loss.backward() # Denormalize input mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1) std = torch.tensor(processor.image_std).view(1, -1, 1, 1) input_tensor_denorm = input_tensor.clone().detach() * std + mean # Add noise to input random_noise = torch.sign(torch.randn_like(input_tensor)) * 0.02 input_tensor_denorm_noised = torch.clamp(input_tensor_denorm + random_noise, 0, 1) # input_tensor_denorm_noised = input_tensor_denorm # FGSM attack adv_input_tensor_denorm = fgsm_attack( image=input_tensor_denorm_noised, epsilon=epsilon,, ) # Normalize adversarial input tensor back to the input range adv_input_tensor = (adv_input_tensor_denorm - mean) / std # Inference on adversarial image adv_output = model(adv_input_tensor) adv_output = adv_output.logits return ( output, adv_output, input_tensor_denorm.squeeze(), adv_input_tensor_denorm.squeeze(), ) async def process_inputs(button_event, image_data: bytes, epsilon: float): """ High level function that takes in the user inputs and returns the classification results as panel objects. """ try: main.disabled = True # if not button_event or (button_event and not isinstance(image_data, bytes)): if not isinstance(image_data, bytes): yield "##### 👋 Upload an image to proceed" return yield "##### ⚙ Fetching image and running model..." try: # Open the image using PIL pil_img = # Run forward + FGSM clean_logits, adv_logits, input_tensor, adv_input_tensor = ( run_forward_backward(image=pil_img, epsilon=epsilon) ) except Exception as e: yield f"##### Something went wrong, please try a different image! \n {e}" return img = pn.pane.Image( to_pil_image(input_tensor, do_rescale=True), height=300, align="center", ) # Convert image for visualizing adv_img_pil = to_pil_image(adv_input_tensor, do_rescale=True) adv_img = pn.pane.Image( adv_img_pil, height=300, align="center", ) # Download image button adv_img_bytes = io.BytesIO(), format="PNG") # download = pn.widgets.FileDownload( # to_pil_image(adv_img_bytes, do_rescale=True), # embed=True, # filename="adv_img.png", # button_type="primary", # button_style="outline", # width_policy="min", # ) # Build the results column k_val = 5 results = pn.Column( pn.Row("###### Uploaded", "###### Adversarial"), pn.Row(img, adv_img), # pn.Row(pn.Spacer(), download), f" ###### Top {k_val} class predictions", ) # Get likelihoods likelihoods = [ F.softmax(clean_logits, dim=1).squeeze(), F.softmax(adv_logits, dim=1).squeeze(), ] label_bars_rows = pn.Row() for likelihood_tensor in likelihoods: # Get top k values and indices vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) label_bars = pn.Column() for idx, val in zip(idx_topk_clean, vals_topk_clean): prob = val.item() row_label = pn.widgets.StaticText( name=f"{classes[idx]}", value=f"{prob:.2%}", align="center" ) row_bar = pn.indicators.Progress( value=int(prob * 100), sizing_mode="stretch_width", bar_color="success" if prob > 0.7 else "warning", # Dynamic color based on value margin=(0, 10), design=pn.theme.Material, ) label_bars.append(pn.Column(row_label, row_bar)) # for likelihood_tensor in likelihoods: # # Get top # vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) # label_bars = pn.Column() # for idx, val in zip(idx_topk_clean, vals_topk_clean): # prob = val.item() # row_label = pn.widgets.StaticText( # name=f"{classes[idx]}", value=f"{prob:.2%}", align="center" # ) # row_bar = pn.indicators.Progress( # value=int(prob * 100), # sizing_mode="stretch_width", # bar_color="secondary", # margin=(0, 10), # design=pn.theme.Material, # ) # label_bars.append(pn.Column(row_label, row_bar)) label_bars_rows.append(label_bars) results.append(label_bars_rows) yield results except Exception as e: yield f"##### Something went wrong! \n {e}" return finally: main.disabled = False #################################################################################################################################### # Get classes classes = [] with open("classes.txt", "r") as file: classes = classes = classes.split("\n") # Create widgets ############################################ # Fil upload widget file_input = pn.widgets.FileInput(name="Upload a PNG image", accept=".png,.jpg") # Epsilon epsilon_slider = pn.widgets.FloatSlider( name=r"$$\epsilon$$ parameter for FGSM", start=0, end=0.1, step=0.005, value=0.005, format="1[.]000", align="center", max_width=500, width_policy="max", ) # alpha_slider = pn.widgets.FloatSlider( # name=r"$$\alpha$$ parameter for Gaussian noise", # start=0, # end=0.1, # step=0.005, # value=0.000, # format="1[.]000", # align="center", # max_width=500, # width_policy="max" # ) # Regenerate button regenerate = pn.widgets.Button( name="Regenerate", button_type="primary", width_policy="min", max_width=105, ) ############################################ # Organize widgets in a column input_widgets = pn.Column( """ ###### Classify an image (png/jpeg) with a pre-trained [ResNet18]( and generate an adversarial example.\n Wondering where the class names come from? Find the list of ImageNet-1K classes [here.]( *Please be patient with the application, it is running on a low-resource device.* """, file_input, pn.Row(epsilon_slider, pn.Spacer(width_policy="min", max_width=25), regenerate), ) # Add interactivity interactive_result = pn.panel( pn.bind( process_inputs, regenerate, file_input.param.value, epsilon_slider.param.value, ), height=600, ) footer = pn.pane.Markdown( """ If the application is too slow for you, head over to the README to get this running locally.

If the application is too slow for you, head over to the README to get this running locally. """ ) # Create dashboard main = pn.WidgetBox( input_widgets, interactive_result, footer, ) title = "Adversarial Sample Generation" pn.template.BootstrapTemplate( title=title, main=main, main_max_width="min(75%, 698px)", header_background="#101820", ).servable(title=title) # Functions from original demo # ICON_URLS = { # "brand-github": "", # "brand-twitter": "", # "brand-linkedin": "", # "message-circle": "", # "brand-discord": "", # } # async def random_url(_): # pet = random.choice(["cat", "dog"]) # api_url = f"https://api.the{pet}" # async with aiohttp.ClientSession() as session: # async with session.get(api_url) as resp: # return (await resp.json())[0]["url"] # @pn.cache # def load_processor_model( # processor_name: str, model_name: str # ) -> Tuple[CLIPProcessor, CLIPModel]: # processor = CLIPProcessor.from_pretrained(processor_name) # model = CLIPModel.from_pretrained(model_name) # return processor, model # async def open_image_url(image_url: str) -> Image: # async with aiohttp.ClientSession() as session: # async with session.get(image_url) as resp: # return # def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: # processor, model = load_processor_model( # "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" # ) # inputs = processor( # text=class_items, # images=[image], # return_tensors="pt", # pytorch tensors # ) # print(inputs) # outputs = model(**inputs) # logits_per_image = outputs.logits_per_image # class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() # return class_likelihoods[0]