pierrefdz commited on
Commit
8e6cbe9
·
1 Parent(s): 1fdb165

inintal commit

Browse files
Files changed (47) hide show
  1. .dockerignore +20 -0
  2. .gitattributes +6 -1
  3. .gitignore +2 -0
  4. Dockerfile +24 -0
  5. README.md +8 -6
  6. data/prompts.json +52 -0
  7. requirements.txt +3 -0
  8. run.py +12 -0
  9. sandbox.ipynb +81 -0
  10. tests/__init__.py +0 -0
  11. wm_interactive/__init__.py +0 -0
  12. wm_interactive/core/__init__.py +0 -0
  13. wm_interactive/core/detector.py +263 -0
  14. wm_interactive/core/generator.py +211 -0
  15. wm_interactive/core/hashing.py +13 -0
  16. wm_interactive/core/main.py +256 -0
  17. wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/0ad5ecc2035b7031b88afb544ee95e2d49baa484.lock +0 -0
  18. wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/36293b6099200eb8aeb55ae2c01bca2ba46d80d0.lock +0 -0
  19. wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/44719d2e365acac0637fd25a3acf46494ca45940.lock +0 -0
  20. wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c.lock +0 -0
  21. wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/69503b13f727ba3812b6803e97442a6de05ef5eb.lock +0 -0
  22. wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/8c7b22013909450429303ed10be4398bd63f5457.lock +0 -0
  23. wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa.lock +0 -0
  24. wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/f922b1797f0c88e71addc8393787831f2477a4bd.lock +0 -0
  25. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/.no_exist/e2c3f7557efbdec707ae3a336371d169783f1da1/added_tokens.json +0 -0
  26. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/0ad5ecc2035b7031b88afb544ee95e2d49baa484 +3 -0
  27. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/36293b6099200eb8aeb55ae2c01bca2ba46d80d0 +3 -0
  28. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/44719d2e365acac0637fd25a3acf46494ca45940 +3 -0
  29. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c +3 -0
  30. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/69503b13f727ba3812b6803e97442a6de05ef5eb +3 -0
  31. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/8c7b22013909450429303ed10be4398bd63f5457 +3 -0
  32. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa +3 -0
  33. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/f922b1797f0c88e71addc8393787831f2477a4bd +3 -0
  34. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/refs/main +3 -0
  35. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/config.json +3 -0
  36. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/generation_config.json +3 -0
  37. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/merges.txt +3 -0
  38. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/model.safetensors +3 -0
  39. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/special_tokens_map.json +3 -0
  40. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/tokenizer.json +3 -0
  41. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/tokenizer_config.json +3 -0
  42. wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/vocab.json +3 -0
  43. wm_interactive/static/styles.css +357 -0
  44. wm_interactive/templates/index.html +459 -0
  45. wm_interactive/web/__init__.py +0 -0
  46. wm_interactive/web/app.py +241 -0
  47. wm_interactive/web/utils.py +83 -0
.dockerignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ env
7
+ pip-log.txt
8
+ pip-delete-this-directory.txt
9
+ .tox
10
+ .coverage
11
+ .coverage.*
12
+ .cache
13
+ nosetests.xml
14
+ coverage.xml
15
+ *.cover
16
+ *.log
17
+ .pytest_cache
18
+ .env
19
+ .venv
20
+ .DS_Store
.gitattributes CHANGED
@@ -1,6 +1,8 @@
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +35,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
1
+ *.pdf filter=lfs diff=lfs merge=lfs -text
2
+ *.txt filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.7z filter=lfs diff=lfs merge=lfs -text
5
  *.arrow filter=lfs diff=lfs merge=lfs -text
 
6
  *.bz2 filter=lfs diff=lfs merge=lfs -text
7
  *.ckpt filter=lfs diff=lfs merge=lfs -text
8
  *.ftz filter=lfs diff=lfs merge=lfs -text
 
35
  *.zip filter=lfs diff=lfs merge=lfs -text
36
  *.zst filter=lfs diff=lfs merge=lfs -text
37
  *tfevents* filter=lfs diff=lfs merge=lfs -text
38
+ static/ia_gen_droits_auteur.pdf filter=lfs diff=lfs merge=lfs -text
39
+ wm_interactive/static/hf_cache/** filter=lfs diff=lfs merge=lfs -text
40
+ wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Copy only the requirements first to leverage Docker cache
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy the rest of the application
10
+ COPY wm_interactive/ ./wm_interactive/
11
+ COPY run.py .
12
+
13
+ # Create necessary directories
14
+ RUN mkdir -p wm_interactive/static/hf_cache
15
+
16
+ # Set environment variables
17
+ ENV PYTHONPATH=/app
18
+ ENV FLASK_APP=run.py
19
+
20
+ # Expose the port the app runs on
21
+ EXPOSE 7860
22
+
23
+ # Command to run the application
24
+ CMD ["python", "run.py"]
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Interactive Llm Wm
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
- short_description: An interactive demo for LLM watermarking
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: Interactive Text Watermark Detection
3
+ emoji: 📝
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
+ short_description: An interactive demo for detection of text watermarks
10
  ---
11
 
12
+ # Interactive Text Watermark Detection
13
+
14
+ This repository contains the code for an interactive demo for detection of watermarked text generated from LLM (Large Language Model) models.
data/prompts.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "Write a short story about a robot learning to paint.",
4
+ "input": "",
5
+ "output": ""
6
+ },
7
+ {
8
+ "instruction": "Explain how photosynthesis works in simple terms.",
9
+ "input": "",
10
+ "output": ""
11
+ },
12
+ {
13
+ "instruction": "Write a recipe for chocolate chip cookies.",
14
+ "input": "",
15
+ "output": ""
16
+ },
17
+ {
18
+ "instruction": "Describe the main differences between classical and quantum computing.",
19
+ "input": "",
20
+ "output": ""
21
+ },
22
+ {
23
+ "instruction": "Write a haiku about the changing seasons.",
24
+ "input": "",
25
+ "output": ""
26
+ },
27
+ {
28
+ "instruction": "Explain why the sky appears blue during the day.",
29
+ "input": "",
30
+ "output": ""
31
+ },
32
+ {
33
+ "instruction": "Write a short dialogue between two friends discussing their favorite books.",
34
+ "input": "",
35
+ "output": ""
36
+ },
37
+ {
38
+ "instruction": "Describe three ways to reduce your carbon footprint.",
39
+ "input": "",
40
+ "output": ""
41
+ },
42
+ {
43
+ "instruction": "Write a brief explanation of how the internet works.",
44
+ "input": "",
45
+ "output": ""
46
+ },
47
+ {
48
+ "instruction": "Create a short motivational speech about perseverance.",
49
+ "input": "",
50
+ "output": ""
51
+ }
52
+ ]
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59ef79cf6a8a998de982ccc64e93ca2b7602aa989b38b2c264d385acf728ef80
3
+ size 75
run.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main entry point for the watermark detection application.
3
+ Run with: python run.py
4
+
5
+ docker build -t wm-interactive .
6
+ docker run -p 7860:7860 wm-interactive
7
+ """
8
+
9
+ from wm_interactive.web.app import app
10
+
11
+ if __name__ == "__main__":
12
+ app.run(host='0.0.0.0', port=7860)
sandbox.ipynb ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": []
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 2,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "from transformers import AutoTokenizer, LlamaForCausalLM\n",
17
+ "\n",
18
+ "model_id = \"meta-llama/Llama-3.2-1B-Instruct\"\n",
19
+ "tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=\"wm_detector/static/hf_cache\")"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 6,
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "[4438, 311, 1304, 264, 19692]\n",
32
+ "['How', 'Ġto', 'Ġmake', 'Ġa', 'Ġcake']\n",
33
+ "['How', ' to', ' make', ' a', ' cake']\n"
34
+ ]
35
+ }
36
+ ],
37
+ "source": [
38
+ "def tokenize_text(text):\n",
39
+ " return tokenizer.encode(text, add_special_tokens=False)\n",
40
+ "\n",
41
+ "text = \"How to make a cake\"\n",
42
+ "token_ids = tokenize_text(text)\n",
43
+ "tokens = tokenizer.convert_ids_to_tokens(token_ids)\n",
44
+ "token_strs = [tokenizer.convert_tokens_to_string([token]) for token in tokens]\n",
45
+ "decoded = tokenizer.decode(tokenize_text(text))\n",
46
+ "\n",
47
+ "print(token_ids)\n",
48
+ "print(tokens)\n",
49
+ "print(token_strs)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": []
58
+ }
59
+ ],
60
+ "metadata": {
61
+ "kernelspec": {
62
+ "display_name": "base",
63
+ "language": "python",
64
+ "name": "python3"
65
+ },
66
+ "language_info": {
67
+ "codemirror_mode": {
68
+ "name": "ipython",
69
+ "version": 3
70
+ },
71
+ "file_extension": ".py",
72
+ "mimetype": "text/x-python",
73
+ "name": "python",
74
+ "nbconvert_exporter": "python",
75
+ "pygments_lexer": "ipython3",
76
+ "version": "3.12.2"
77
+ }
78
+ },
79
+ "nbformat": 4,
80
+ "nbformat_minor": 2
81
+ }
tests/__init__.py ADDED
File without changes
wm_interactive/__init__.py ADDED
File without changes
wm_interactive/core/__init__.py ADDED
File without changes
wm_interactive/core/detector.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from scipy import special
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+
8
+ from .hashing import get_seed_rng
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ class WmDetector():
13
+ def __init__(self,
14
+ tokenizer: AutoTokenizer,
15
+ ngram: int = 1,
16
+ seed: int = 0
17
+ ):
18
+ # model config
19
+ self.tokenizer = tokenizer
20
+ self.vocab_size = self.tokenizer.vocab_size
21
+ # watermark config
22
+ self.ngram = ngram
23
+ self.seed = seed
24
+ self.rng = torch.Generator()
25
+ self.rng.manual_seed(self.seed)
26
+
27
+ def aggregate_scores(
28
+ self,
29
+ scores: list[np.array],
30
+ aggregation: str = 'mean'
31
+ ) -> float:
32
+ """Aggregate scores along a text."""
33
+ if aggregation == 'sum':
34
+ return scores.sum(axis=0)
35
+ elif aggregation == 'mean':
36
+ return scores.mean(axis=0)
37
+ elif aggregation == 'max':
38
+ return scores.max(axis=0)
39
+ else:
40
+ raise ValueError(f'Aggregation {aggregation} not supported.')
41
+
42
+ def get_details(
43
+ self,
44
+ text: str,
45
+ scoring_method: str="v2",
46
+ ntoks_max: int = None,
47
+ ) -> list[dict]:
48
+ """
49
+ Get score increment for each token in text.
50
+ Args:
51
+ text: input text
52
+ scoring_method:
53
+ 'none': score all ngrams
54
+ 'v1': only score tokens for which wm window is unique
55
+ 'v2': only score unique {wm window+tok} is unique
56
+ ntoks_max: maximum number of tokens
57
+ Output:
58
+ token_details: list of dicts containing token info and scores
59
+ """
60
+ tokens_id = self.tokenizer.encode(text, add_special_tokens=False)
61
+ if ntoks_max is not None:
62
+ tokens_id = tokens_id[:ntoks_max]
63
+
64
+ total_len = len(tokens_id)
65
+ token_details = []
66
+ seen_grams = set()
67
+
68
+ # Add initial tokens that can't be scored (not enough context)
69
+ num_start = min(self.ngram, total_len)
70
+ for i in range(num_start):
71
+ token_details.append({
72
+ 'token_id': tokens_id[i],
73
+ 'is_scored': False,
74
+ 'score': float('nan'),
75
+ 'token_text': self.tokenizer.decode([tokens_id[i]])
76
+ })
77
+
78
+ # Score remaining tokens
79
+ for cur_pos in range(self.ngram, total_len):
80
+ ngram_tokens = tokens_id[cur_pos-self.ngram:cur_pos]
81
+ is_scored = True
82
+
83
+ if scoring_method == 'v1':
84
+ tup_for_unique = tuple(ngram_tokens)
85
+ is_scored = tup_for_unique not in seen_grams
86
+ if is_scored:
87
+ seen_grams.add(tup_for_unique)
88
+ elif scoring_method == 'v2':
89
+ tup_for_unique = tuple(ngram_tokens + [tokens_id[cur_pos]])
90
+ is_scored = tup_for_unique not in seen_grams
91
+ if is_scored:
92
+ seen_grams.add(tup_for_unique)
93
+
94
+ score = float('nan')
95
+ if is_scored:
96
+ score = self.score_tok(ngram_tokens, tokens_id[cur_pos])
97
+ score = float(score)
98
+
99
+ token_details.append({
100
+ 'token_id': tokens_id[cur_pos],
101
+ 'is_scored': is_scored,
102
+ 'score': score,
103
+ 'token_text': self.tokenizer.decode([tokens_id[cur_pos]])
104
+ })
105
+
106
+ return token_details
107
+
108
+ def get_pvalues_by_tok(
109
+ self,
110
+ token_details: list[dict]
111
+ ) -> tuple[list[float], dict]:
112
+ """
113
+ Get p-value for each token so far.
114
+ Args:
115
+ token_details: list of dicts containing token info and scores from get_details()
116
+ Returns:
117
+ tuple containing:
118
+ - list of p-values, with nan for unscored tokens
119
+ - dict with auxiliary information:
120
+ - final_score: final running score
121
+ - ntoks_scored: final number of scored tokens
122
+ - final_pvalue: last non-nan pvalue (0.5 if none available)
123
+ """
124
+ pvalues = []
125
+ running_score = 0
126
+ ntoks_scored = 0
127
+ eps = 1e-10 # small constant to avoid numerical issues
128
+ last_valid_pvalue = 0.5 # default value if no tokens are scored
129
+
130
+ for token in token_details:
131
+ if token['is_scored']:
132
+ running_score += token['score']
133
+ ntoks_scored += 1
134
+ pvalue = self.get_pvalue(running_score, ntoks_scored, eps)
135
+ last_valid_pvalue = pvalue
136
+ pvalues.append(pvalue)
137
+ else:
138
+ pvalues.append(float('nan'))
139
+
140
+ aux_info = {
141
+ 'final_score': running_score,
142
+ 'ntoks_scored': ntoks_scored,
143
+ 'final_pvalue': last_valid_pvalue
144
+ }
145
+
146
+ return pvalues, aux_info
147
+
148
+ def score_tok(self, ngram_tokens: list[int], token_id: int):
149
+ """ for each token in the text, compute the score increment """
150
+ raise NotImplementedError
151
+
152
+ def get_pvalue(self, score: float, ntoks: int, eps: float):
153
+ """ compute the p-value for a couple of score and number of tokens """
154
+ raise NotImplementedError
155
+
156
+
157
+ class MarylandDetector(WmDetector):
158
+
159
+ def __init__(self,
160
+ tokenizer: AutoTokenizer,
161
+ ngram: int = 1,
162
+ seed: int = 0,
163
+ gamma: float = 0.5,
164
+ delta: float = 1.0,
165
+ **kwargs):
166
+ super().__init__(tokenizer, ngram, seed, **kwargs)
167
+ self.gamma = gamma
168
+ self.delta = delta
169
+
170
+ def score_tok(self, ngram_tokens, token_id):
171
+ """
172
+ score_t = 1 if token_id in greenlist else 0
173
+ """
174
+ seed = get_seed_rng(self.seed, ngram_tokens)
175
+ self.rng.manual_seed(seed)
176
+ scores = torch.zeros(self.vocab_size)
177
+ vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
178
+ greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n toks in the greenlist
179
+ scores[greenlist] = 1
180
+ return scores[token_id]
181
+
182
+ def get_pvalue(self, score: int, ntoks: int, eps: float):
183
+ """ from cdf of a binomial distribution """
184
+ pvalue = special.betainc(score, 1 + ntoks - score, self.gamma)
185
+ return max(pvalue, eps)
186
+
187
+ class MarylandDetectorZ(WmDetector):
188
+
189
+ def __init__(self,
190
+ tokenizer: AutoTokenizer,
191
+ ngram: int = 1,
192
+ seed: int = 0,
193
+ gamma: float = 0.5,
194
+ delta: float = 1.0,
195
+ **kwargs):
196
+ super().__init__(tokenizer, ngram, seed, **kwargs)
197
+ self.gamma = gamma
198
+ self.delta = delta
199
+
200
+ def score_tok(self, ngram_tokens, token_id):
201
+ """ same as MarylandDetector but using zscore """
202
+ seed = get_seed_rng(self.seed, ngram_tokens)
203
+ self.rng.manual_seed(seed)
204
+ scores = torch.zeros(self.vocab_size)
205
+ vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
206
+ greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
207
+ scores[greenlist] = 1
208
+ return scores[token_id]
209
+
210
+ def get_pvalue(self, score: int, ntoks: int, eps: float):
211
+ """ from cdf of a normal distribution """
212
+ zscore = (score - self.gamma * ntoks) / np.sqrt(self.gamma * (1 - self.gamma) * ntoks)
213
+ pvalue = 0.5 * special.erfc(zscore / np.sqrt(2))
214
+ return max(pvalue, eps)
215
+
216
+ class OpenaiDetector(WmDetector):
217
+
218
+ def __init__(self,
219
+ tokenizer: AutoTokenizer,
220
+ ngram: int = 1,
221
+ seed: int = 0,
222
+ **kwargs):
223
+ super().__init__(tokenizer, ngram, seed, **kwargs)
224
+
225
+ def score_tok(self, ngram_tokens, token_id):
226
+ """
227
+ score_t = -log(1 - rt[token_id]])
228
+ """
229
+ seed = get_seed_rng(self.seed, ngram_tokens)
230
+ self.rng.manual_seed(seed)
231
+ rs = torch.rand(self.vocab_size, generator=self.rng) # n
232
+ scores = -(1 - rs).log()
233
+ return scores[token_id]
234
+
235
+ def get_pvalue(self, score: float, ntoks: int, eps: float):
236
+ """ from cdf of a gamma distribution """
237
+ pvalue = special.gammaincc(ntoks, score)
238
+ return max(pvalue, eps)
239
+
240
+ class OpenaiDetectorZ(WmDetector):
241
+
242
+ def __init__(self,
243
+ tokenizer: AutoTokenizer,
244
+ ngram: int = 1,
245
+ seed: int = 0,
246
+ **kwargs):
247
+ super().__init__(tokenizer, ngram, seed, **kwargs)
248
+
249
+ def score_tok(self, ngram_tokens, token_id):
250
+ """ same as OpenaiDetector but using zscore """
251
+ seed = get_seed_rng(self.seed, ngram_tokens)
252
+ self.rng.manual_seed(seed)
253
+ rs = torch.rand(self.vocab_size, generator=self.rng) # n
254
+ scores = -(1 - rs).log()
255
+ return scores[token_id]
256
+
257
+ def get_pvalue(self, score: float, ntoks: int, eps: float):
258
+ """ from cdf of a normal distribution """
259
+ mu0 = 1
260
+ sigma0 = np.pi / np.sqrt(6)
261
+ zscore = (score/ntoks - mu0) / (sigma0 / np.sqrt(ntoks))
262
+ pvalue = 0.5 * special.erfc(zscore / np.sqrt(2))
263
+ return max(pvalue, eps)
wm_interactive/core/generator.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ from .hashing import get_seed_rng
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ class WmGenerator():
10
+ def __init__(self,
11
+ model: AutoModelForCausalLM,
12
+ tokenizer: AutoTokenizer,
13
+ ngram: int = 1,
14
+ seed: int = 0,
15
+ **kwargs
16
+ ):
17
+ # model config
18
+ self.tokenizer = tokenizer
19
+ self.vocab_size = self.tokenizer.vocab_size
20
+ self.model = model
21
+ self.max_seq_len = model.config.max_sequence_length if 'max_sequence_length' in model.config.to_dict() else 2048
22
+ self.pad_id = model.config.pad_token_id if model.config.pad_token_id is not None else -1
23
+ self.eos_id = model.config.eos_token_id
24
+ # watermark config
25
+ self.ngram = ngram
26
+ self.seed = seed
27
+ self.rng = torch.Generator()
28
+ self.rng.manual_seed(self.seed)
29
+
30
+ @torch.no_grad()
31
+ def generate(
32
+ self,
33
+ prompt: str,
34
+ max_gen_len: int,
35
+ temperature: float = 0.8,
36
+ top_p: float = 0.95,
37
+ return_aux: bool = False,
38
+ ) -> str:
39
+
40
+ prompt_tokens = self.tokenizer.encode(prompt)
41
+ prompt_size = len(prompt_tokens)
42
+ total_len = min(self.max_seq_len, max_gen_len + prompt_size)
43
+ tokens = torch.full((1, total_len), self.pad_id).to(device).long()
44
+ if total_len < prompt_size:
45
+ print("prompt is bigger than max sequence length")
46
+ prompt_tokens = prompt_tokens[:total_len]
47
+ tokens[0, :len(prompt_tokens)] = torch.tensor(prompt_tokens).long()
48
+ input_text_mask = tokens != self.pad_id
49
+
50
+ start_pos = prompt_size
51
+ prev_pos = 0
52
+ for cur_pos in range(start_pos, total_len):
53
+ past_key_values = outputs.past_key_values if prev_pos > 0 else None
54
+ outputs = self.model.forward(
55
+ tokens[:, prev_pos:cur_pos],
56
+ use_cache=True,
57
+ past_key_values=past_key_values
58
+ )
59
+ ngram_tokens = tokens[0, cur_pos-self.ngram:cur_pos].tolist()
60
+ aux = {
61
+ 'ngram_tokens': ngram_tokens,
62
+ 'cur_pos': cur_pos,
63
+ }
64
+ next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p)
65
+ tokens[0, cur_pos] = torch.where(input_text_mask[0, cur_pos], tokens[0, cur_pos], next_tok)
66
+ prev_pos = cur_pos
67
+ if next_tok == self.eos_id:
68
+ break
69
+
70
+ # cut to max gen len
71
+ t = tokens[0, :prompt_size + max_gen_len].tolist()
72
+ # cut to eos tok if any
73
+ finish_reason = 'length'
74
+ try:
75
+ find_eos = t[prompt_size:].index(self.eos_id)
76
+ if find_eos:
77
+ t = t[: prompt_size+find_eos]
78
+ finish_reason = 'eos'
79
+ except ValueError:
80
+ pass
81
+ aux_info = {
82
+ 't': t,
83
+ 'finish_reason': finish_reason,
84
+ 'n_toks_gen': len(t) - prompt_size,
85
+ 'n_toks_tot': len(t),
86
+ }
87
+ decoded = self.tokenizer.decode(t)
88
+
89
+ if return_aux:
90
+ return decoded, aux_info
91
+ return decoded
92
+
93
+ def sample_next(
94
+ self,
95
+ logits: torch.FloatTensor, # (1, vocab_size): logits for last token
96
+ aux: dict, # ngram_tokens (1, ngram): tokens to consider when seeding
97
+ temperature: float = 0.8, # temperature for sampling
98
+ top_p: float = 0.95, # top p for sampling
99
+ ):
100
+ """Vanilla sampling with temperature and top p."""
101
+ if temperature > 0:
102
+ probs = torch.softmax(logits / temperature, dim=-1)
103
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
104
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
105
+ mask = probs_sum - probs_sort > top_p
106
+ probs_sort[mask] = 0.0
107
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
108
+ next_token = torch.multinomial(probs_sort, num_samples=1) # one hot of next token, ordered by original probs
109
+ next_token = torch.gather(probs_idx, -1, next_token) # one hot of next token, ordered by vocab
110
+ else:
111
+ next_token = torch.argmax(logits, dim=-1)
112
+ next_token = next_token.reshape(-1)[0] # Get the single token value
113
+ return next_token
114
+
115
+
116
+ class OpenaiGenerator(WmGenerator):
117
+ """
118
+ Generate text using LLaMA and Aaronson's watermarking method.
119
+ From ngram tokens, select the next token based on the following:
120
+ - hash the ngram tokens and get a seed
121
+ - use the seed to generate V random number r between [0,1]
122
+ - select argmax ( r^(1/p) )
123
+ """
124
+ def __init__(self, *args, **kwargs):
125
+ super().__init__(*args, **kwargs)
126
+
127
+ def sample_next(
128
+ self,
129
+ logits: torch.FloatTensor, # (1, vocab_size): logits for last token
130
+ aux: dict, # (1, ngram): tokens to consider when seeding
131
+ temperature: float = 0.8, # temperature for sampling
132
+ top_p: float = 0.95, # top p for sampling
133
+ ):
134
+ ngram_tokens = aux['ngram_tokens']
135
+ if temperature > 0:
136
+ probs = torch.softmax(logits / temperature, dim=-1)
137
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
138
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
139
+ mask = probs_sum - probs_sort > top_p
140
+ probs_sort[mask] = 0.0
141
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
142
+ # seed with hash of ngram tokens
143
+ seed = get_seed_rng(self.seed, ngram_tokens)
144
+ self.rng.manual_seed(seed)
145
+ # generate rs randomly between [0,1]
146
+ rs = torch.rand(self.vocab_size, generator=self.rng) # n
147
+ rs = torch.Tensor(rs).to(probs_sort.device)
148
+ rs = rs[probs_idx[0]]
149
+ # compute r^(1/p)
150
+ probs_sort[0] = torch.pow(rs, 1/probs_sort[0])
151
+ # select argmax ( r^(1/p) )
152
+ next_token = torch.argmax(probs_sort, dim=-1, keepdim=True)
153
+ next_token = torch.gather(probs_idx, -1, next_token)
154
+ else:
155
+ next_token = torch.argmax(logits, dim=-1)
156
+ next_token = next_token.reshape(-1)[0] # Get the single token value
157
+ return next_token
158
+
159
+
160
+ class MarylandGenerator(WmGenerator):
161
+ """
162
+ Generate text using LLaMA and Maryland's watemrarking method.
163
+ From ngram tokens, select the next token based on the following:
164
+ - hash the ngram tokens and get a seed
165
+ - use the seed to partition the vocabulary into greenlist (gamma*V words) and blacklist
166
+ - add delta to greenlist words' logits
167
+ """
168
+ def __init__(self,
169
+ *args,
170
+ gamma: float = 0.5,
171
+ delta: float = 1.0,
172
+ **kwargs
173
+ ):
174
+ super().__init__(*args, **kwargs)
175
+ self.gamma = gamma
176
+ self.delta = delta
177
+
178
+ def sample_next(
179
+ self,
180
+ logits: torch.FloatTensor, # (1, vocab_size): logits for last token
181
+ aux: dict, # ngram_tokens (1, ngram): tokens to consider when seeding
182
+ temperature: float = 0.8, # temperature for sampling
183
+ top_p: float = 0.95, # top p for sampling
184
+ ):
185
+ ngram_tokens = aux['ngram_tokens']
186
+ logits = self.logits_processor(logits, ngram_tokens)
187
+ if temperature > 0:
188
+ probs = torch.softmax(logits / temperature, dim=-1)
189
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
190
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
191
+ mask = probs_sum - probs_sort > top_p
192
+ probs_sort[mask] = 0.0
193
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
194
+ next_token = torch.multinomial(probs_sort, num_samples=1) # one hot of next token, ordered by original probs
195
+ next_token = torch.gather(probs_idx, -1, next_token) # one hot of next token, ordered by vocab
196
+ else:
197
+ next_token = torch.argmax(logits, dim=-1)
198
+ next_token = next_token.reshape(-1)[0] # Get the single token value
199
+ return next_token
200
+
201
+ def logits_processor(self, logits, ngram_tokens):
202
+ """Process logits to mask out words in greenlist."""
203
+ logits = logits.clone()
204
+ seed = get_seed_rng(self.seed, ngram_tokens)
205
+ self.rng.manual_seed(seed)
206
+ vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
207
+ greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
208
+ bias = torch.zeros(self.vocab_size).to(logits.device)
209
+ bias[greenlist] = self.delta
210
+ logits[0] += bias # add bias to greenlist words
211
+ return logits
wm_interactive/core/hashing.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def get_seed_rng(
3
+ start,
4
+ input_ids: list[int],
5
+ salt = 35317
6
+ ) -> int:
7
+ """
8
+ Seed RNG with hash of input_ids.
9
+ Adapted from https://github.com/jwkirchenbauer/lm-watermarking
10
+ """
11
+ for ii in input_ids:
12
+ start = (start * salt + ii) % (2 ** 64 - 1)
13
+ return int(start)
wm_interactive/core/main.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main script for watermark detection.
3
+ Test with:
4
+ python -m wm_interactive.core.main --model_name smollm2-135m --prompt_path data/prompts.json --method maryland --delta 4.0 --ngram 1
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import time
10
+ import tqdm
11
+ import torch
12
+ import numpy as np
13
+ import pandas as pd
14
+ import argparse
15
+
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+
18
+ from wm_interactive.core.generator import WmGenerator, OpenaiGenerator, MarylandGenerator
19
+ from wm_interactive.core.detector import WmDetector, OpenaiDetector, OpenaiDetectorZ, MarylandDetector, MarylandDetectorZ
20
+
21
+ # model names mapping
22
+ model_names = {
23
+ # 'llama-3.2-1b': 'meta-llama/Llama-3.2-1B-Instruct',
24
+ 'smollm2-135m': 'HuggingFaceTB/SmolLM2-135M-Instruct',
25
+ 'smollm2-360m': 'HuggingFaceTB/SmolLM2-360M-Instruct',
26
+ }
27
+
28
+ CACHE_DIR = "wm_interactive/static/hf_cache"
29
+
30
+
31
+ def load_prompts(json_path: str, prompt_type: str = "smollm", nsamples: int = None) -> list[dict]:
32
+ """Load prompts from a JSON file.
33
+
34
+ Args:
35
+ json_path: Path to the JSON file
36
+ prompt_type: Type of prompt dataset (alpaca, smollm)
37
+ nsamples: Number of samples to load (if None, load all)
38
+
39
+ Returns:
40
+ List of prompts
41
+ """
42
+ if not os.path.exists(json_path):
43
+ raise FileNotFoundError(f"File {json_path} not found")
44
+
45
+ with open(json_path, 'r') as f:
46
+ data = json.load(f)
47
+
48
+ if prompt_type == "alpaca":
49
+ prompts = [{"instruction": item["instruction"]} for item in data]
50
+ elif prompt_type == "smollm":
51
+ prompts = []
52
+ for item in data:
53
+ prompt = "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n"
54
+ prompt += f"<|im_start|>user\n{item['instruction']}<|im_end|>\n<|im_start|>assistant\n"
55
+ prompts.append({"instruction": prompt})
56
+ else:
57
+ raise ValueError(f"Prompt type {prompt_type} not supported")
58
+
59
+ if nsamples is not None:
60
+ prompts = prompts[:nsamples]
61
+
62
+ return prompts
63
+
64
+ def load_results(json_path: str, result_key: str = "result", nsamples: int = None) -> list[str]:
65
+ """Load results from a JSONL file.
66
+
67
+ Args:
68
+ json_path: Path to the JSONL file
69
+ result_key: Key to extract from each JSON line
70
+ nsamples: Number of samples to load (if None, load all)
71
+
72
+ Returns:
73
+ List of results
74
+ """
75
+ if not os.path.exists(json_path):
76
+ raise FileNotFoundError(f"File {json_path} not found")
77
+
78
+ results = []
79
+ with open(json_path, 'r') as f:
80
+ for line in f:
81
+ if line.strip(): # Skip empty lines
82
+ data = json.loads(line)
83
+ results.append(data[result_key])
84
+ if nsamples is not None and len(results) >= nsamples:
85
+ break
86
+
87
+ return results
88
+
89
+ def get_args_parser():
90
+ parser = argparse.ArgumentParser('Args', add_help=False)
91
+
92
+ # model parameters
93
+ parser.add_argument('--model_name', type=str, required=True,
94
+ help='Name of the model to use. Choose from: llama-3.2-1b, smollm2-135m')
95
+
96
+ # prompts parameters
97
+ parser.add_argument('--prompt_path', type=str, default=None,
98
+ help='Path to the prompt dataset. Required if --prompt is not provided')
99
+ parser.add_argument('--prompt_type', type=str, default="smollm",
100
+ help='Type of prompt dataset. Only used if --prompt_path is provided')
101
+ parser.add_argument('--prompt', type=str, nargs='+', default=None,
102
+ help='List of prompts to use. If not provided, prompts will be loaded from --prompt_path')
103
+
104
+ # generation parameters
105
+ parser.add_argument('--temperature', type=float, default=0.8,
106
+ help='Temperature for sampling (higher = more random)')
107
+ parser.add_argument('--top_p', type=float, default=0.95,
108
+ help='Top p for nucleus sampling (lower = more focused)')
109
+ parser.add_argument('--max_gen_len', type=int, default=256,
110
+ help='Maximum length of generated text')
111
+
112
+ # watermark parameters
113
+ parser.add_argument('--method', type=str, default='none',
114
+ help='Watermarking method. Choose from: none (no watermarking), openai (Aaronson et al.), maryland (Kirchenbauer et al.)')
115
+ parser.add_argument('--method_detect', type=str, default='same',
116
+ help='Statistical test to detect watermark. Choose from: same (same as method), openai, openaiz, maryland, marylandz')
117
+ parser.add_argument('--seed', type=int, default=0,
118
+ help='Random seed for reproducibility')
119
+ parser.add_argument('--ngram', type=int, default=1,
120
+ help='n-gram size for rng key generation')
121
+ parser.add_argument('--gamma', type=float, default=0.5,
122
+ help='For maryland method: proportion of greenlist tokens')
123
+ parser.add_argument('--delta', type=float, default=2.0,
124
+ help='For maryland method: bias to add to greenlist tokens')
125
+ parser.add_argument('--scoring_method', type=str, default='v2',
126
+ help='Method for scoring. Choose from: none (score every token), v1 (score when context unique), v2 (score when context+token unique)')
127
+
128
+ # experiment parameters
129
+ parser.add_argument('--nsamples', type=int, default=None,
130
+ help='Number of samples to generate from the prompt dataset')
131
+ parser.add_argument('--do_eval', type=bool, default=True,
132
+ help='Whether to evaluate the generated text')
133
+ parser.add_argument('--output_dir', type=str, default='output',
134
+ help='Directory to save results')
135
+
136
+ return parser
137
+
138
+ def main(args):
139
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
140
+ print("{}".format(args).replace(', ', ',\n'))
141
+
142
+ torch.manual_seed(args.seed)
143
+ np.random.seed(args.seed)
144
+
145
+ # build model
146
+ model_name = args.model_name.lower()
147
+ if model_name not in model_names:
148
+ raise ValueError(f"Model {model_name} not supported. Choose from: {list(model_names.keys())}")
149
+ model_name = model_names[model_name]
150
+
151
+ # Load tokenizer and model
152
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
153
+ device = "cuda" if torch.cuda.is_available() else "cpu"
154
+
155
+ model = AutoModelForCausalLM.from_pretrained(
156
+ model_name,
157
+ cache_dir=CACHE_DIR
158
+ ).to(device)
159
+
160
+ # build watermark generator
161
+ if args.method == "none":
162
+ generator = WmGenerator(model, tokenizer)
163
+ elif args.method == "openai":
164
+ generator = OpenaiGenerator(model, tokenizer, args.ngram, args.seed)
165
+ elif args.method == "maryland":
166
+ generator = MarylandGenerator(model, tokenizer, args.ngram, args.seed, gamma=args.gamma, delta=args.delta)
167
+ else:
168
+ raise NotImplementedError("method {} not implemented".format(args.method))
169
+
170
+ # load prompts
171
+ if args.prompt is not None:
172
+ prompts = args.prompt
173
+ prompts = [{"instruction": prompt} for prompt in prompts]
174
+ elif args.prompt_path is not None:
175
+ prompts = load_prompts(json_path=args.prompt_path, prompt_type=args.prompt_type, nsamples=args.nsamples)
176
+ else:
177
+ raise ValueError("Either --prompt or --prompt_path must be provided")
178
+
179
+ # (re)start experiment
180
+ os.makedirs(args.output_dir, exist_ok=True)
181
+ start_point = 0 # if resuming, start from the last line of the file
182
+ if os.path.exists(os.path.join(args.output_dir, f"results.jsonl")):
183
+ with open(os.path.join(args.output_dir, f"results.jsonl"), "r") as f:
184
+ for _ in f:
185
+ start_point += 1
186
+ print(f"Starting from {start_point}")
187
+
188
+ # generate
189
+ all_times = []
190
+ with open(os.path.join(args.output_dir, f"results.jsonl"), "a") as f:
191
+ for ii in range(start_point, len(prompts)):
192
+ # generate text
193
+ time0 = time.time()
194
+ prompt = prompts[ii]["instruction"]
195
+ result = generator.generate(
196
+ prompt,
197
+ max_gen_len=args.max_gen_len,
198
+ temperature=args.temperature,
199
+ top_p=args.top_p
200
+ )
201
+ time1 = time.time()
202
+ # time chunk
203
+ speed = 1 / (time1 - time0)
204
+ eta = (len(prompts) - ii) / speed
205
+ eta = time.strftime("%Hh%Mm%Ss", time.gmtime(eta))
206
+ all_times.append(time1 - time0)
207
+ print(f"Generated {ii:5d} - Speed {speed:.2f} prompts/s - ETA {eta}")
208
+ # log
209
+ f.write(json.dumps({
210
+ "prompt": prompt,
211
+ "result": result[len(prompt):],
212
+ "speed": speed,
213
+ "eta": eta}) + "\n")
214
+ f.flush()
215
+ print(f"Average time per prompt: {np.sum(all_times) / (len(prompts) - start_point) :.2f}")
216
+
217
+ if args.method_detect == 'same':
218
+ args.method_detect = args.method
219
+ if (not args.do_eval) or (args.method_detect not in ["openai", "maryland", "marylandz", "openaiz"]):
220
+ return
221
+
222
+ # build watermark detector
223
+ if args.method_detect == "openai":
224
+ detector = OpenaiDetector(tokenizer, args.ngram, args.seed)
225
+ elif args.method_detect == "openaiz":
226
+ detector = OpenaiDetectorZ(tokenizer, args.ngram, args.seed)
227
+ elif args.method_detect == "maryland":
228
+ detector = MarylandDetector(tokenizer, args.ngram, args.seed, gamma=args.gamma, delta=args.delta)
229
+ elif args.method_detect == "marylandz":
230
+ detector = MarylandDetectorZ(tokenizer, args.ngram, args.seed, gamma=args.gamma, delta=args.delta)
231
+
232
+ # evaluate
233
+ results = load_results(json_path=os.path.join(args.output_dir, f"results.jsonl"), result_key="result", nsamples=args.nsamples)
234
+ log_stats = []
235
+ with open(os.path.join(args.output_dir, 'scores.jsonl'), 'w') as f:
236
+ for text in tqdm.tqdm(results):
237
+ # get token details and pvalues
238
+ token_details = detector.get_details(text, scoring_method=args.scoring_method)
239
+ pvalues, aux_info = detector.get_pvalues_by_tok(token_details)
240
+ # log stats
241
+ log_stat = {
242
+ 'num_token': aux_info['ntoks_scored'],
243
+ 'score': aux_info['final_score'],
244
+ 'pvalue': aux_info['final_pvalue'],
245
+ 'log10_pvalue': np.log10(aux_info['final_pvalue']),
246
+ }
247
+ log_stats.append(log_stat)
248
+ f.write('\n' + json.dumps({k: float(v) for k, v in log_stat.items()}))
249
+ df = pd.DataFrame(log_stats)
250
+ print(f">>> Scores: \n{df.describe(percentiles=[])}")
251
+ print(f"Saved scores to {os.path.join(args.output_dir, 'scores.csv')}")
252
+
253
+
254
+ if __name__ == "__main__":
255
+ args = get_args_parser().parse_args()
256
+ main(args)
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/0ad5ecc2035b7031b88afb544ee95e2d49baa484.lock ADDED
File without changes
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/36293b6099200eb8aeb55ae2c01bca2ba46d80d0.lock ADDED
File without changes
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/44719d2e365acac0637fd25a3acf46494ca45940.lock ADDED
File without changes
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c.lock ADDED
File without changes
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/69503b13f727ba3812b6803e97442a6de05ef5eb.lock ADDED
File without changes
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/8c7b22013909450429303ed10be4398bd63f5457.lock ADDED
File without changes
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa.lock ADDED
File without changes
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/f922b1797f0c88e71addc8393787831f2477a4bd.lock ADDED
File without changes
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/.no_exist/e2c3f7557efbdec707ae3a336371d169783f1da1/added_tokens.json ADDED
File without changes
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/0ad5ecc2035b7031b88afb544ee95e2d49baa484 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82b84012e3add4d01d12ba14442026e49b8cbbaead1f79ecf3d919784f82dc79
3
+ size 800662
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/36293b6099200eb8aeb55ae2c01bca2ba46d80d0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8eb740e8bbe4cff95ea7b4588d17a2432deb16e8075bc5828ff7ba9be94d982a
3
+ size 861
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/44719d2e365acac0637fd25a3acf46494ca45940 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b7379f3ae813529281a5c602bc5a11c1d4e0a99107aaa597fe936c1e813ca52
3
+ size 655
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c
3
+ size 269060552
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/69503b13f727ba3812b6803e97442a6de05ef5eb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b54e8aa4e53d5383e2e4bc635a56b43f9647f7b13832d5d9ecd8f82dac4f510
3
+ size 466391
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/8c7b22013909450429303ed10be4398bd63f5457 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ec77d44f62efeb38d7e044a1db318f6a939438425312dfa333b8382dbad98df
3
+ size 3764
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87b916edaaab66b3899b9d0dd0752727dff6666686da0504d89ae0a6e055a013
3
+ size 132
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/f922b1797f0c88e71addc8393787831f2477a4bd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ca9acddb6525a194ec8ac7a87f24fbba7232a9a15ffa1af0c1224fcd888e47c
3
+ size 2104556
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/refs/main ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71a184f20b0fe5c1a9407ed75fa9633b681779c7f1a5ca478f22fdff69a6c7ab
3
+ size 40
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8eb740e8bbe4cff95ea7b4588d17a2432deb16e8075bc5828ff7ba9be94d982a
3
+ size 861
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/generation_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87b916edaaab66b3899b9d0dd0752727dff6666686da0504d89ae0a6e055a013
3
+ size 132
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/merges.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b54e8aa4e53d5383e2e4bc635a56b43f9647f7b13832d5d9ecd8f82dac4f510
3
+ size 466391
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c
3
+ size 269060552
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b7379f3ae813529281a5c602bc5a11c1d4e0a99107aaa597fe936c1e813ca52
3
+ size 655
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ca9acddb6525a194ec8ac7a87f24fbba7232a9a15ffa1af0c1224fcd888e47c
3
+ size 2104556
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ec77d44f62efeb38d7e044a1db318f6a939438425312dfa333b8382dbad98df
3
+ size 3764
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/vocab.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82b84012e3add4d01d12ba14442026e49b8cbbaead1f79ecf3d919784f82dc79
3
+ size 800662
wm_interactive/static/styles.css ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ background-color: #f7f7f8;
3
+ color: #1a1a1a;
4
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
5
+ line-height: 1.5;
6
+ padding: 0;
7
+ margin: 0;
8
+ }
9
+
10
+ .container {
11
+ background-color: transparent;
12
+ box-shadow: none;
13
+ max-width: 1000px;
14
+ padding: 20px;
15
+ margin: 0 auto;
16
+ }
17
+
18
+ h1 {
19
+ color: #1a1a1a;
20
+ font-size: 24px;
21
+ font-weight: 600;
22
+ margin-bottom: 30px;
23
+ }
24
+
25
+ .input-section {
26
+ display: flex;
27
+ flex-direction: column;
28
+ gap: 24px;
29
+ margin-bottom: 30px;
30
+ }
31
+
32
+ .input-section textarea {
33
+ width: 100%;
34
+ padding: 16px;
35
+ background-color: #ffffff;
36
+ border: 1px solid #e5e5e5;
37
+ border-radius: 12px;
38
+ resize: none;
39
+ font-size: 16px;
40
+ line-height: 1.5;
41
+ color: #1a1a1a;
42
+ transition: border-color 0.2s;
43
+ }
44
+
45
+ .input-section textarea:focus {
46
+ outline: none;
47
+ border-color: #10a37f;
48
+ box-shadow: 0 0 0 2px rgba(16, 163, 127, 0.2);
49
+ }
50
+
51
+ .input-section #prompt_text {
52
+ height: 80px;
53
+ padding-right: 52px;
54
+ }
55
+
56
+ .input-section #user_text {
57
+ height: 160px;
58
+ }
59
+
60
+ .button-container {
61
+ display: flex;
62
+ gap: 12px;
63
+ justify-content: center;
64
+ }
65
+
66
+ .btn {
67
+ padding: 8px 16px;
68
+ font-size: 14px;
69
+ font-weight: 500;
70
+ border-radius: 6px;
71
+ transition: all 0.2s;
72
+ }
73
+
74
+ .btn-primary {
75
+ background-color: #10a37f;
76
+ border-color: #10a37f;
77
+ }
78
+
79
+ .btn-primary:hover:not(:disabled) {
80
+ background-color: #0e8d6e;
81
+ border-color: #0e8d6e;
82
+ }
83
+
84
+ .btn-secondary {
85
+ background-color: #40414f;
86
+ border-color: #565869;
87
+ color: #ececf1;
88
+ }
89
+
90
+ .btn-secondary:hover:not(:disabled) {
91
+ background-color: #4a4b5a;
92
+ border-color: #6b6c7b;
93
+ }
94
+
95
+ .token-display {
96
+ margin: 24px 0;
97
+ padding: 16px;
98
+ background-color: #ffffff;
99
+ border: 1px solid #e5e5e5;
100
+ border-radius: 12px;
101
+ min-height: 100px;
102
+ font-size: 15px;
103
+ line-height: 1.6;
104
+ }
105
+
106
+ .stats-container {
107
+ display: grid;
108
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
109
+ gap: 20px;
110
+ margin-top: 30px;
111
+ padding: 20px;
112
+ background-color: #ffffff;
113
+ border-radius: 12px;
114
+ border: 1px solid #e5e5e5;
115
+ }
116
+
117
+ .stats-container > div {
118
+ text-align: center;
119
+ padding: 16px;
120
+ border-radius: 8px;
121
+ background-color: #f7f7f8;
122
+ }
123
+
124
+ .stat-value {
125
+ font-size: 28px;
126
+ font-weight: 600;
127
+ color: #1a1a1a;
128
+ margin-bottom: 8px;
129
+ }
130
+
131
+ .stat-label {
132
+ position: relative;
133
+ color: #6e6e80;
134
+ font-size: 14px;
135
+ display: inline-flex;
136
+ align-items: center;
137
+ justify-content: center;
138
+ gap: 6px;
139
+ }
140
+
141
+ .help-icon {
142
+ color: #6e6e80;
143
+ font-size: 12px;
144
+ opacity: 0.8;
145
+ transition: opacity 0.2s;
146
+ cursor: help;
147
+ }
148
+
149
+ .help-tooltip {
150
+ visibility: hidden;
151
+ position: absolute;
152
+ z-index: 1000;
153
+ bottom: 125%;
154
+ left: 50%;
155
+ transform: translateX(-50%);
156
+ background-color: #1a1a1a;
157
+ color: #ffffff;
158
+ padding: 8px 12px;
159
+ border-radius: 6px;
160
+ font-size: 12px;
161
+ width: max-content;
162
+ max-width: 200px;
163
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
164
+ pointer-events: none;
165
+ opacity: 0;
166
+ transition: opacity 0.2s;
167
+ }
168
+
169
+ .help-tooltip::after {
170
+ content: "";
171
+ position: absolute;
172
+ top: 100%;
173
+ left: 50%;
174
+ margin-left: -5px;
175
+ border-width: 5px;
176
+ border-style: solid;
177
+ border-color: #1a1a1a transparent transparent transparent;
178
+ }
179
+
180
+ .help-icon:hover + .help-tooltip {
181
+ visibility: visible;
182
+ opacity: 1;
183
+ }
184
+
185
+ .token {
186
+ padding: 2px 4px;
187
+ margin: 1px;
188
+ border-radius: 4px;
189
+ font-family: 'SF Mono', 'Menlo', 'Monaco', Courier, monospace;
190
+ transition: background-color 0.2s;
191
+ position: relative;
192
+ cursor: pointer;
193
+ }
194
+
195
+ .token:hover {
196
+ filter: brightness(1.1);
197
+ }
198
+
199
+ .token-tooltip {
200
+ visibility: hidden;
201
+ position: absolute;
202
+ z-index: 1000;
203
+ bottom: 125%;
204
+ left: 50%;
205
+ transform: translateX(-50%);
206
+ background-color: #1a1a1a;
207
+ color: #ffffff;
208
+ padding: 8px 12px;
209
+ border-radius: 6px;
210
+ font-size: 12px;
211
+ white-space: nowrap;
212
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
213
+ pointer-events: none;
214
+ }
215
+
216
+ .token-tooltip::after {
217
+ content: "";
218
+ position: absolute;
219
+ top: 100%;
220
+ left: 50%;
221
+ margin-left: -5px;
222
+ border-width: 5px;
223
+ border-style: solid;
224
+ border-color: #1a1a1a transparent transparent transparent;
225
+ }
226
+
227
+ .token:hover .token-tooltip {
228
+ visibility: visible;
229
+ }
230
+
231
+ /* Modal styling */
232
+ .modal-content {
233
+ background-color: #ffffff;
234
+ color: #1a1a1a;
235
+ border: 1px solid #e5e5e5;
236
+ }
237
+
238
+ .modal-header {
239
+ border-bottom-color: #e5e5e5;
240
+ }
241
+
242
+ .modal-footer {
243
+ border-top-color: #e5e5e5;
244
+ }
245
+
246
+ .form-control, .form-select {
247
+ background-color: #f7f7f8;
248
+ border-color: #e5e5e5;
249
+ color: #1a1a1a;
250
+ }
251
+
252
+ .form-control:focus, .form-select:focus {
253
+ background-color: #f7f7f8;
254
+ border-color: #10a37f;
255
+ color: #1a1a1a;
256
+ box-shadow: 0 0 0 2px rgba(16, 163, 127, 0.2);
257
+ }
258
+
259
+ .form-text {
260
+ color: #6e6e80;
261
+ }
262
+
263
+ .btn-close {
264
+ filter: none;
265
+ }
266
+
267
+ /* Mobile-specific styles */
268
+ @media (max-width: 768px) {
269
+ .container {
270
+ padding: 15px;
271
+ }
272
+
273
+ .stats-container {
274
+ grid-template-columns: repeat(2, 1fr);
275
+ gap: 12px;
276
+ padding: 12px;
277
+ }
278
+
279
+ .stat-value {
280
+ font-size: 24px;
281
+ }
282
+
283
+ .stat-label {
284
+ font-size: 12px;
285
+ }
286
+
287
+ .help-tooltip {
288
+ max-width: 160px;
289
+ }
290
+ }
291
+
292
+ /* Light scrollbar */
293
+ ::-webkit-scrollbar {
294
+ width: 8px;
295
+ height: 8px;
296
+ }
297
+
298
+ ::-webkit-scrollbar-track {
299
+ background: #f7f7f8;
300
+ }
301
+
302
+ ::-webkit-scrollbar-thumb {
303
+ background: #d1d1d1;
304
+ border-radius: 4px;
305
+ }
306
+
307
+ ::-webkit-scrollbar-thumb:hover {
308
+ background: #a8a8a8;
309
+ }
310
+
311
+ .prompt-container {
312
+ position: relative;
313
+ width: 100%;
314
+ }
315
+
316
+ .floating-btn {
317
+ position: absolute;
318
+ bottom: 16px;
319
+ right: 16px;
320
+ width: 36px;
321
+ height: 36px;
322
+ border-radius: 12px;
323
+ display: flex;
324
+ align-items: center;
325
+ justify-content: center;
326
+ border: none;
327
+ background-color: #10a37f;
328
+ color: #ffffff;
329
+ cursor: pointer;
330
+ transition: all 0.2s ease;
331
+ padding: 0;
332
+ }
333
+
334
+ .floating-btn:hover {
335
+ background-color: #0e8d6e;
336
+ }
337
+
338
+ .floating-btn:disabled {
339
+ background-color: #565869;
340
+ cursor: not-allowed;
341
+ }
342
+
343
+ .floating-btn i {
344
+ font-size: 16px;
345
+ }
346
+
347
+ .floating-btn .stop-icon {
348
+ display: none;
349
+ }
350
+
351
+ .floating-btn.generating .send-icon {
352
+ display: none;
353
+ }
354
+
355
+ .floating-btn.generating .stop-icon {
356
+ display: block;
357
+ }
wm_interactive/templates/index.html ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Watermark Detector</title>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1">
6
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
7
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/font/bootstrap-icons.css">
8
+ <link rel="stylesheet" href="{{ url_for('static', filename='styles.css') }}">
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <div class="d-flex justify-content-between align-items-center">
13
+ <h1>Interactive watermark detector</h1>
14
+ <button class="btn btn-outline-secondary" data-bs-toggle="modal" data-bs-target="#paramsModal">
15
+ <i class="bi bi-gear"></i>
16
+ </button>
17
+ </div>
18
+
19
+ <!-- Advanced Parameters Modal -->
20
+ <div class="modal fade" id="paramsModal" tabindex="-1">
21
+ <div class="modal-dialog">
22
+ <div class="modal-content">
23
+ <div class="modal-header">
24
+ <h5 class="modal-title">Advanced Parameters</h5>
25
+ <button type="button" class="btn-close" data-bs-dismiss="modal"></button>
26
+ </div>
27
+ <div class="modal-body">
28
+ <div class="mb-3">
29
+ <label for="detectorType" class="form-label">Detector Type</label>
30
+ <select class="form-select" id="detectorType">
31
+ <option value="maryland">Maryland</option>
32
+ <option value="marylandz">Maryland Z-score</option>
33
+ <option value="openai">OpenAI</option>
34
+ <option value="openaiz">OpenAI Z-score</option>
35
+ </select>
36
+ <div class="form-text">Type of watermark detection algorithm</div>
37
+ </div>
38
+ <div class="mb-3">
39
+ <label for="seed" class="form-label">Seed</label>
40
+ <input type="number" class="form-control" id="seed" value="0">
41
+ <div class="form-text">Random seed for the watermark detector</div>
42
+ </div>
43
+ <div class="mb-3">
44
+ <label for="ngram" class="form-label">N-gram Size</label>
45
+ <input type="number" class="form-control" id="ngram" value="1">
46
+ <div class="form-text">Size of the n-gram window used for detection</div>
47
+ </div>
48
+ <div class="mb-3">
49
+ <label for="delta" class="form-label">Delta</label>
50
+ <input type="number" step="0.1" class="form-control" id="delta" value="2.0">
51
+ <div class="form-text">Bias added to greenlist tokens (for Maryland method)</div>
52
+ </div>
53
+ <div class="mb-3">
54
+ <label for="temperature" class="form-label">Temperature</label>
55
+ <input type="number" step="0.1" class="form-control" id="temperature" value="0.8">
56
+ <div class="form-text">Temperature for sampling (higher = more random)</div>
57
+ </div>
58
+ </div>
59
+ <div class="modal-footer">
60
+ <button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
61
+ <button type="button" class="btn btn-primary" id="applyParams">Apply</button>
62
+ </div>
63
+ </div>
64
+ </div>
65
+ </div>
66
+
67
+ <!-- Input Form -->
68
+ <div class="input-section">
69
+ <div class="prompt-container">
70
+ <textarea id="prompt_text"
71
+ placeholder="Enter your prompt here to generate text with the model..."></textarea>
72
+ <button class="floating-btn" id="generateBtn">
73
+ <i class="bi bi-send-fill send-icon"></i>
74
+ <i class="bi bi-stop-fill stop-icon"></i>
75
+ </button>
76
+ </div>
77
+ <textarea id="user_text"
78
+ placeholder="Generated text will appear here. Replace or edit this text to see how watermark detection works."></textarea>
79
+ </div>
80
+
81
+ <!-- Token Display -->
82
+ <div class="token-display" id="tokenDisplay"></div>
83
+
84
+ <!-- Statistics -->
85
+ <div class="stats-container">
86
+ <div>
87
+ <div class="stat-value" id="tokenCount">0</div>
88
+ <div class="stat-label">
89
+ Tokens
90
+ <i class="bi bi-question-circle help-icon"></i>
91
+ <span class="help-tooltip">Total number of tokens in the text</span>
92
+ </div>
93
+ </div>
94
+ <div>
95
+ <div class="stat-value" id="scoredTokens">0</div>
96
+ <div class="stat-label">
97
+ Scored Tokens
98
+ <i class="bi bi-question-circle help-icon"></i>
99
+ <span class="help-tooltip">Number of tokens that were actually scored by the detector (excludes first n-gram tokens and duplicates)</span>
100
+ </div>
101
+ </div>
102
+ <div>
103
+ <div class="stat-value" id="finalScore">0.00</div>
104
+ <div class="stat-label">
105
+ Final Score
106
+ <i class="bi bi-question-circle help-icon"></i>
107
+ <span class="help-tooltip">Cumulative score from all scored tokens. Higher values indicate more likely watermarked text</span>
108
+ </div>
109
+ </div>
110
+ <div>
111
+ <div class="stat-value" id="pValue">0.500</div>
112
+ <div class="stat-label">
113
+ P-value
114
+ <i class="bi bi-question-circle help-icon"></i>
115
+ <span class="help-tooltip">Statistical significance of the score. Lower values indicate stronger evidence of watermarking (p < 0.05 is typically considered significant)</span>
116
+ </div>
117
+ </div>
118
+ </div>
119
+ </div>
120
+
121
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
122
+ <script>
123
+ let debounceTimeout = null;
124
+ let abortController = null;
125
+ const textarea = document.getElementById('user_text');
126
+ const promptArea = document.getElementById('prompt_text');
127
+ const generateBtn = document.getElementById('generateBtn');
128
+ const tokenDisplay = document.getElementById('tokenDisplay');
129
+ const tokenCount = document.getElementById('tokenCount');
130
+ const scoredTokens = document.getElementById('scoredTokens');
131
+ const finalScore = document.getElementById('finalScore');
132
+ const pValue = document.getElementById('pValue');
133
+ const applyParamsBtn = document.getElementById('applyParams');
134
+ const seedInput = document.getElementById('seed');
135
+ const ngramInput = document.getElementById('ngram');
136
+ const detectorTypeSelect = document.getElementById('detectorType');
137
+ const deltaInput = document.getElementById('delta');
138
+ const temperatureInput = document.getElementById('temperature');
139
+
140
+ function startGeneration() {
141
+ const prompt = promptArea.value.trim();
142
+ if (!prompt) {
143
+ alert('Please enter a prompt first.');
144
+ return;
145
+ }
146
+
147
+ generateBtn.classList.add('generating');
148
+ textarea.value = '';
149
+
150
+ // Create new AbortController for this request
151
+ abortController = new AbortController();
152
+
153
+ // Get current parameters
154
+ const params = {
155
+ detector_type: detectorTypeSelect.value,
156
+ seed: parseInt(seedInput.value) || 0,
157
+ ngram: parseInt(ngramInput.value) || 1,
158
+ delta: parseFloat(deltaInput.value) || 2.0,
159
+ temperature: parseFloat(temperatureInput.value) || 0.8
160
+ };
161
+
162
+ // Create headers for SSE
163
+ const headers = new Headers({
164
+ 'Content-Type': 'application/json',
165
+ 'Accept': 'text/event-stream',
166
+ });
167
+
168
+ // Start fetch request with signal
169
+ fetch('/generate', {
170
+ method: 'POST',
171
+ headers: headers,
172
+ body: JSON.stringify({
173
+ prompt: prompt,
174
+ params: params
175
+ }),
176
+ signal: abortController.signal // Add the abort signal
177
+ }).then(response => {
178
+ const reader = response.body.getReader();
179
+ const decoder = new TextDecoder();
180
+ let buffer = '';
181
+
182
+ function processText(text) {
183
+ const lines = text.split('\n');
184
+
185
+ for (const line of lines) {
186
+ if (line.startsWith('data: ')) {
187
+ try {
188
+ const data = JSON.parse(line.slice(6));
189
+
190
+ if (data.error) {
191
+ alert('Error: ' + data.error);
192
+ stopGeneration();
193
+ return;
194
+ }
195
+
196
+ if (data.token) {
197
+ // Append new token to existing text
198
+ textarea.value += data.token;
199
+ updateTokenization();
200
+ }
201
+
202
+ if (data.text) {
203
+ // Final text (only used if something went wrong with streaming)
204
+ textarea.value = data.text;
205
+ updateTokenization();
206
+ }
207
+
208
+ if (data.done) {
209
+ stopGeneration();
210
+ }
211
+ } catch (e) {
212
+ console.error('Error parsing SSE data:', e);
213
+ }
214
+ }
215
+ }
216
+ }
217
+
218
+ function pump() {
219
+ return reader.read().then(({value, done}) => {
220
+ if (done) {
221
+ if (buffer.length > 0) {
222
+ processText(buffer);
223
+ }
224
+ return;
225
+ }
226
+
227
+ buffer += decoder.decode(value, {stream: true});
228
+ const lines = buffer.split('\n\n');
229
+ buffer = lines.pop();
230
+
231
+ for (const line of lines) {
232
+ processText(line);
233
+ }
234
+
235
+ return pump();
236
+ });
237
+ }
238
+
239
+ return pump();
240
+ })
241
+ .catch(error => {
242
+ if (error.name === 'AbortError') {
243
+ console.log('Generation stopped by user');
244
+ } else {
245
+ console.error('Error:', error);
246
+ alert('Error: Failed to generate text');
247
+ }
248
+ })
249
+ .finally(() => {
250
+ generateBtn.classList.remove('generating');
251
+ abortController = null;
252
+ });
253
+ }
254
+
255
+ function stopGeneration() {
256
+ if (abortController) {
257
+ abortController.abort();
258
+ abortController = null;
259
+ }
260
+ generateBtn.classList.remove('generating');
261
+ }
262
+
263
+ // Remove BOTH old event listeners and add just one new one
264
+ generateBtn.addEventListener('click', function(e) {
265
+ e.preventDefault(); // Prevent any double triggers
266
+ if (generateBtn.classList.contains('generating')) {
267
+ stopGeneration();
268
+ } else {
269
+ startGeneration();
270
+ }
271
+ });
272
+
273
+ async function updateTokenization() {
274
+ const text = textarea.value;
275
+ try {
276
+ // Validate parameters before sending
277
+ const seed = parseInt(seedInput.value);
278
+ const ngram = parseInt(ngramInput.value);
279
+ const delta = parseFloat(deltaInput.value);
280
+ const temperature = parseFloat(temperatureInput.value);
281
+
282
+ const response = await fetch('/tokenize', {
283
+ method: 'POST',
284
+ headers: {
285
+ 'Content-Type': 'application/json',
286
+ },
287
+ body: JSON.stringify({
288
+ text: text,
289
+ params: {
290
+ detector_type: detectorTypeSelect.value,
291
+ seed: isNaN(seed) ? 0 : seed,
292
+ ngram: isNaN(ngram) ? 1 : ngram,
293
+ delta: isNaN(delta) ? 2.0 : delta,
294
+ temperature: isNaN(temperature) ? 0.8 : temperature
295
+ }
296
+ })
297
+ });
298
+
299
+ if (!response.ok) {
300
+ const errorData = await response.json();
301
+ throw new Error(errorData.error || `HTTP error! status: ${response.status}`);
302
+ }
303
+
304
+ const data = await response.json();
305
+
306
+ if (data.error) {
307
+ throw new Error(data.error);
308
+ }
309
+
310
+ // Update token display
311
+ tokenDisplay.innerHTML = data.tokens.map((token, i) => {
312
+ const score = data.scores[i];
313
+ const pvalue = data.pvalues[i];
314
+ const scoreDisplay = (score !== null && !isNaN(score)) ? score.toFixed(3) : 'N/A';
315
+ const pvalueDisplay = (pvalue !== null && !isNaN(pvalue)) ? formatPValue(pvalue) : 'N/A';
316
+
317
+ return `<span class="token" style="background-color: ${data.colors[i]}">
318
+ ${token}
319
+ <div class="token-tooltip">
320
+ Score: ${scoreDisplay}<br>
321
+ P-value: ${pvalueDisplay}
322
+ </div>
323
+ </span>`;
324
+ }).join('');
325
+
326
+ // Update counts and stats - safely handle null values
327
+ tokenCount.textContent = data.token_count || 0;
328
+ scoredTokens.textContent = data.ntoks_scored || 0;
329
+ finalScore.textContent = (data.final_score !== null && !isNaN(data.final_score)) ?
330
+ data.final_score.toFixed(2) : '0.00';
331
+ pValue.textContent = (data.final_pvalue !== null && !isNaN(data.final_pvalue)) ?
332
+ formatPValue(data.final_pvalue) : '0.500';
333
+
334
+ // Clear any previous error
335
+ const existingError = tokenDisplay.querySelector('.alert-danger');
336
+ if (existingError) {
337
+ existingError.remove();
338
+ }
339
+ } catch (error) {
340
+ console.error('Error updating tokenization:', error);
341
+ // Show detailed error to user
342
+ tokenDisplay.innerHTML = `<div class="alert alert-danger">
343
+ <strong>Error:</strong> ${error.message || 'Error updating results. Please try again.'}
344
+ </div>`;
345
+
346
+ // Reset stats on error
347
+ tokenCount.textContent = '0';
348
+ scoredTokens.textContent = '0';
349
+ finalScore.textContent = '0.00';
350
+ pValue.textContent = '0.500';
351
+ }
352
+ }
353
+
354
+ // Increase debounce timeout and ensure it's properly cleared
355
+ textarea.addEventListener('input', function() {
356
+ if (debounceTimeout) {
357
+ clearTimeout(debounceTimeout);
358
+ }
359
+ debounceTimeout = setTimeout(updateTokenization, 500); // Increased to 500ms
360
+ });
361
+
362
+ // Add input event listeners for parameter fields to trigger updates
363
+ seedInput.addEventListener('input', function() {
364
+ const value = this.value === '' ? '' : parseInt(this.value);
365
+ if (isNaN(value) && this.value !== '') {
366
+ this.value = "0";
367
+ }
368
+ if (debounceTimeout) {
369
+ clearTimeout(debounceTimeout);
370
+ }
371
+ debounceTimeout = setTimeout(updateTokenization, 500);
372
+ });
373
+
374
+ ngramInput.addEventListener('input', function() {
375
+ const value = this.value === '' ? '' : parseInt(this.value);
376
+ if (isNaN(value) && this.value !== '') {
377
+ this.value = "1";
378
+ }
379
+ if (debounceTimeout) {
380
+ clearTimeout(debounceTimeout);
381
+ }
382
+ debounceTimeout = setTimeout(updateTokenization, 500);
383
+ });
384
+
385
+ deltaInput.addEventListener('input', function() {
386
+ const value = this.value === '' ? '' : parseFloat(this.value);
387
+ if (isNaN(value) && this.value !== '') {
388
+ this.value = "2.0";
389
+ }
390
+ if (debounceTimeout) {
391
+ clearTimeout(debounceTimeout);
392
+ }
393
+ debounceTimeout = setTimeout(updateTokenization, 500);
394
+ });
395
+
396
+ temperatureInput.addEventListener('input', function() {
397
+ const value = this.value === '' ? '' : parseFloat(this.value);
398
+ if (isNaN(value) && this.value !== '') {
399
+ this.value = "0.8";
400
+ }
401
+ if (debounceTimeout) {
402
+ clearTimeout(debounceTimeout);
403
+ }
404
+ debounceTimeout = setTimeout(updateTokenization, 500);
405
+ });
406
+
407
+ // Add keyboard shortcut for applying changes
408
+ document.addEventListener('keydown', function(e) {
409
+ if ((e.metaKey || e.ctrlKey) && e.key === 'Enter') {
410
+ e.preventDefault();
411
+ if (document.activeElement === promptArea) {
412
+ if (generateBtn.classList.contains('generating')) {
413
+ stopGeneration();
414
+ } else {
415
+ startGeneration();
416
+ }
417
+ } else {
418
+ applyParamsBtn.click();
419
+ }
420
+ }
421
+ });
422
+
423
+ detectorTypeSelect.addEventListener('change', function() {
424
+ if (debounceTimeout) {
425
+ clearTimeout(debounceTimeout);
426
+ }
427
+ debounceTimeout = setTimeout(updateTokenization, 500);
428
+ });
429
+
430
+ // Ensure the modal apply button properly triggers an update
431
+ applyParamsBtn.addEventListener('click', function() {
432
+ updateTokenization().then(() => {
433
+ const modal = bootstrap.Modal.getInstance(document.getElementById('paramsModal'));
434
+ if (modal) {
435
+ modal.hide();
436
+ }
437
+ }).catch(error => {
438
+ console.error('Error applying parameters:', error);
439
+ });
440
+ });
441
+
442
+ // Initial tokenization with error handling
443
+ document.addEventListener('DOMContentLoaded', function() {
444
+ updateTokenization().catch(error => {
445
+ console.error('Error during initial tokenization:', error);
446
+ });
447
+ });
448
+
449
+ // Add this helper function for formatting p-values
450
+ function formatPValue(value) {
451
+ if (value >= 0.001) {
452
+ return value.toFixed(3);
453
+ } else {
454
+ return value.toExponential(2);
455
+ }
456
+ }
457
+ </script>
458
+ </body>
459
+ </html>
wm_interactive/web/__init__.py ADDED
File without changes
wm_interactive/web/app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Flask application for the watermark detection web interface.
3
+ """
4
+
5
+ from flask import Flask, render_template, request, jsonify, Response, stream_with_context
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import torch
8
+ import json
9
+
10
+ from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ
11
+ from ..core.generator import WmGenerator, OpenaiGenerator, MarylandGenerator
12
+ from .utils import get_token_details, template_prompt
13
+
14
+ CACHE_DIR = "wm_interactive/static/hf_cache"
15
+
16
+ def convert_nan_to_null(obj):
17
+ """Convert NaN values to null for JSON serialization"""
18
+ import math
19
+ if isinstance(obj, float) and math.isnan(obj):
20
+ return None
21
+ elif isinstance(obj, dict):
22
+ return {k: convert_nan_to_null(v) for k, v in obj.items()}
23
+ elif isinstance(obj, list):
24
+ return [convert_nan_to_null(item) for item in obj]
25
+ return obj
26
+
27
+ def set_to_int(value, default_value = None):
28
+ try:
29
+ return int(value)
30
+ except (ValueError, TypeError):
31
+ return default_value
32
+
33
+ def create_detector(detector_type, tokenizer, **kwargs):
34
+ """Create a detector instance based on the specified type."""
35
+ detector_map = {
36
+ 'maryland': MarylandDetector,
37
+ 'marylandz': MarylandDetectorZ,
38
+ 'openai': OpenaiDetector,
39
+ 'openaiz': OpenaiDetectorZ
40
+ }
41
+
42
+ # Validate and set default values for parameters
43
+ if 'seed' in kwargs:
44
+ kwargs['seed'] = set_to_int(kwargs['seed'], default_value = 0)
45
+ if 'ngram' in kwargs:
46
+ kwargs['ngram'] = set_to_int(kwargs['ngram'], default_value = 1)
47
+
48
+ detector_class = detector_map.get(detector_type, MarylandDetector)
49
+ return detector_class(tokenizer=tokenizer, **kwargs)
50
+
51
+ def create_app():
52
+ app = Flask(__name__,
53
+ static_folder='../static',
54
+ template_folder='../templates')
55
+
56
+ # Add zip to Jinja's global context
57
+ app.jinja_env.globals.update(zip=zip)
58
+
59
+ # Pick a model
60
+ # model_id = "meta-llama/Llama-3.2-1B-Instruct"
61
+ model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
62
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
63
+ model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR).to("cuda" if torch.cuda.is_available() else "cpu")
64
+
65
+ # Create default generator
66
+ generator = MarylandGenerator(model, tokenizer, ngram=1, seed=0)
67
+
68
+ @app.route("/", methods=["GET"])
69
+ def index():
70
+ return render_template("index.html")
71
+
72
+ @app.route("/tokenize", methods=["POST"])
73
+ def tokenize():
74
+ try:
75
+ data = request.get_json()
76
+ if not data:
77
+ return jsonify({'error': 'No JSON data received'}), 400
78
+
79
+ text = data.get('text', '')
80
+ params = data.get('params', {})
81
+
82
+ # Create a detector instance with the provided parameters
83
+ detector = create_detector(
84
+ detector_type=params.get('detector_type', 'maryland'),
85
+ tokenizer=tokenizer,
86
+ seed=params.get('seed', 0),
87
+ ngram=params.get('ngram', 1)
88
+ )
89
+
90
+ if text:
91
+ try:
92
+ display_info = get_token_details(text, detector)
93
+
94
+ # Extract summary stats (last item in display_info)
95
+ stats = display_info.pop()
96
+
97
+ response_data = {
98
+ 'token_count': len(display_info),
99
+ 'tokens': [info['token'] for info in display_info],
100
+ 'colors': [info['color'] for info in display_info],
101
+ 'scores': [info['score'] if info.get('is_scored', False) else None for info in display_info],
102
+ 'pvalues': [info['pvalue'] if info.get('is_scored', False) else None for info in display_info],
103
+ 'final_score': stats.get('final_score', 0) if stats.get('final_score') is not None else 0,
104
+ 'ntoks_scored': stats.get('ntoks_scored', 0) if stats.get('ntoks_scored') is not None else 0,
105
+ 'final_pvalue': stats.get('final_pvalue', 0.5) if stats.get('final_pvalue') is not None else 0.5
106
+ }
107
+
108
+ # Convert any NaN values to null before sending
109
+ response_data = convert_nan_to_null(response_data)
110
+
111
+ # Ensure numeric fields have default values if they became null
112
+ if response_data['final_score'] is None:
113
+ response_data['final_score'] = 0
114
+ if response_data['ntoks_scored'] is None:
115
+ response_data['ntoks_scored'] = 0
116
+ if response_data['final_pvalue'] is None:
117
+ response_data['final_pvalue'] = 0.5
118
+
119
+ return jsonify(response_data)
120
+
121
+ except Exception as e:
122
+ app.logger.error(f'Error processing text: {str(e)}')
123
+ return jsonify({'error': f'Error processing text: {str(e)}'}), 500
124
+
125
+ return jsonify({
126
+ 'token_count': 0,
127
+ 'tokens': [],
128
+ 'colors': [],
129
+ 'scores': [],
130
+ 'pvalues': [],
131
+ 'final_score': 0,
132
+ 'ntoks_scored': 0,
133
+ 'final_pvalue': 0.5
134
+ })
135
+
136
+ except Exception as e:
137
+ app.logger.error(f'Server error: {str(e)}')
138
+ return jsonify({'error': f'Server error: {str(e)}'}), 500
139
+
140
+ @app.route("/generate", methods=["POST"])
141
+ def generate():
142
+ try:
143
+ data = request.get_json()
144
+ if not data:
145
+ return jsonify({'error': 'No JSON data received'}), 400
146
+
147
+ prompt = template_prompt(data.get('prompt', ''))
148
+ params = data.get('params', {})
149
+ temperature = float(params.get('temperature', 0.8))
150
+
151
+ def generate_stream():
152
+ try:
153
+ # Create generator with correct parameters
154
+ generator_class = OpenaiGenerator if params.get('detector_type') == 'openai' else MarylandGenerator
155
+ generator = generator_class(
156
+ model=model,
157
+ tokenizer=tokenizer,
158
+ ngram=set_to_int(params.get('ngram', 1)),
159
+ seed=set_to_int(params.get('seed', 0)),
160
+ delta=float(params.get('delta', 2.0)),
161
+ )
162
+
163
+ # Get special tokens to filter out
164
+ special_tokens = {
165
+ '<|im_start|>', '<|im_end|>',
166
+ tokenizer.pad_token, tokenizer.eos_token,
167
+ tokenizer.bos_token if hasattr(tokenizer, 'bos_token') else None,
168
+ tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else None
169
+ }
170
+ special_tokens = {t for t in special_tokens if t is not None}
171
+
172
+ # Encode prompt
173
+ prompt_tokens = tokenizer.encode(prompt)
174
+ prompt_size = len(prompt_tokens)
175
+ max_gen_len = 100
176
+ total_len = min(getattr(model.config, 'max_position_embeddings', 2048), max_gen_len + prompt_size)
177
+
178
+ # Initialize generation
179
+ tokens = torch.full((1, total_len), model.config.pad_token_id).to(model.device).long()
180
+ tokens[0, :prompt_size] = torch.tensor(prompt_tokens).long()
181
+ input_text_mask = tokens != model.config.pad_token_id
182
+
183
+ # Generate token by token
184
+ prev_pos = 0
185
+ outputs = None # Initialize outputs to None
186
+ for cur_pos in range(prompt_size, total_len):
187
+ # Get model outputs
188
+ outputs = model.forward(
189
+ tokens[:, prev_pos:cur_pos],
190
+ use_cache=True,
191
+ past_key_values=outputs.past_key_values if prev_pos > 0 else None
192
+ )
193
+
194
+ # Sample next token using the generator's sampling method
195
+ ngram_tokens = tokens[0, cur_pos-generator.ngram:cur_pos].tolist()
196
+ aux = {
197
+ 'ngram_tokens': ngram_tokens,
198
+ 'cur_pos': cur_pos,
199
+ }
200
+ next_token = generator.sample_next(
201
+ outputs.logits[:, -1, :],
202
+ aux,
203
+ temperature=temperature,
204
+ top_p=0.9
205
+ )
206
+ # Check for EOS token
207
+ if next_token == model.config.eos_token_id:
208
+ break
209
+
210
+ # Decode and check if it's a special token
211
+ new_text = tokenizer.decode([next_token])
212
+ if new_text not in special_tokens and not any(st in new_text for st in special_tokens):
213
+ yield f"data: {json.dumps({'token': new_text, 'done': False})}\n\n"
214
+
215
+ # Update token and position
216
+ tokens[0, cur_pos] = next_token
217
+ prev_pos = cur_pos
218
+
219
+ # Send final complete text, filtering out special tokens
220
+ final_tokens = tokens[0, prompt_size:cur_pos+1].tolist()
221
+ final_text = tokenizer.decode(final_tokens)
222
+ for st in special_tokens:
223
+ final_text = final_text.replace(st, '')
224
+ yield f"data: {json.dumps({'text': final_text, 'done': True})}\n\n"
225
+
226
+ except Exception as e:
227
+ app.logger.error(f'Error generating text: {str(e)}')
228
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
229
+
230
+ return Response(stream_with_context(generate_stream()), mimetype='text/event-stream')
231
+
232
+ except Exception as e:
233
+ app.logger.error(f'Server error: {str(e)}')
234
+ return jsonify({'error': f'Server error: {str(e)}'}), 500
235
+
236
+ return app
237
+
238
+ app = create_app()
239
+
240
+ if __name__ == "__main__":
241
+ app.run(host='0.0.0.0', port=7860)
wm_interactive/web/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+
4
+ from ..core.detector import WmDetector
5
+
6
+ def generate_pastel_color():
7
+ """Generate a pastel color in HSL format."""
8
+ h = random.random() # Random hue
9
+ s = 0.3 + random.random() * 0.2 # Saturation between 0.3-0.5
10
+ l = 0.8 + random.random() * 0.1 # Lightness between 0.8-0.9
11
+ return f"hsl({h*360}, {s*100}%, {l*100}%)"
12
+
13
+ def color_from_score(score: float):
14
+ """
15
+ Take a score between 0 and 1 and output the color.
16
+ If the score is nan, returns a pastel gray color
17
+ If the score is close to 0, return pastel red, if the score is close to 1 returns pastel green.
18
+ """
19
+ if isinstance(score, float) and not np.isnan(score):
20
+ # Red for low scores, green for high scores
21
+ h = 0 if score < 0.5 else 120 # 0 = red, 120 = green
22
+ s = 0.3 + 0.2 * abs(2 * score - 1) # Higher saturation for extreme values
23
+ l = 0.85 # Keep lightness constant for pastel colors
24
+ return f"hsl({h}, {s*100}%, {l*100}%)"
25
+ return "hsl(0, 0%, 85%)" # Pastel gray for NaN
26
+
27
+ def get_token_details(
28
+ text: str,
29
+ detector: WmDetector
30
+ ) -> tuple:
31
+ """
32
+ Run the detector on the text and outputs everything needed for display
33
+ """
34
+ # Get scores for each token
35
+ token_details = detector.get_details(text)
36
+
37
+ # Get p-values for each token
38
+ pvalues, aux_info = detector.get_pvalues_by_tok(token_details)
39
+
40
+ display_info = []
41
+ for token_detail, pvalue in zip(token_details, pvalues):
42
+ score = token_detail['score'] if token_detail['is_scored'] else float('nan')
43
+ # Convert numpy types to native Python types
44
+ if isinstance(score, (np.floating, np.integer)):
45
+ score = float(score)
46
+ if isinstance(pvalue, (np.floating, np.integer)):
47
+ pvalue = float(pvalue)
48
+
49
+ display_info.append({
50
+ 'is_scored': token_detail['is_scored'],
51
+ 'token': token_detail['token_text'],
52
+ 'color': color_from_score(score),
53
+ 'score': score,
54
+ 'pvalue': pvalue
55
+ })
56
+
57
+ # Add summary statistics and convert numpy types to native Python types
58
+ display_info.append({
59
+ 'final_score': float(aux_info['final_score']),
60
+ 'ntoks_scored': int(aux_info['ntoks_scored']),
61
+ 'final_pvalue': float(aux_info['final_pvalue'])
62
+ })
63
+
64
+ return display_info
65
+
66
+ def template_prompt(instruction: str, prompt_type: str = "smollm") -> str:
67
+ """Template a prompt according to the model's format.
68
+
69
+ Args:
70
+ instruction: The raw prompt/instruction to template
71
+ prompt_type: Type of prompt format (smollm, alpaca)
72
+
73
+ Returns:
74
+ The formatted prompt ready for the model
75
+ """
76
+ if prompt_type == "alpaca":
77
+ return instruction
78
+ elif prompt_type == "smollm":
79
+ prompt = "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n"
80
+ prompt += f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
81
+ return prompt
82
+ else:
83
+ raise ValueError(f"Prompt type {prompt_type} not supported")