rupeshs commited on
Commit
df713c0
·
1 Parent(s): 1c32148

Update safty checker

Browse files
backend/safety_check.py CHANGED
@@ -1,30 +1,16 @@
1
- from transformers import CLIPProcessor, CLIPModel
2
- from PIL import Image
3
-
4
-
5
- # model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
6
- # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
7
 
8
 
9
  def is_safe_image(
10
- model,
11
- processor,
12
  image,
13
  ):
14
- # Load image
15
- # image = Image.open(
16
- # r"F:\om\2025\fastsdcpumcp\fastsdcpu\results\829a2123-92c8-4957-ad2f-06365a19665a-1.png"
17
- # )
18
- categories = ["safe", "nsfw"]
19
- inputs = processor(
20
- text=categories,
21
- images=image,
22
- return_tensors="pt",
23
- padding=True,
24
- )
25
- outputs = model(**inputs)
26
- logits_per_image = outputs.logits_per_image
27
- probs = logits_per_image.softmax(dim=1)
28
- safe_prob = dict(zip(categories, probs[0].tolist()))
29
- print(safe_prob)
30
- return safe_prob["safe"] > safe_prob["nsfw"]
 
1
+ from transformers import pipeline
 
 
 
 
 
2
 
3
 
4
  def is_safe_image(
5
+ classifier,
 
6
  image,
7
  ):
8
+ pred = classifier(image)
9
+ nsfw_score = 0
10
+ normal_score = 0
11
+ for label in pred:
12
+ if label["label"] == "nsfw":
13
+ nsfw_score = label["score"]
14
+ elif label["label"] == "normal":
15
+ normal_score = label["score"]
16
+ return normal_score > nsfw_score
 
 
 
 
 
 
 
 
frontend/webui/hf_demo.py CHANGED
@@ -13,15 +13,17 @@ from PIL import Image
13
  from backend.models.lcmdiffusion_setting import DiffusionTask
14
  from backend.safety_check import is_safe_image
15
  from pprint import pprint
16
- from transformers import CLIPProcessor, CLIPModel
17
 
18
  lcm_text_to_image = LCMTextToImage()
19
  lcm_lora = LCMLora(
20
  base_model_id="Lykon/dreamshaper-7",
21
  lcm_lora_id="latent-consistency/lcm-lora-sdv1-5",
22
  )
23
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
24
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
 
25
 
26
 
27
  # https://github.com/gradio-app/gradio/issues/2635#issuecomment-1423531319
@@ -69,7 +71,10 @@ def predict(
69
  latency = perf_counter() - start
70
  print(f"Latency: {latency:.2f} seconds")
71
  result = images[0]
72
- if is_safe_image(model, processor, result):
 
 
 
73
  return result # .resize([512, 512], PIL.Image.ANTIALIAS)
74
  else:
75
  print("Unsafe image detected")
 
13
  from backend.models.lcmdiffusion_setting import DiffusionTask
14
  from backend.safety_check import is_safe_image
15
  from pprint import pprint
16
+ from transformers import pipeline
17
 
18
  lcm_text_to_image = LCMTextToImage()
19
  lcm_lora = LCMLora(
20
  base_model_id="Lykon/dreamshaper-7",
21
  lcm_lora_id="latent-consistency/lcm-lora-sdv1-5",
22
  )
23
+ classifier = pipeline(
24
+ "image-classification",
25
+ model="Falconsai/nsfw_image_detection",
26
+ )
27
 
28
 
29
  # https://github.com/gradio-app/gradio/issues/2635#issuecomment-1423531319
 
71
  latency = perf_counter() - start
72
  print(f"Latency: {latency:.2f} seconds")
73
  result = images[0]
74
+ if is_safe_image(
75
+ classifier,
76
+ result,
77
+ ):
78
  return result # .resize([512, 512], PIL.Image.ANTIALIAS)
79
  else:
80
  print("Unsafe image detected")