Spaces:
Running
Running
Update safty checker
Browse files- backend/safety_check.py +11 -25
- frontend/webui/hf_demo.py +9 -4
backend/safety_check.py
CHANGED
@@ -1,30 +1,16 @@
|
|
1 |
-
from transformers import
|
2 |
-
from PIL import Image
|
3 |
-
|
4 |
-
|
5 |
-
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
6 |
-
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
7 |
|
8 |
|
9 |
def is_safe_image(
|
10 |
-
|
11 |
-
processor,
|
12 |
image,
|
13 |
):
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
padding=True,
|
24 |
-
)
|
25 |
-
outputs = model(**inputs)
|
26 |
-
logits_per_image = outputs.logits_per_image
|
27 |
-
probs = logits_per_image.softmax(dim=1)
|
28 |
-
safe_prob = dict(zip(categories, probs[0].tolist()))
|
29 |
-
print(safe_prob)
|
30 |
-
return safe_prob["safe"] > safe_prob["nsfw"]
|
|
|
1 |
+
from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def is_safe_image(
|
5 |
+
classifier,
|
|
|
6 |
image,
|
7 |
):
|
8 |
+
pred = classifier(image)
|
9 |
+
nsfw_score = 0
|
10 |
+
normal_score = 0
|
11 |
+
for label in pred:
|
12 |
+
if label["label"] == "nsfw":
|
13 |
+
nsfw_score = label["score"]
|
14 |
+
elif label["label"] == "normal":
|
15 |
+
normal_score = label["score"]
|
16 |
+
return normal_score > nsfw_score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/webui/hf_demo.py
CHANGED
@@ -13,15 +13,17 @@ from PIL import Image
|
|
13 |
from backend.models.lcmdiffusion_setting import DiffusionTask
|
14 |
from backend.safety_check import is_safe_image
|
15 |
from pprint import pprint
|
16 |
-
from transformers import
|
17 |
|
18 |
lcm_text_to_image = LCMTextToImage()
|
19 |
lcm_lora = LCMLora(
|
20 |
base_model_id="Lykon/dreamshaper-7",
|
21 |
lcm_lora_id="latent-consistency/lcm-lora-sdv1-5",
|
22 |
)
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
|
26 |
|
27 |
# https://github.com/gradio-app/gradio/issues/2635#issuecomment-1423531319
|
@@ -69,7 +71,10 @@ def predict(
|
|
69 |
latency = perf_counter() - start
|
70 |
print(f"Latency: {latency:.2f} seconds")
|
71 |
result = images[0]
|
72 |
-
if is_safe_image(
|
|
|
|
|
|
|
73 |
return result # .resize([512, 512], PIL.Image.ANTIALIAS)
|
74 |
else:
|
75 |
print("Unsafe image detected")
|
|
|
13 |
from backend.models.lcmdiffusion_setting import DiffusionTask
|
14 |
from backend.safety_check import is_safe_image
|
15 |
from pprint import pprint
|
16 |
+
from transformers import pipeline
|
17 |
|
18 |
lcm_text_to_image = LCMTextToImage()
|
19 |
lcm_lora = LCMLora(
|
20 |
base_model_id="Lykon/dreamshaper-7",
|
21 |
lcm_lora_id="latent-consistency/lcm-lora-sdv1-5",
|
22 |
)
|
23 |
+
classifier = pipeline(
|
24 |
+
"image-classification",
|
25 |
+
model="Falconsai/nsfw_image_detection",
|
26 |
+
)
|
27 |
|
28 |
|
29 |
# https://github.com/gradio-app/gradio/issues/2635#issuecomment-1423531319
|
|
|
71 |
latency = perf_counter() - start
|
72 |
print(f"Latency: {latency:.2f} seconds")
|
73 |
result = images[0]
|
74 |
+
if is_safe_image(
|
75 |
+
classifier,
|
76 |
+
result,
|
77 |
+
):
|
78 |
return result # .resize([512, 512], PIL.Image.ANTIALIAS)
|
79 |
else:
|
80 |
print("Unsafe image detected")
|