Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # Copyright 2023 Authors of "A Watermark for Large Language Models" | |
| # available at https://arxiv.org/abs/2301.10226 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| import collections | |
| from math import sqrt | |
| import scipy.stats | |
| import torch | |
| from torch import Tensor | |
| from tokenizers import Tokenizer | |
| from transformers import LogitsProcessor | |
| from nltk.util import ngrams | |
| from normalizers import normalization_strategy_lookup | |
| class WatermarkBase: | |
| def __init__( | |
| self, | |
| vocab: list[int] = None, | |
| gamma: float = 0.5, | |
| delta: float = 2.0, | |
| seeding_scheme: str = "markov_1", # mostly unused/always default | |
| hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width | |
| select_green_tokens: bool = True, | |
| ): | |
| # watermarking parameters | |
| self.vocab = vocab | |
| self.vocab_size = len(vocab) | |
| self.gamma = gamma | |
| self.delta = delta | |
| self.seeding_scheme = seeding_scheme | |
| self.rng = None | |
| self.hash_key = hash_key | |
| self.select_green_tokens = select_green_tokens | |
| def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None: | |
| # can optionally override the seeding scheme, | |
| # but uses the instance attr by default | |
| if seeding_scheme is None: | |
| seeding_scheme = self.seeding_scheme | |
| if seeding_scheme == "markov_1": | |
| assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng" | |
| prev_token = input_ids[-1].item() | |
| self.rng.manual_seed(self.hash_key * prev_token) | |
| else: | |
| raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}") | |
| return | |
| def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]: | |
| # seed the rng using the previous tokens/prefix | |
| # according to the seeding_scheme | |
| self._seed_rng(input_ids) | |
| greenlist_size = int(self.vocab_size * self.gamma) | |
| vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng) | |
| if self.select_green_tokens: # directly | |
| greenlist_ids = vocab_permutation[:greenlist_size] # new | |
| else: # select green via red | |
| greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior | |
| return greenlist_ids | |
| class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor): | |
| # FIXME maybe make this explict instead of args/kwargs | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor: | |
| # TODO lets see if we can lose this loop | |
| green_tokens_mask = torch.zeros_like(scores) | |
| for b_idx in range(len(greenlist_token_ids)): | |
| green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1 | |
| final_mask = green_tokens_mask.bool() | |
| return final_mask | |
| def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor: | |
| scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias | |
| return scores | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| # this is lazy to allow us to colocate on the watermarked model's device | |
| if self.rng is None: | |
| self.rng = torch.Generator(device=input_ids.device) | |
| # NOTE, it would be nice to get rid of this batch loop, but currently, | |
| # the seed and partition operations are not tensor/vectorized, thus | |
| # each sequence in the batch needs to be treated separately. | |
| batched_greenlist_ids = [None for _ in range(input_ids.shape[0])] | |
| for b_idx in range(input_ids.shape[0]): | |
| greenlist_ids = self._get_greenlist_ids(input_ids[b_idx]) | |
| batched_greenlist_ids[b_idx] = greenlist_ids | |
| green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids) | |
| scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta) | |
| return scores | |
| class WatermarkDetector(WatermarkBase): | |
| def __init__( | |
| self, | |
| *args, | |
| device: torch.device = None, | |
| tokenizer: Tokenizer = None, | |
| z_threshold: float = 4.0, | |
| normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"] | |
| ignore_repeated_bigrams: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| # also configure the metrics returned/preprocessing options | |
| assert device, "Must pass device" | |
| assert tokenizer, "Need an instance of the generating tokenizer to perform detection" | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| self.z_threshold = z_threshold | |
| self.rng = torch.Generator(device=self.device) | |
| if self.seeding_scheme == "markov_1": | |
| self.min_prefix_len = 1 | |
| else: | |
| raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}") | |
| self.normalizers = [] | |
| for normalization_strategy in normalizers: | |
| self.normalizers.append(normalization_strategy_lookup(normalization_strategy)) | |
| self.ignore_repeated_bigrams = ignore_repeated_bigrams | |
| if self.ignore_repeated_bigrams: | |
| assert self.seeding_scheme == "markov_1", "No repeated bigram credit variant assumes the single token seeding scheme." | |
| def _compute_z_score(self, observed_count, T): | |
| # count refers to number of green tokens, T is total number of tokens | |
| expected_count = self.gamma | |
| numer = observed_count - expected_count * T | |
| denom = sqrt(T * expected_count * (1 - expected_count)) | |
| z = numer / denom | |
| return z | |
| def _compute_p_value(self, z): | |
| p_value = scipy.stats.norm.sf(z) | |
| return p_value | |
| def _score_sequence( | |
| self, | |
| input_ids: Tensor, | |
| return_num_tokens_scored: bool = True, | |
| return_num_green_tokens: bool = True, | |
| return_green_fraction: bool = True, | |
| return_green_token_mask: bool = False, | |
| return_z_score: bool = True, | |
| return_p_value: bool = True, | |
| ): | |
| if self.ignore_repeated_bigrams: | |
| # Method that only counts a green/red hit once per unique bigram. | |
| # New num total tokens scored (T) becomes the number unique bigrams. | |
| # We iterate over all unqiue token bigrams in the input, computing the greenlist | |
| # induced by the first token in each, and then checking whether the second | |
| # token falls in that greenlist. | |
| assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats." | |
| bigram_table = {} | |
| token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2) | |
| freq = collections.Counter(token_bigram_generator) | |
| num_tokens_scored = len(freq.keys()) | |
| for idx, bigram in enumerate(freq.keys()): | |
| prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device | |
| greenlist_ids = self._get_greenlist_ids(prefix) | |
| bigram_table[bigram] = True if bigram[1] in greenlist_ids else False | |
| green_token_count = sum(bigram_table.values()) | |
| else: | |
| num_tokens_scored = len(input_ids) - self.min_prefix_len | |
| if num_tokens_scored < 1: | |
| raise ValueError((f"Must have at least {1} token to score after " | |
| f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme.")) | |
| # Standard method. | |
| # Since we generally need at least 1 token (for the simplest scheme) | |
| # we start the iteration over the token sequence with a minimum | |
| # num tokens as the first prefix for the seeding scheme, | |
| # and at each step, compute the greenlist induced by the | |
| # current prefix and check if the current token falls in the greenlist. | |
| green_token_count, green_token_mask = 0, [] | |
| for idx in range(self.min_prefix_len, len(input_ids)): | |
| curr_token = input_ids[idx] | |
| greenlist_ids = self._get_greenlist_ids(input_ids[:idx]) | |
| if curr_token in greenlist_ids: | |
| green_token_count += 1 | |
| green_token_mask.append(True) | |
| else: | |
| green_token_mask.append(False) | |
| score_dict = dict() | |
| if return_num_tokens_scored: | |
| score_dict.update(dict(num_tokens_scored=num_tokens_scored)) | |
| if return_num_green_tokens: | |
| score_dict.update(dict(num_green_tokens=green_token_count)) | |
| if return_z_score: | |
| score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored))) | |
| if return_p_value: | |
| z_score = score_dict.get("z_score") | |
| if z_score is None: | |
| z_score = self._compute_z_score(green_token_count, num_tokens_scored) | |
| score_dict.update(dict(p_value=self._compute_p_value(z_score))) | |
| if return_green_fraction: | |
| score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored))) | |
| if return_green_token_mask: | |
| score_dict.update(dict(green_token_mask=green_token_mask)) | |
| return score_dict | |
| def detect( | |
| self, | |
| text: str = None, | |
| tokenized_text: list[int] = None, | |
| return_prediction: bool = True, | |
| return_scores: bool = True, | |
| z_threshold: float = None, | |
| **kwargs, | |
| ) -> dict: | |
| assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string" | |
| if return_prediction: | |
| kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections | |
| # run optional normalizers on text | |
| for normalizer in self.normalizers: | |
| text = normalizer(text) | |
| if len(self.normalizers) > 0: | |
| print(f"Text after normalization:\n\n{text}\n") | |
| if tokenized_text is None: | |
| assert self.tokenizer is not None, ( | |
| "Watermark detection on raw string ", | |
| "requires an instance of the tokenizer ", | |
| "that was used at generation time.", | |
| ) | |
| tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device) | |
| if tokenized_text[0] == self.tokenizer.bos_token_id: | |
| tokenized_text = tokenized_text[1:] | |
| else: | |
| # try to remove the bos_tok at beginning if it's there | |
| if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id): | |
| tokenized_text = tokenized_text[1:] | |
| # call score method | |
| output_dict = {} | |
| score_dict = self._score_sequence(tokenized_text, **kwargs) | |
| if return_scores: | |
| output_dict.update(score_dict) | |
| # if passed return_prediction then perform the hypothesis test and return the outcome | |
| if return_prediction: | |
| z_threshold = z_threshold if z_threshold else self.z_threshold | |
| assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test" | |
| output_dict["prediction"] = score_dict["z_score"] > z_threshold | |
| if output_dict["prediction"]: | |
| output_dict["confidence"] = 1 - score_dict["p_value"] | |
| return output_dict | |