* ONNX runtime
Browse files* use llm-guard 0.3.1
* google analytics tracking
* linter to fix code
- .pre-commit-config.yaml +38 -0
- Dockerfile +1 -1
- app.py +28 -19
- output.py +43 -27
- prompt.py +25 -45
- requirements.txt +4 -5
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v4.4.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: check-yaml
|
| 6 |
+
- id: end-of-file-fixer
|
| 7 |
+
- id: trailing-whitespace
|
| 8 |
+
- id: end-of-file-fixer
|
| 9 |
+
types: [ python ]
|
| 10 |
+
- id: requirements-txt-fixer
|
| 11 |
+
|
| 12 |
+
- repo: https://github.com/psf/black
|
| 13 |
+
rev: 23.7.0
|
| 14 |
+
hooks:
|
| 15 |
+
- id: black
|
| 16 |
+
args: [ --line-length=100, --exclude="" ]
|
| 17 |
+
|
| 18 |
+
# this is not technically always safe but usually is
|
| 19 |
+
# use comments `# isort: off` and `# isort: on` to disable/re-enable isort
|
| 20 |
+
- repo: https://github.com/pycqa/isort
|
| 21 |
+
rev: 5.12.0
|
| 22 |
+
hooks:
|
| 23 |
+
- id: isort
|
| 24 |
+
args: [ --line-length=100, --profile=black ]
|
| 25 |
+
|
| 26 |
+
# this is slightly dangerous because python imports have side effects
|
| 27 |
+
# and this tool removes unused imports, which may be providing
|
| 28 |
+
# necessary side effects for the code to run
|
| 29 |
+
- repo: https://github.com/PyCQA/autoflake
|
| 30 |
+
rev: v2.2.0
|
| 31 |
+
hooks:
|
| 32 |
+
- id: autoflake
|
| 33 |
+
args:
|
| 34 |
+
- "--in-place"
|
| 35 |
+
- "--expand-star-imports"
|
| 36 |
+
- "--remove-duplicate-keys"
|
| 37 |
+
- "--remove-unused-variables"
|
| 38 |
+
- "--remove-all-unused-imports"
|
Dockerfile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
FROM python:3.
|
| 2 |
|
| 3 |
RUN apt-get update && apt-get install -y \
|
| 4 |
build-essential \
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
|
| 3 |
RUN apt-get update && apt-get install -y \
|
| 4 |
build-essential \
|
app.py
CHANGED
|
@@ -1,16 +1,33 @@
|
|
| 1 |
import logging
|
| 2 |
-
import time
|
| 3 |
import traceback
|
| 4 |
-
from datetime import timedelta
|
| 5 |
|
| 6 |
import pandas as pd
|
| 7 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
| 8 |
from output import init_settings as init_output_settings
|
| 9 |
from output import scan as scan_output
|
| 10 |
from prompt import init_settings as init_prompt_settings
|
| 11 |
from prompt import scan as scan_prompt
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
PROMPT = "prompt"
|
| 16 |
OUTPUT = "output"
|
|
@@ -48,6 +65,8 @@ if scanner_type == PROMPT:
|
|
| 48 |
elif scanner_type == OUTPUT:
|
| 49 |
enabled_scanners, settings = init_output_settings()
|
| 50 |
|
|
|
|
|
|
|
| 51 |
# Main pannel
|
| 52 |
with st.expander("About", expanded=False):
|
| 53 |
st.info(
|
|
@@ -93,32 +112,24 @@ elif scanner_type == OUTPUT:
|
|
| 93 |
st_result_text = None
|
| 94 |
st_analysis = None
|
| 95 |
st_is_valid = None
|
| 96 |
-
st_time_delta = None
|
| 97 |
|
| 98 |
try:
|
| 99 |
with st.form("text_form", clear_on_submit=False):
|
| 100 |
submitted = st.form_submit_button("Process")
|
| 101 |
if submitted:
|
| 102 |
-
|
| 103 |
-
results_score = {}
|
| 104 |
|
| 105 |
-
start_time = time.monotonic()
|
| 106 |
if scanner_type == PROMPT:
|
| 107 |
-
st_result_text,
|
| 108 |
vault, enabled_scanners, settings, st_prompt_text, st_fail_fast
|
| 109 |
)
|
| 110 |
elif scanner_type == OUTPUT:
|
| 111 |
-
st_result_text,
|
| 112 |
vault, enabled_scanners, settings, st_prompt_text, st_output_text, st_fail_fast
|
| 113 |
)
|
| 114 |
-
end_time = time.monotonic()
|
| 115 |
-
st_time_delta = timedelta(seconds=end_time - start_time)
|
| 116 |
|
| 117 |
-
st_is_valid = all(
|
| 118 |
-
st_analysis =
|
| 119 |
-
{"scanner": k, "is valid": results_valid[k], "risk score": results_score[k]}
|
| 120 |
-
for k in results_valid
|
| 121 |
-
]
|
| 122 |
|
| 123 |
except Exception as e:
|
| 124 |
logger.error(e)
|
|
@@ -127,9 +138,7 @@ except Exception as e:
|
|
| 127 |
|
| 128 |
# After:
|
| 129 |
if st_is_valid is not None:
|
| 130 |
-
st.subheader(
|
| 131 |
-
f"Results - {'valid' if st_is_valid else 'invalid'} ({round(st_time_delta.total_seconds())} seconds)"
|
| 132 |
-
)
|
| 133 |
|
| 134 |
col1, col2 = st.columns(2)
|
| 135 |
|
|
|
|
| 1 |
import logging
|
|
|
|
| 2 |
import traceback
|
|
|
|
| 3 |
|
| 4 |
import pandas as pd
|
| 5 |
import streamlit as st
|
| 6 |
+
from llm_guard.vault import Vault
|
| 7 |
+
from streamlit.components.v1 import html
|
| 8 |
+
|
| 9 |
from output import init_settings as init_output_settings
|
| 10 |
from output import scan as scan_output
|
| 11 |
from prompt import init_settings as init_prompt_settings
|
| 12 |
from prompt import scan as scan_prompt
|
| 13 |
|
| 14 |
+
|
| 15 |
+
def add_google_analytics(ga4_id):
|
| 16 |
+
"""
|
| 17 |
+
Add Google Analytics 4 to a Streamlit app
|
| 18 |
+
"""
|
| 19 |
+
ga_code = f"""
|
| 20 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id={ga4_id}"></script>
|
| 21 |
+
<script>
|
| 22 |
+
window.dataLayer = window.dataLayer || [];
|
| 23 |
+
function gtag(){{dataLayer.push(arguments);}}
|
| 24 |
+
gtag('js', new Date());
|
| 25 |
+
gtag('config', '{ga4_id}');
|
| 26 |
+
</script>
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
html(ga_code)
|
| 30 |
+
|
| 31 |
|
| 32 |
PROMPT = "prompt"
|
| 33 |
OUTPUT = "output"
|
|
|
|
| 65 |
elif scanner_type == OUTPUT:
|
| 66 |
enabled_scanners, settings = init_output_settings()
|
| 67 |
|
| 68 |
+
add_google_analytics("G-0HBVNHEZBW")
|
| 69 |
+
|
| 70 |
# Main pannel
|
| 71 |
with st.expander("About", expanded=False):
|
| 72 |
st.info(
|
|
|
|
| 112 |
st_result_text = None
|
| 113 |
st_analysis = None
|
| 114 |
st_is_valid = None
|
|
|
|
| 115 |
|
| 116 |
try:
|
| 117 |
with st.form("text_form", clear_on_submit=False):
|
| 118 |
submitted = st.form_submit_button("Process")
|
| 119 |
if submitted:
|
| 120 |
+
results = {}
|
|
|
|
| 121 |
|
|
|
|
| 122 |
if scanner_type == PROMPT:
|
| 123 |
+
st_result_text, results = scan_prompt(
|
| 124 |
vault, enabled_scanners, settings, st_prompt_text, st_fail_fast
|
| 125 |
)
|
| 126 |
elif scanner_type == OUTPUT:
|
| 127 |
+
st_result_text, results = scan_output(
|
| 128 |
vault, enabled_scanners, settings, st_prompt_text, st_output_text, st_fail_fast
|
| 129 |
)
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
st_is_valid = all(item["is_valid"] for item in results)
|
| 132 |
+
st_analysis = results
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
except Exception as e:
|
| 135 |
logger.error(e)
|
|
|
|
| 138 |
|
| 139 |
# After:
|
| 140 |
if st_is_valid is not None:
|
| 141 |
+
st.subheader(f"Results - {'valid' if st_is_valid else 'invalid'}")
|
|
|
|
|
|
|
| 142 |
|
| 143 |
col1, col2 = st.columns(2)
|
| 144 |
|
output.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import logging
|
|
|
|
|
|
|
| 2 |
from typing import Dict, List
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
-
from streamlit_tags import st_tags
|
| 6 |
-
|
| 7 |
from llm_guard.input_scanners.anonymize import default_entity_types
|
| 8 |
from llm_guard.output_scanners import (
|
| 9 |
JSON,
|
|
@@ -12,11 +12,11 @@ from llm_guard.output_scanners import (
|
|
| 12 |
Bias,
|
| 13 |
Code,
|
| 14 |
Deanonymize,
|
|
|
|
| 15 |
Language,
|
| 16 |
LanguageSame,
|
| 17 |
MaliciousURLs,
|
| 18 |
NoRefusal,
|
| 19 |
-
Refutation,
|
| 20 |
Regex,
|
| 21 |
Relevance,
|
| 22 |
Sensitive,
|
|
@@ -25,6 +25,7 @@ from llm_guard.output_scanners.relevance import all_models as relevance_models
|
|
| 25 |
from llm_guard.output_scanners.sentiment import Sentiment
|
| 26 |
from llm_guard.output_scanners.toxicity import Toxicity
|
| 27 |
from llm_guard.vault import Vault
|
|
|
|
| 28 |
|
| 29 |
logger = logging.getLogger("llm-guard-playground")
|
| 30 |
|
|
@@ -41,7 +42,7 @@ def init_settings() -> (List, Dict):
|
|
| 41 |
"LanguageSame",
|
| 42 |
"MaliciousURLs",
|
| 43 |
"NoRefusal",
|
| 44 |
-
"
|
| 45 |
"Regex",
|
| 46 |
"Relevance",
|
| 47 |
"Sensitive",
|
|
@@ -163,7 +164,12 @@ def init_settings() -> (List, Dict):
|
|
| 163 |
help="The minimum number of JSON elements that should be present",
|
| 164 |
)
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
if "Language" in st_enabled_scanners:
|
| 169 |
st_lan_expander = st.sidebar.expander(
|
|
@@ -274,23 +280,23 @@ def init_settings() -> (List, Dict):
|
|
| 274 |
|
| 275 |
settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
|
| 276 |
|
| 277 |
-
if "
|
| 278 |
-
|
| 279 |
-
"
|
| 280 |
expanded=False,
|
| 281 |
)
|
| 282 |
|
| 283 |
-
with
|
| 284 |
-
|
| 285 |
-
label="
|
| 286 |
value=0.5,
|
| 287 |
min_value=0.0,
|
| 288 |
max_value=1.0,
|
| 289 |
step=0.05,
|
| 290 |
-
key="
|
| 291 |
)
|
| 292 |
|
| 293 |
-
settings["
|
| 294 |
|
| 295 |
if "Regex" in st_enabled_scanners:
|
| 296 |
st_regex_expander = st.sidebar.expander(
|
|
@@ -359,7 +365,7 @@ def init_settings() -> (List, Dict):
|
|
| 359 |
key="sensitive_entity_types",
|
| 360 |
)
|
| 361 |
st.caption(
|
| 362 |
-
"Check all supported entities: https://
|
| 363 |
)
|
| 364 |
st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact")
|
| 365 |
st_sens_threshold = st.slider(
|
|
@@ -434,13 +440,13 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
| 434 |
return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
|
| 435 |
|
| 436 |
if scanner_name == "Bias":
|
| 437 |
-
return Bias(threshold=settings["threshold"])
|
| 438 |
|
| 439 |
if scanner_name == "Deanonymize":
|
| 440 |
return Deanonymize(vault=vault)
|
| 441 |
|
| 442 |
if scanner_name == "JSON":
|
| 443 |
-
return JSON(required_elements=settings["required_elements"])
|
| 444 |
|
| 445 |
if scanner_name == "Language":
|
| 446 |
return Language(valid_languages=settings["valid_languages"])
|
|
@@ -458,16 +464,16 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
| 458 |
elif mode == "denied":
|
| 459 |
denied_languages = settings["languages"]
|
| 460 |
|
| 461 |
-
return Code(allowed=allowed_languages, denied=denied_languages)
|
| 462 |
|
| 463 |
if scanner_name == "MaliciousURLs":
|
| 464 |
-
return MaliciousURLs(threshold=settings["threshold"])
|
| 465 |
|
| 466 |
if scanner_name == "NoRefusal":
|
| 467 |
return NoRefusal(threshold=settings["threshold"])
|
| 468 |
|
| 469 |
-
if scanner_name == "
|
| 470 |
-
return
|
| 471 |
|
| 472 |
if scanner_name == "Regex":
|
| 473 |
match_type = settings["type"]
|
|
@@ -491,13 +497,14 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
| 491 |
entity_types=settings["entity_types"],
|
| 492 |
redact=settings["redact"],
|
| 493 |
threshold=settings["threshold"],
|
|
|
|
| 494 |
)
|
| 495 |
|
| 496 |
if scanner_name == "Sentiment":
|
| 497 |
return Sentiment(threshold=settings["threshold"])
|
| 498 |
|
| 499 |
if scanner_name == "Toxicity":
|
| 500 |
-
return Toxicity(threshold=settings["threshold"])
|
| 501 |
|
| 502 |
raise ValueError("Unknown scanner name")
|
| 503 |
|
|
@@ -509,10 +516,9 @@ def scan(
|
|
| 509 |
prompt: str,
|
| 510 |
text: str,
|
| 511 |
fail_fast: bool = False,
|
| 512 |
-
) -> (str,
|
| 513 |
sanitized_output = text
|
| 514 |
-
|
| 515 |
-
results_score = {}
|
| 516 |
|
| 517 |
status_text = "Scanning prompt..."
|
| 518 |
if fail_fast:
|
|
@@ -524,13 +530,23 @@ def scan(
|
|
| 524 |
scanner = get_scanner(
|
| 525 |
scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
|
| 526 |
)
|
|
|
|
|
|
|
| 527 |
sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
|
| 528 |
-
|
| 529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
if fail_fast and not is_valid:
|
| 532 |
break
|
| 533 |
|
| 534 |
status.update(label="Scanning complete", state="complete", expanded=False)
|
| 535 |
|
| 536 |
-
return sanitized_output,
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import time
|
| 3 |
+
from datetime import timedelta
|
| 4 |
from typing import Dict, List
|
| 5 |
|
| 6 |
import streamlit as st
|
|
|
|
|
|
|
| 7 |
from llm_guard.input_scanners.anonymize import default_entity_types
|
| 8 |
from llm_guard.output_scanners import (
|
| 9 |
JSON,
|
|
|
|
| 12 |
Bias,
|
| 13 |
Code,
|
| 14 |
Deanonymize,
|
| 15 |
+
FactualConsistency,
|
| 16 |
Language,
|
| 17 |
LanguageSame,
|
| 18 |
MaliciousURLs,
|
| 19 |
NoRefusal,
|
|
|
|
| 20 |
Regex,
|
| 21 |
Relevance,
|
| 22 |
Sensitive,
|
|
|
|
| 25 |
from llm_guard.output_scanners.sentiment import Sentiment
|
| 26 |
from llm_guard.output_scanners.toxicity import Toxicity
|
| 27 |
from llm_guard.vault import Vault
|
| 28 |
+
from streamlit_tags import st_tags
|
| 29 |
|
| 30 |
logger = logging.getLogger("llm-guard-playground")
|
| 31 |
|
|
|
|
| 42 |
"LanguageSame",
|
| 43 |
"MaliciousURLs",
|
| 44 |
"NoRefusal",
|
| 45 |
+
"FactualConsistency",
|
| 46 |
"Regex",
|
| 47 |
"Relevance",
|
| 48 |
"Sensitive",
|
|
|
|
| 164 |
help="The minimum number of JSON elements that should be present",
|
| 165 |
)
|
| 166 |
|
| 167 |
+
st_json_repair = st.checkbox("Repair", value=False, help="Attempt to repair the JSON")
|
| 168 |
+
|
| 169 |
+
settings["JSON"] = {
|
| 170 |
+
"required_elements": st_json_required_elements,
|
| 171 |
+
"repair": st_json_repair,
|
| 172 |
+
}
|
| 173 |
|
| 174 |
if "Language" in st_enabled_scanners:
|
| 175 |
st_lan_expander = st.sidebar.expander(
|
|
|
|
| 280 |
|
| 281 |
settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
|
| 282 |
|
| 283 |
+
if "FactualConsistency" in st_enabled_scanners:
|
| 284 |
+
st_fc_expander = st.sidebar.expander(
|
| 285 |
+
"FactualConsistency",
|
| 286 |
expanded=False,
|
| 287 |
)
|
| 288 |
|
| 289 |
+
with st_fc_expander:
|
| 290 |
+
st_fc_minimum_score = st.slider(
|
| 291 |
+
label="Minimum score",
|
| 292 |
value=0.5,
|
| 293 |
min_value=0.0,
|
| 294 |
max_value=1.0,
|
| 295 |
step=0.05,
|
| 296 |
+
key="fc_threshold",
|
| 297 |
)
|
| 298 |
|
| 299 |
+
settings["FactualConsistency"] = {"minimum_score": st_fc_minimum_score}
|
| 300 |
|
| 301 |
if "Regex" in st_enabled_scanners:
|
| 302 |
st_regex_expander = st.sidebar.expander(
|
|
|
|
| 365 |
key="sensitive_entity_types",
|
| 366 |
)
|
| 367 |
st.caption(
|
| 368 |
+
"Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
|
| 369 |
)
|
| 370 |
st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact")
|
| 371 |
st_sens_threshold = st.slider(
|
|
|
|
| 440 |
return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
|
| 441 |
|
| 442 |
if scanner_name == "Bias":
|
| 443 |
+
return Bias(threshold=settings["threshold"], use_onnx=True)
|
| 444 |
|
| 445 |
if scanner_name == "Deanonymize":
|
| 446 |
return Deanonymize(vault=vault)
|
| 447 |
|
| 448 |
if scanner_name == "JSON":
|
| 449 |
+
return JSON(required_elements=settings["required_elements"], repair=settings["repair"])
|
| 450 |
|
| 451 |
if scanner_name == "Language":
|
| 452 |
return Language(valid_languages=settings["valid_languages"])
|
|
|
|
| 464 |
elif mode == "denied":
|
| 465 |
denied_languages = settings["languages"]
|
| 466 |
|
| 467 |
+
return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True)
|
| 468 |
|
| 469 |
if scanner_name == "MaliciousURLs":
|
| 470 |
+
return MaliciousURLs(threshold=settings["threshold"], use_onnx=True)
|
| 471 |
|
| 472 |
if scanner_name == "NoRefusal":
|
| 473 |
return NoRefusal(threshold=settings["threshold"])
|
| 474 |
|
| 475 |
+
if scanner_name == "FactualConsistency":
|
| 476 |
+
return FactualConsistency(minimum_score=settings["minimum_score"])
|
| 477 |
|
| 478 |
if scanner_name == "Regex":
|
| 479 |
match_type = settings["type"]
|
|
|
|
| 497 |
entity_types=settings["entity_types"],
|
| 498 |
redact=settings["redact"],
|
| 499 |
threshold=settings["threshold"],
|
| 500 |
+
use_onnx=True,
|
| 501 |
)
|
| 502 |
|
| 503 |
if scanner_name == "Sentiment":
|
| 504 |
return Sentiment(threshold=settings["threshold"])
|
| 505 |
|
| 506 |
if scanner_name == "Toxicity":
|
| 507 |
+
return Toxicity(threshold=settings["threshold"], use_onnx=True)
|
| 508 |
|
| 509 |
raise ValueError("Unknown scanner name")
|
| 510 |
|
|
|
|
| 516 |
prompt: str,
|
| 517 |
text: str,
|
| 518 |
fail_fast: bool = False,
|
| 519 |
+
) -> (str, List[Dict[str, any]]):
|
| 520 |
sanitized_output = text
|
| 521 |
+
results = []
|
|
|
|
| 522 |
|
| 523 |
status_text = "Scanning prompt..."
|
| 524 |
if fail_fast:
|
|
|
|
| 530 |
scanner = get_scanner(
|
| 531 |
scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
|
| 532 |
)
|
| 533 |
+
|
| 534 |
+
start_time = time.monotonic()
|
| 535 |
sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
|
| 536 |
+
end_time = time.monotonic()
|
| 537 |
+
|
| 538 |
+
results.append(
|
| 539 |
+
{
|
| 540 |
+
"scanner": scanner_name,
|
| 541 |
+
"is_valid": is_valid,
|
| 542 |
+
"risk_score": risk_score,
|
| 543 |
+
"took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
|
| 544 |
+
}
|
| 545 |
+
)
|
| 546 |
|
| 547 |
if fail_fast and not is_valid:
|
| 548 |
break
|
| 549 |
|
| 550 |
status.update(label="Scanning complete", state="complete", expanded=False)
|
| 551 |
|
| 552 |
+
return sanitized_output, results
|
prompt.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import logging
|
|
|
|
|
|
|
| 2 |
from typing import Dict, List
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
-
from streamlit_tags import st_tags
|
| 6 |
-
|
| 7 |
from llm_guard.input_scanners import (
|
| 8 |
Anonymize,
|
| 9 |
BanSubstrings,
|
|
@@ -11,7 +11,6 @@ from llm_guard.input_scanners import (
|
|
| 11 |
Code,
|
| 12 |
Language,
|
| 13 |
PromptInjection,
|
| 14 |
-
PromptInjectionV2,
|
| 15 |
Regex,
|
| 16 |
Secrets,
|
| 17 |
Sentiment,
|
|
@@ -19,8 +18,9 @@ from llm_guard.input_scanners import (
|
|
| 19 |
Toxicity,
|
| 20 |
)
|
| 21 |
from llm_guard.input_scanners.anonymize import default_entity_types
|
| 22 |
-
from llm_guard.input_scanners.
|
| 23 |
from llm_guard.vault import Vault
|
|
|
|
| 24 |
|
| 25 |
logger = logging.getLogger("llm-guard-playground")
|
| 26 |
|
|
@@ -33,7 +33,6 @@ def init_settings() -> (List, Dict):
|
|
| 33 |
"Code",
|
| 34 |
"Language",
|
| 35 |
"PromptInjection",
|
| 36 |
-
"PromptInjectionV2",
|
| 37 |
"Regex",
|
| 38 |
"Secrets",
|
| 39 |
"Sentiment",
|
|
@@ -67,7 +66,7 @@ def init_settings() -> (List, Dict):
|
|
| 67 |
key="anon_entity_types",
|
| 68 |
)
|
| 69 |
st.caption(
|
| 70 |
-
"Check all supported entities: https://
|
| 71 |
)
|
| 72 |
st_anon_hidden_names = st_tags(
|
| 73 |
label="Hidden names to be anonymized",
|
|
@@ -101,11 +100,6 @@ def init_settings() -> (List, Dict):
|
|
| 101 |
step=0.1,
|
| 102 |
key="anon_threshold",
|
| 103 |
)
|
| 104 |
-
st_anon_recognizer = st.selectbox(
|
| 105 |
-
"Recognizer",
|
| 106 |
-
[RECOGNIZER_SPACY_EN_PII_DISTILBERT, RECOGNIZER_SPACY_EN_PII_FAST],
|
| 107 |
-
index=1,
|
| 108 |
-
)
|
| 109 |
|
| 110 |
settings["Anonymize"] = {
|
| 111 |
"entity_types": st_anon_entity_types,
|
|
@@ -114,7 +108,6 @@ def init_settings() -> (List, Dict):
|
|
| 114 |
"preamble": st_anon_preamble,
|
| 115 |
"use_faker": st_anon_use_faker,
|
| 116 |
"threshold": st_anon_threshold,
|
| 117 |
-
"recognizer": st_anon_recognizer,
|
| 118 |
}
|
| 119 |
|
| 120 |
if "BanSubstrings" in st_enabled_scanners:
|
|
@@ -286,26 +279,6 @@ def init_settings() -> (List, Dict):
|
|
| 286 |
"threshold": st_pi_threshold,
|
| 287 |
}
|
| 288 |
|
| 289 |
-
if "PromptInjectionV2" in st_enabled_scanners:
|
| 290 |
-
st_piv2_expander = st.sidebar.expander(
|
| 291 |
-
"Prompt Injection V2",
|
| 292 |
-
expanded=False,
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
with st_piv2_expander:
|
| 296 |
-
st_piv2_threshold = st.slider(
|
| 297 |
-
label="Threshold",
|
| 298 |
-
value=0.5,
|
| 299 |
-
min_value=0.0,
|
| 300 |
-
max_value=1.0,
|
| 301 |
-
step=0.05,
|
| 302 |
-
key="prompt_injection_v2_threshold",
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
-
settings["PromptInjectionV2"] = {
|
| 306 |
-
"threshold": st_piv2_threshold,
|
| 307 |
-
}
|
| 308 |
-
|
| 309 |
if "Regex" in st_enabled_scanners:
|
| 310 |
st_regex_expander = st.sidebar.expander(
|
| 311 |
"Regex",
|
|
@@ -427,7 +400,7 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
| 427 |
preamble=settings["preamble"],
|
| 428 |
use_faker=settings["use_faker"],
|
| 429 |
threshold=settings["threshold"],
|
| 430 |
-
|
| 431 |
)
|
| 432 |
|
| 433 |
if scanner_name == "BanSubstrings":
|
|
@@ -452,16 +425,13 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
| 452 |
elif mode == "denied":
|
| 453 |
denied_languages = settings["languages"]
|
| 454 |
|
| 455 |
-
return Code(allowed=allowed_languages, denied=denied_languages)
|
| 456 |
|
| 457 |
if scanner_name == "Language":
|
| 458 |
return Language(valid_languages=settings["valid_languages"])
|
| 459 |
|
| 460 |
if scanner_name == "PromptInjection":
|
| 461 |
-
return PromptInjection(threshold=settings["threshold"])
|
| 462 |
-
|
| 463 |
-
if scanner_name == "PromptInjectionV2":
|
| 464 |
-
return PromptInjectionV2(threshold=settings["threshold"])
|
| 465 |
|
| 466 |
if scanner_name == "Regex":
|
| 467 |
match_type = settings["type"]
|
|
@@ -487,17 +457,16 @@ def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
|
|
| 487 |
return TokenLimit(limit=settings["limit"], encoding_name=settings["encoding_name"])
|
| 488 |
|
| 489 |
if scanner_name == "Toxicity":
|
| 490 |
-
return Toxicity(threshold=settings["threshold"])
|
| 491 |
|
| 492 |
raise ValueError("Unknown scanner name")
|
| 493 |
|
| 494 |
|
| 495 |
def scan(
|
| 496 |
vault: Vault, enabled_scanners: List[str], settings: Dict, text: str, fail_fast: bool = False
|
| 497 |
-
) -> (str,
|
| 498 |
sanitized_prompt = text
|
| 499 |
-
|
| 500 |
-
results_score = {}
|
| 501 |
|
| 502 |
status_text = "Scanning prompt..."
|
| 503 |
if fail_fast:
|
|
@@ -507,12 +476,23 @@ def scan(
|
|
| 507 |
for scanner_name in enabled_scanners:
|
| 508 |
st.write(f"{scanner_name} scanner...")
|
| 509 |
scanner = get_scanner(scanner_name, vault, settings[scanner_name])
|
|
|
|
|
|
|
| 510 |
sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt)
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
if fail_fast and not is_valid:
|
| 515 |
break
|
|
|
|
| 516 |
status.update(label="Scanning complete", state="complete", expanded=False)
|
| 517 |
|
| 518 |
-
return sanitized_prompt,
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import time
|
| 3 |
+
from datetime import timedelta
|
| 4 |
from typing import Dict, List
|
| 5 |
|
| 6 |
import streamlit as st
|
|
|
|
|
|
|
| 7 |
from llm_guard.input_scanners import (
|
| 8 |
Anonymize,
|
| 9 |
BanSubstrings,
|
|
|
|
| 11 |
Code,
|
| 12 |
Language,
|
| 13 |
PromptInjection,
|
|
|
|
| 14 |
Regex,
|
| 15 |
Secrets,
|
| 16 |
Sentiment,
|
|
|
|
| 18 |
Toxicity,
|
| 19 |
)
|
| 20 |
from llm_guard.input_scanners.anonymize import default_entity_types
|
| 21 |
+
from llm_guard.input_scanners.prompt_injection import ALL_MODELS as PI_ALL_MODELS
|
| 22 |
from llm_guard.vault import Vault
|
| 23 |
+
from streamlit_tags import st_tags
|
| 24 |
|
| 25 |
logger = logging.getLogger("llm-guard-playground")
|
| 26 |
|
|
|
|
| 33 |
"Code",
|
| 34 |
"Language",
|
| 35 |
"PromptInjection",
|
|
|
|
| 36 |
"Regex",
|
| 37 |
"Secrets",
|
| 38 |
"Sentiment",
|
|
|
|
| 66 |
key="anon_entity_types",
|
| 67 |
)
|
| 68 |
st.caption(
|
| 69 |
+
"Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
|
| 70 |
)
|
| 71 |
st_anon_hidden_names = st_tags(
|
| 72 |
label="Hidden names to be anonymized",
|
|
|
|
| 100 |
step=0.1,
|
| 101 |
key="anon_threshold",
|
| 102 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
settings["Anonymize"] = {
|
| 105 |
"entity_types": st_anon_entity_types,
|
|
|
|
| 108 |
"preamble": st_anon_preamble,
|
| 109 |
"use_faker": st_anon_use_faker,
|
| 110 |
"threshold": st_anon_threshold,
|
|
|
|
| 111 |
}
|
| 112 |
|
| 113 |
if "BanSubstrings" in st_enabled_scanners:
|
|
|
|
| 279 |
"threshold": st_pi_threshold,
|
| 280 |
}
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
if "Regex" in st_enabled_scanners:
|
| 283 |
st_regex_expander = st.sidebar.expander(
|
| 284 |
"Regex",
|
|
|
|
| 400 |
preamble=settings["preamble"],
|
| 401 |
use_faker=settings["use_faker"],
|
| 402 |
threshold=settings["threshold"],
|
| 403 |
+
use_onnx=True,
|
| 404 |
)
|
| 405 |
|
| 406 |
if scanner_name == "BanSubstrings":
|
|
|
|
| 425 |
elif mode == "denied":
|
| 426 |
denied_languages = settings["languages"]
|
| 427 |
|
| 428 |
+
return Code(allowed=allowed_languages, denied=denied_languages, use_onnx=True)
|
| 429 |
|
| 430 |
if scanner_name == "Language":
|
| 431 |
return Language(valid_languages=settings["valid_languages"])
|
| 432 |
|
| 433 |
if scanner_name == "PromptInjection":
|
| 434 |
+
return PromptInjection(threshold=settings["threshold"], models=PI_ALL_MODELS, use_onnx=True)
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
if scanner_name == "Regex":
|
| 437 |
match_type = settings["type"]
|
|
|
|
| 457 |
return TokenLimit(limit=settings["limit"], encoding_name=settings["encoding_name"])
|
| 458 |
|
| 459 |
if scanner_name == "Toxicity":
|
| 460 |
+
return Toxicity(threshold=settings["threshold"], use_onnx=True)
|
| 461 |
|
| 462 |
raise ValueError("Unknown scanner name")
|
| 463 |
|
| 464 |
|
| 465 |
def scan(
|
| 466 |
vault: Vault, enabled_scanners: List[str], settings: Dict, text: str, fail_fast: bool = False
|
| 467 |
+
) -> (str, List[Dict[str, any]]):
|
| 468 |
sanitized_prompt = text
|
| 469 |
+
results = []
|
|
|
|
| 470 |
|
| 471 |
status_text = "Scanning prompt..."
|
| 472 |
if fail_fast:
|
|
|
|
| 476 |
for scanner_name in enabled_scanners:
|
| 477 |
st.write(f"{scanner_name} scanner...")
|
| 478 |
scanner = get_scanner(scanner_name, vault, settings[scanner_name])
|
| 479 |
+
|
| 480 |
+
start_time = time.monotonic()
|
| 481 |
sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt)
|
| 482 |
+
end_time = time.monotonic()
|
| 483 |
+
|
| 484 |
+
results.append(
|
| 485 |
+
{
|
| 486 |
+
"scanner": scanner_name,
|
| 487 |
+
"is_valid": is_valid,
|
| 488 |
+
"risk_score": risk_score,
|
| 489 |
+
"took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
|
| 490 |
+
}
|
| 491 |
+
)
|
| 492 |
|
| 493 |
if fail_fast and not is_valid:
|
| 494 |
break
|
| 495 |
+
|
| 496 |
status.update(label="Scanning complete", state="complete", expanded=False)
|
| 497 |
|
| 498 |
+
return sanitized_prompt, results
|
requirements.txt
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
llm-guard==0.3.
|
| 3 |
-
pandas==2.1.
|
| 4 |
-
streamlit==1.
|
| 5 |
streamlit-tags==1.2.8
|
| 6 |
-
https://huggingface.co/beki/en_spacy_pii_fast/resolve/main/en_spacy_pii_fast-any-py3-none-any.whl
|
|
|
|
| 1 |
+
llm-guard==0.3.1
|
| 2 |
+
llm-guard[onnxruntime]==0.3.1
|
| 3 |
+
pandas==2.1.2
|
| 4 |
+
streamlit==1.28.1
|
| 5 |
streamlit-tags==1.2.8
|
|
|