fgsm-project / app.py
root
update default eps value
b9f4d50
raw
history blame contribute delete
No virus
11.4 kB
import io
import random
from io import BytesIO
from typing import List, Tuple
import aiohttp
import panel as pn
import torch
from bokeh.themes import Theme
# import torchvision.transforms.functional as TVF
import torch.nn.functional as F
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification
from transformers.image_transforms import to_pil_image
DEVICE = "cpu"
pn.extension("mathjax", design="bootstrap", sizing_mode="stretch_width")
@pn.cache
def load_processor_model(
processor_name: str, model_name: str
) -> Tuple[AutoImageProcessor, ResNetForImageClassification]:
processor = AutoImageProcessor.from_pretrained(processor_name)
model = ResNetForImageClassification.from_pretrained(model_name)
return processor, model
def denormalize(image, mean, std):
mean = torch.tensor(mean).view(1, -1, 1, 1) # Reshape for broadcasting
std = torch.tensor(std).view(1, -1, 1, 1)
return image * std + mean
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
# Collect the element-wise sign of the data gradient
sign_data_grad = data_grad.sign()
# Create the perturbed image by adjusting each pixel of the input image
perturbed_image = image + epsilon * sign_data_grad
# Adding clipping to maintain [0,1] range
perturbed_image = torch.clamp(perturbed_image, 0, 1)
# Return the perturbed image
return perturbed_image.detach()
def run_forward_backward(image: Image, epsilon):
processor, model = load_processor_model(
"microsoft/resnet-18", "microsoft/resnet-18"
)
# Grab input
processor.crop_pct = 1
input_tensor = processor(image, return_tensors="pt")["pixel_values"]
input_tensor.requires_grad_(True)
# Run inference
output = model(input_tensor)
output = output.logits
# Top target
top_pred = output.max(1, keepdim=False)[1]
# Get NLL loss and backward
loss = F.cross_entropy(output, top_pred)
model.zero_grad()
loss.backward()
# Denormalize input
mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1)
std = torch.tensor(processor.image_std).view(1, -1, 1, 1)
input_tensor_denorm = input_tensor.clone().detach() * std + mean
# Add noise to input
random_noise = torch.sign(torch.randn_like(input_tensor)) * 0.02
input_tensor_denorm_noised = torch.clamp(input_tensor_denorm + random_noise, 0, 1)
# input_tensor_denorm_noised = input_tensor_denorm
# FGSM attack
adv_input_tensor_denorm = fgsm_attack(
image=input_tensor_denorm_noised,
epsilon=epsilon,
data_grad=input_tensor.grad.data,
)
# Normalize adversarial input tensor back to the input range
adv_input_tensor = (adv_input_tensor_denorm - mean) / std
# Inference on adversarial image
adv_output = model(adv_input_tensor)
adv_output = adv_output.logits
return (
output,
adv_output,
input_tensor_denorm.squeeze(),
adv_input_tensor_denorm.squeeze(),
)
async def process_inputs(button_event, image_data: bytes, epsilon: float):
"""
High level function that takes in the user inputs and returns the
classification results as panel objects.
"""
try:
main.disabled = True
# if not button_event or (button_event and not isinstance(image_data, bytes)):
if not isinstance(image_data, bytes):
yield "##### πŸ‘‹ Upload an image to proceed"
return
yield "##### βš™ Fetching image and running model..."
try:
# Open the image using PIL
pil_img = Image.open(BytesIO(image_data))
# Run forward + FGSM
clean_logits, adv_logits, input_tensor, adv_input_tensor = (
run_forward_backward(image=pil_img, epsilon=epsilon)
)
except Exception as e:
yield f"##### Something went wrong, please try a different image! \n {e}"
return
img = pn.pane.Image(
to_pil_image(input_tensor, do_rescale=True),
height=300,
align="center",
)
# Convert image for visualizing
adv_img_pil = to_pil_image(adv_input_tensor, do_rescale=True)
adv_img = pn.pane.Image(
adv_img_pil,
height=300,
align="center",
)
# Download image button
adv_img_bytes = io.BytesIO()
adv_img_pil.save(adv_img_bytes, format="PNG")
# download = pn.widgets.FileDownload(
# to_pil_image(adv_img_bytes, do_rescale=True),
# embed=True,
# filename="adv_img.png",
# button_type="primary",
# button_style="outline",
# width_policy="min",
# )
# Build the results column
k_val = 5
results = pn.Column(
pn.Row("###### Uploaded", "###### Adversarial"),
pn.Row(img, adv_img),
# pn.Row(pn.Spacer(), download),
f" ###### Top {k_val} class predictions",
)
# Get likelihoods
likelihoods = [
F.softmax(clean_logits, dim=1).squeeze(),
F.softmax(adv_logits, dim=1).squeeze(),
]
label_bars_rows = pn.Row()
for likelihood_tensor in likelihoods:
# Get top k values and indices
vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val)
label_bars = pn.Column()
for idx, val in zip(idx_topk_clean, vals_topk_clean):
prob = val.item()
row_label = pn.widgets.StaticText(
name=f"{classes[idx]}", value=f"{prob:.2%}", align="center"
)
row_bar = pn.indicators.Progress(
value=int(prob * 100),
sizing_mode="stretch_width",
bar_color="success"
if prob > 0.7
else "warning", # Dynamic color based on value
margin=(0, 10),
design=pn.theme.Material,
)
label_bars.append(pn.Column(row_label, row_bar))
# for likelihood_tensor in likelihoods:
# # Get top
# vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val)
# label_bars = pn.Column()
# for idx, val in zip(idx_topk_clean, vals_topk_clean):
# prob = val.item()
# row_label = pn.widgets.StaticText(
# name=f"{classes[idx]}", value=f"{prob:.2%}", align="center"
# )
# row_bar = pn.indicators.Progress(
# value=int(prob * 100),
# sizing_mode="stretch_width",
# bar_color="secondary",
# margin=(0, 10),
# design=pn.theme.Material,
# )
# label_bars.append(pn.Column(row_label, row_bar))
label_bars_rows.append(label_bars)
results.append(label_bars_rows)
yield results
except Exception as e:
yield f"##### Something went wrong! \n {e}"
return
finally:
main.disabled = False
####################################################################################################################################
# Get classes
classes = []
with open("classes.txt", "r") as file:
classes = file.read()
classes = classes.split("\n")
# Create widgets
############################################
# Fil upload widget
file_input = pn.widgets.FileInput(name="Upload a PNG image", accept=".png,.jpg")
# Epsilon
epsilon_slider = pn.widgets.FloatSlider(
name=r"$$\epsilon$$ parameter for FGSM",
start=0,
end=0.1,
step=0.005,
value=0.005,
format="1[.]000",
align="center",
max_width=500,
width_policy="max",
)
# alpha_slider = pn.widgets.FloatSlider(
# name=r"$$\alpha$$ parameter for Gaussian noise",
# start=0,
# end=0.1,
# step=0.005,
# value=0.000,
# format="1[.]000",
# align="center",
# max_width=500,
# width_policy="max"
# )
# Regenerate button
regenerate = pn.widgets.Button(
name="Regenerate",
button_type="primary",
width_policy="min",
max_width=105,
)
############################################
# Organize widgets in a column
input_widgets = pn.Column(
"""
###### Classify an image (png/jpeg) with a pre-trained [ResNet18](https://huggingface.co/microsoft/resnet-18) and generate an adversarial example.\n
Wondering where the class names come from? Find the list of ImageNet-1K classes [here.](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/)
*Please be patient with the application, it is running on a low-resource device.*
""",
file_input,
pn.Row(epsilon_slider, pn.Spacer(width_policy="min", max_width=25), regenerate),
)
# Add interactivity
interactive_result = pn.panel(
pn.bind(
process_inputs,
regenerate,
file_input.param.value,
epsilon_slider.param.value,
),
height=600,
)
footer = pn.pane.Markdown(
"""
<br><br>
If the application is too slow for you, head over to the README to get this running locally.
"""
)
# Create dashboard
main = pn.WidgetBox(
input_widgets,
interactive_result,
footer,
)
title = "Adversarial Sample Generation"
pn.template.BootstrapTemplate(
title=title,
main=main,
main_max_width="min(75%, 698px)",
header_background="#101820",
).servable(title=title)
# Functions from original demo
# ICON_URLS = {
# "brand-github": "https://github.com/holoviz/panel",
# "brand-twitter": "https://twitter.com/Panel_Org",
# "brand-linkedin": "https://www.linkedin.com/company/panel-org",
# "message-circle": "https://discourse.holoviz.org/",
# "brand-discord": "https://discord.gg/AXRHnJU6sP",
# }
# async def random_url(_):
# pet = random.choice(["cat", "dog"])
# api_url = f"https://api.the{pet}api.com/v1/images/search"
# async with aiohttp.ClientSession() as session:
# async with session.get(api_url) as resp:
# return (await resp.json())[0]["url"]
# @pn.cache
# def load_processor_model(
# processor_name: str, model_name: str
# ) -> Tuple[CLIPProcessor, CLIPModel]:
# processor = CLIPProcessor.from_pretrained(processor_name)
# model = CLIPModel.from_pretrained(model_name)
# return processor, model
# async def open_image_url(image_url: str) -> Image:
# async with aiohttp.ClientSession() as session:
# async with session.get(image_url) as resp:
# return Image.open(io.BytesIO(await resp.read()))
# def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
# processor, model = load_processor_model(
# "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
# )
# inputs = processor(
# text=class_items,
# images=[image],
# return_tensors="pt", # pytorch tensors
# )
# print(inputs)
# outputs = model(**inputs)
# logits_per_image = outputs.logits_per_image
# class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
# return class_likelihoods[0]