import io import random from io import BytesIO from typing import List, Tuple import aiohttp import panel as pn import torch from bokeh.themes import Theme # import torchvision.transforms.functional as TVF import torch.nn.functional as F from PIL import Image from transformers import AutoImageProcessor, ResNetForImageClassification from transformers.image_transforms import to_pil_image DEVICE = "cpu" pn.extension("mathjax", design="bootstrap", sizing_mode="stretch_width") @pn.cache def load_processor_model( processor_name: str, model_name: str ) -> Tuple[AutoImageProcessor, ResNetForImageClassification]: processor = AutoImageProcessor.from_pretrained(processor_name) model = ResNetForImageClassification.from_pretrained(model_name) return processor, model def denormalize(image, mean, std): mean = torch.tensor(mean).view(1, -1, 1, 1) # Reshape for broadcasting std = torch.tensor(std).view(1, -1, 1, 1) return image * std + mean # FGSM attack code def fgsm_attack(image, epsilon, data_grad): # Collect the element-wise sign of the data gradient sign_data_grad = data_grad.sign() # Create the perturbed image by adjusting each pixel of the input image perturbed_image = image + epsilon * sign_data_grad # Adding clipping to maintain [0,1] range perturbed_image = torch.clamp(perturbed_image, 0, 1) # Return the perturbed image return perturbed_image.detach() def run_forward_backward(image: Image, epsilon): processor, model = load_processor_model( "microsoft/resnet-18", "microsoft/resnet-18" ) # Grab input processor.crop_pct = 1 input_tensor = processor(image, return_tensors="pt")["pixel_values"] input_tensor.requires_grad_(True) # Run inference output = model(input_tensor) output = output.logits # Top target top_pred = output.max(1, keepdim=False)[1] # Get NLL loss and backward loss = F.cross_entropy(output, top_pred) model.zero_grad() loss.backward() # Denormalize input mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1) std = torch.tensor(processor.image_std).view(1, -1, 1, 1) input_tensor_denorm = input_tensor.clone().detach() * std + mean # Add noise to input random_noise = torch.sign(torch.randn_like(input_tensor)) * 0.02 input_tensor_denorm_noised = torch.clamp(input_tensor_denorm + random_noise, 0, 1) # input_tensor_denorm_noised = input_tensor_denorm # FGSM attack adv_input_tensor_denorm = fgsm_attack( image=input_tensor_denorm_noised, epsilon=epsilon, data_grad=input_tensor.grad.data, ) # Normalize adversarial input tensor back to the input range adv_input_tensor = (adv_input_tensor_denorm - mean) / std # Inference on adversarial image adv_output = model(adv_input_tensor) adv_output = adv_output.logits return ( output, adv_output, input_tensor_denorm.squeeze(), adv_input_tensor_denorm.squeeze(), ) async def process_inputs(button_event, image_data: bytes, epsilon: float): """ High level function that takes in the user inputs and returns the classification results as panel objects. """ try: main.disabled = True # if not button_event or (button_event and not isinstance(image_data, bytes)): if not isinstance(image_data, bytes): yield "##### 👋 Upload an image to proceed" return yield "##### ⚙ Fetching image and running model..." try: # Open the image using PIL pil_img = Image.open(BytesIO(image_data)) # Run forward + FGSM clean_logits, adv_logits, input_tensor, adv_input_tensor = ( run_forward_backward(image=pil_img, epsilon=epsilon) ) except Exception as e: yield f"##### Something went wrong, please try a different image! \n {e}" return img = pn.pane.Image( to_pil_image(input_tensor, do_rescale=True), height=300, align="center", ) # Convert image for visualizing adv_img_pil = to_pil_image(adv_input_tensor, do_rescale=True) adv_img = pn.pane.Image( adv_img_pil, height=300, align="center", ) # Download image button adv_img_bytes = io.BytesIO() adv_img_pil.save(adv_img_bytes, format="PNG") # download = pn.widgets.FileDownload( # to_pil_image(adv_img_bytes, do_rescale=True), # embed=True, # filename="adv_img.png", # button_type="primary", # button_style="outline", # width_policy="min", # ) # Build the results column k_val = 5 results = pn.Column( pn.Row("###### Uploaded", "###### Adversarial"), pn.Row(img, adv_img), # pn.Row(pn.Spacer(), download), f" ###### Top {k_val} class predictions", ) # Get likelihoods likelihoods = [ F.softmax(clean_logits, dim=1).squeeze(), F.softmax(adv_logits, dim=1).squeeze(), ] label_bars_rows = pn.Row() for likelihood_tensor in likelihoods: # Get top k values and indices vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) label_bars = pn.Column() for idx, val in zip(idx_topk_clean, vals_topk_clean): prob = val.item() row_label = pn.widgets.StaticText( name=f"{classes[idx]}", value=f"{prob:.2%}", align="center" ) row_bar = pn.indicators.Progress( value=int(prob * 100), sizing_mode="stretch_width", bar_color="success" if prob > 0.7 else "warning", # Dynamic color based on value margin=(0, 10), design=pn.theme.Material, ) label_bars.append(pn.Column(row_label, row_bar)) # for likelihood_tensor in likelihoods: # # Get top # vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) # label_bars = pn.Column() # for idx, val in zip(idx_topk_clean, vals_topk_clean): # prob = val.item() # row_label = pn.widgets.StaticText( # name=f"{classes[idx]}", value=f"{prob:.2%}", align="center" # ) # row_bar = pn.indicators.Progress( # value=int(prob * 100), # sizing_mode="stretch_width", # bar_color="secondary", # margin=(0, 10), # design=pn.theme.Material, # ) # label_bars.append(pn.Column(row_label, row_bar)) label_bars_rows.append(label_bars) results.append(label_bars_rows) yield results except Exception as e: yield f"##### Something went wrong! \n {e}" return finally: main.disabled = False #################################################################################################################################### # Get classes classes = [] with open("classes.txt", "r") as file: classes = file.read() classes = classes.split("\n") # Create widgets ############################################ # Fil upload widget file_input = pn.widgets.FileInput(name="Upload a PNG image", accept=".png,.jpg") # Epsilon epsilon_slider = pn.widgets.FloatSlider( name=r"$$\epsilon$$ parameter for FGSM", start=0, end=0.1, step=0.005, value=0.005, format="1[.]000", align="center", max_width=500, width_policy="max", ) # alpha_slider = pn.widgets.FloatSlider( # name=r"$$\alpha$$ parameter for Gaussian noise", # start=0, # end=0.1, # step=0.005, # value=0.000, # format="1[.]000", # align="center", # max_width=500, # width_policy="max" # ) # Regenerate button regenerate = pn.widgets.Button( name="Regenerate", button_type="primary", width_policy="min", max_width=105, ) ############################################ # Organize widgets in a column input_widgets = pn.Column( """ ###### Classify an image (png/jpeg) with a pre-trained [ResNet18](https://huggingface.co/microsoft/resnet-18) and generate an adversarial example.\n Wondering where the class names come from? Find the list of ImageNet-1K classes [here.](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/) *Please be patient with the application, it is running on a low-resource device.* """, file_input, pn.Row(epsilon_slider, pn.Spacer(width_policy="min", max_width=25), regenerate), ) # Add interactivity interactive_result = pn.panel( pn.bind( process_inputs, regenerate, file_input.param.value, epsilon_slider.param.value, ), height=600, ) footer = pn.pane.Markdown( """

If the application is too slow for you, head over to the README to get this running locally. """ ) # Create dashboard main = pn.WidgetBox( input_widgets, interactive_result, footer, ) title = "Adversarial Sample Generation" pn.template.BootstrapTemplate( title=title, main=main, main_max_width="min(75%, 698px)", header_background="#101820", ).servable(title=title) # Functions from original demo # ICON_URLS = { # "brand-github": "https://github.com/holoviz/panel", # "brand-twitter": "https://twitter.com/Panel_Org", # "brand-linkedin": "https://www.linkedin.com/company/panel-org", # "message-circle": "https://discourse.holoviz.org/", # "brand-discord": "https://discord.gg/AXRHnJU6sP", # } # async def random_url(_): # pet = random.choice(["cat", "dog"]) # api_url = f"https://api.the{pet}api.com/v1/images/search" # async with aiohttp.ClientSession() as session: # async with session.get(api_url) as resp: # return (await resp.json())[0]["url"] # @pn.cache # def load_processor_model( # processor_name: str, model_name: str # ) -> Tuple[CLIPProcessor, CLIPModel]: # processor = CLIPProcessor.from_pretrained(processor_name) # model = CLIPModel.from_pretrained(model_name) # return processor, model # async def open_image_url(image_url: str) -> Image: # async with aiohttp.ClientSession() as session: # async with session.get(image_url) as resp: # return Image.open(io.BytesIO(await resp.read())) # def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: # processor, model = load_processor_model( # "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" # ) # inputs = processor( # text=class_items, # images=[image], # return_tensors="pt", # pytorch tensors # ) # print(inputs) # outputs = model(**inputs) # logits_per_image = outputs.logits_per_image # class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() # return class_likelihoods[0]