TruongScotl's picture
Upload 7 files
f961e67 verified
import torch
from scipy.stats import norm,truncnorm
from functools import reduce
from scipy.special import betainc
import numpy as np
class Gaussian_Shading:
def __init__(self, ch_factor, hw_factor, fpr, user_number):
self.ch = ch_factor
self.hw = hw_factor
self.key = None
self.watermark = None
self.latentlength = 4 * 64 * 64
self.marklength = self.latentlength//(self.ch * self.hw * self.hw)
self.threshold = 1 if self.hw == 1 and self.ch == 1 else self.ch * self.hw * self.hw // 2
self.tp_onebit_count = 0
self.tp_bits_count = 0
self.tau_onebit = None
self.tau_bits = None
for i in range(self.marklength):
fpr_onebit = betainc(i+1, self.marklength-i, 0.5)
fpr_bits = betainc(i+1, self.marklength-i, 0.5) * user_number
if fpr_onebit <= fpr and self.tau_onebit is None:
self.tau_onebit = i / self.marklength
if fpr_bits <= fpr and self.tau_bits is None:
self.tau_bits = i / self.marklength
def truncSampling(self, message):
z = np.zeros(self.latentlength)
denominator = 2.0
ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
for i in range(self.latentlength):
dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1])
dec_mes = int(dec_mes)
z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
z = torch.from_numpy(z).reshape(1, 4, 64, 64).half()
return z.cuda()
def create_watermark_and_return_w(self):
rng_state = torch.get_rng_state()
torch.manual_seed(42)
self.key = torch.randint(0, 2, [1, 4, 64, 64]).cuda()
torch.set_rng_state(rng_state)
self.watermark = torch.randint(0, 2, [1, 4 // self.ch, 64 // self.hw, 64 // self.hw]).cuda()
sd = self.watermark.repeat(1,self.ch,self.hw,self.hw)
m = ((sd + self.key) % 2).flatten().cpu().numpy()
w = self.truncSampling(m)
return w, self.key, self.watermark
def diffusion_inverse(self,watermark_sd):
ch_stride = 4 // self.ch
hw_stride = 64 // self.hw
ch_list = [ch_stride] * self.ch
hw_list = [hw_stride] * self.hw
split_dim1 = torch.cat(torch.split(watermark_sd, tuple(ch_list), dim=1), dim=0)
split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)
vote = torch.sum(split_dim3, dim=0).clone()
vote[vote <= self.threshold] = 0
vote[vote > self.threshold] = 1
return vote
def sequence_binary_watermark(self, watermark):
ls = watermark.view(-1).tolist()
sequence = ''.join(str(i) for i in ls)
return sequence
def eval_watermark(self, reversed_m):
key = torch.load('key.pt')
reversed_m = (reversed_m > 0).int()
# reversed_sd = (reversed_m + self.key) % 2
reversed_sd = (reversed_m + key) % 2
reversed_watermark = self.diffusion_inverse(reversed_sd)
print(f"The extracted watermark is {self.sequence_binary_watermark(reversed_watermark)}")
watermark = torch.load('watermark.pt')
ls_accurate = []
for i in watermark:
ls_accurate.append((reversed_watermark == i).float().mean().item())
correct = max(ls_accurate)
if correct >= self.tau_onebit:
self.tp_onebit_count = self.tp_onebit_count+1
if correct >= self.tau_bits:
self.tp_bits_count = self.tp_bits_count + 1
return self.sequence_binary_watermark(reversed_watermark), correct
def get_tpr(self):
return self.tp_onebit_count, self.tp_bits_count