import torch from scipy.stats import norm,truncnorm from functools import reduce from scipy.special import betainc import numpy as np class Gaussian_Shading: def __init__(self, ch_factor, hw_factor, fpr, user_number): self.ch = ch_factor self.hw = hw_factor self.key = None self.watermark = None self.latentlength = 4 * 64 * 64 self.marklength = self.latentlength//(self.ch * self.hw * self.hw) self.threshold = 1 if self.hw == 1 and self.ch == 1 else self.ch * self.hw * self.hw // 2 self.tp_onebit_count = 0 self.tp_bits_count = 0 self.tau_onebit = None self.tau_bits = None for i in range(self.marklength): fpr_onebit = betainc(i+1, self.marklength-i, 0.5) fpr_bits = betainc(i+1, self.marklength-i, 0.5) * user_number if fpr_onebit <= fpr and self.tau_onebit is None: self.tau_onebit = i / self.marklength if fpr_bits <= fpr and self.tau_bits is None: self.tau_bits = i / self.marklength def truncSampling(self, message): z = np.zeros(self.latentlength) denominator = 2.0 ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)] for i in range(self.latentlength): dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1]) dec_mes = int(dec_mes) z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1]) z = torch.from_numpy(z).reshape(1, 4, 64, 64).half() return z.cuda() def create_watermark_and_return_w(self): rng_state = torch.get_rng_state() torch.manual_seed(42) self.key = torch.randint(0, 2, [1, 4, 64, 64]).cuda() torch.set_rng_state(rng_state) self.watermark = torch.randint(0, 2, [1, 4 // self.ch, 64 // self.hw, 64 // self.hw]).cuda() sd = self.watermark.repeat(1,self.ch,self.hw,self.hw) m = ((sd + self.key) % 2).flatten().cpu().numpy() w = self.truncSampling(m) return w, self.key, self.watermark def diffusion_inverse(self,watermark_sd): ch_stride = 4 // self.ch hw_stride = 64 // self.hw ch_list = [ch_stride] * self.ch hw_list = [hw_stride] * self.hw split_dim1 = torch.cat(torch.split(watermark_sd, tuple(ch_list), dim=1), dim=0) split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0) split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0) vote = torch.sum(split_dim3, dim=0).clone() vote[vote <= self.threshold] = 0 vote[vote > self.threshold] = 1 return vote def sequence_binary_watermark(self, watermark): ls = watermark.view(-1).tolist() sequence = ''.join(str(i) for i in ls) return sequence def eval_watermark(self, reversed_m): key = torch.load('key.pt') reversed_m = (reversed_m > 0).int() # reversed_sd = (reversed_m + self.key) % 2 reversed_sd = (reversed_m + key) % 2 reversed_watermark = self.diffusion_inverse(reversed_sd) print(f"The extracted watermark is {self.sequence_binary_watermark(reversed_watermark)}") watermark = torch.load('watermark.pt') ls_accurate = [] for i in watermark: ls_accurate.append((reversed_watermark == i).float().mean().item()) correct = max(ls_accurate) if correct >= self.tau_onebit: self.tp_onebit_count = self.tp_onebit_count+1 if correct >= self.tau_bits: self.tp_bits_count = self.tp_bits_count + 1 return self.sequence_binary_watermark(reversed_watermark), correct def get_tpr(self): return self.tp_onebit_count, self.tp_bits_count