Thouph commited on
Commit
5ba6e49
·
verified ·
1 Parent(s): e4197d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ random.seed(1234)
4
+ import torch
5
+ from transformers import Qwen2ForSequenceClassification, AutoTokenizer
6
+ import gradio as gr
7
+ from datetime import datetime
8
+ torch.set_grad_enabled(False)
9
+
10
+ model = Qwen2ForSequenceClassification.from_pretrained("Thouph/danbooru-to-e621-qwen2.5-0.5b", num_labels = 9086, device_map="cpu")
11
+ model.eval()
12
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
13
+
14
+ with open("tags_9083.json", "r") as file:
15
+ allowed_tags = json.load(file)
16
+
17
+ allowed_tags = sorted(allowed_tags)
18
+ allowed_tags.append("explicit")
19
+ allowed_tags.append("questionable")
20
+ allowed_tags.append("safe")
21
+
22
+ def create_tags(prompt, threshold):
23
+ inputs = tokenizer(
24
+ prompt,
25
+ padding="do_not_pad",
26
+ max_length=512,
27
+ truncation=True,
28
+ return_tensors="pt",
29
+ )
30
+
31
+ output = model(**inputs).logits
32
+ output = torch.nn.functional.sigmoid(output)
33
+ indices = torch.where(output > threshold)
34
+ values = output[indices]
35
+ indices = indices[1]
36
+ values = values.squeeze()
37
+
38
+ temp = []
39
+ tag_score = dict()
40
+ for i in range(indices.size(0)):
41
+ temp.append([allowed_tags[indices[i]], values[i].item()])
42
+ tag_score[allowed_tags[indices[i]]] = values[i].item()
43
+ temp = [t[0] for t in temp]
44
+ text_no_impl = " ".join(temp)
45
+ current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
46
+ print(f"{current_datetime}: finished.")
47
+ return text_no_impl, tag_score
48
+
49
+ demo = gr.Interface(
50
+ create_tags,
51
+ inputs=[
52
+ gr.TextArea(label="Prompt",),
53
+ gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="Threshold")
54
+ ],
55
+ outputs=[
56
+ gr.Textbox(label="Tag String"),
57
+ gr.Label(label="Tag Predictions", num_top_classes=200),
58
+ ],
59
+ allow_flagging="never",
60
+ )
61
+
62
+ demo.launch()