Spaces:
Build error
Build error
Changed censor to function instead of class
Browse files- app.py +2 -2
- src/generate.py +4 -4
- src/profanity_filter.py +33 -34
app.py
CHANGED
@@ -12,7 +12,7 @@ def fn(
|
|
12 |
max_length: int = 100,
|
13 |
temperature: float = 1.5,
|
14 |
seed: int = 0,
|
15 |
-
|
16 |
):
|
17 |
|
18 |
return generate(
|
@@ -23,7 +23,7 @@ def fn(
|
|
23 |
"temperature": temperature,
|
24 |
},
|
25 |
seed=seed,
|
26 |
-
|
27 |
)
|
28 |
|
29 |
|
|
|
12 |
max_length: int = 100,
|
13 |
temperature: float = 1.5,
|
14 |
seed: int = 0,
|
15 |
+
censor_profanity: bool = True,
|
16 |
):
|
17 |
|
18 |
return generate(
|
|
|
23 |
"temperature": temperature,
|
24 |
},
|
25 |
seed=seed,
|
26 |
+
censor_profanity=censor_profanity,
|
27 |
)
|
28 |
|
29 |
|
src/generate.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import transformers
|
2 |
from transformers import set_seed
|
3 |
|
4 |
-
from src.profanity_filter import
|
5 |
|
6 |
|
7 |
def generate(
|
8 |
pipeline: transformers.Pipeline,
|
9 |
pipeline_args: dict,
|
10 |
seed: int = 0,
|
11 |
-
|
12 |
):
|
13 |
|
14 |
set_seed(seed)
|
@@ -27,7 +27,7 @@ def generate(
|
|
27 |
args.update(pipeline_args)
|
28 |
generated = pipeline(**args)[0]["generated_text"]
|
29 |
|
30 |
-
if
|
31 |
-
generated =
|
32 |
|
33 |
return generated
|
|
|
1 |
import transformers
|
2 |
from transformers import set_seed
|
3 |
|
4 |
+
from src.profanity_filter import censor
|
5 |
|
6 |
|
7 |
def generate(
|
8 |
pipeline: transformers.Pipeline,
|
9 |
pipeline_args: dict,
|
10 |
seed: int = 0,
|
11 |
+
censor_profanity: bool = True,
|
12 |
):
|
13 |
|
14 |
set_seed(seed)
|
|
|
27 |
args.update(pipeline_args)
|
28 |
generated = pipeline(**args)[0]["generated_text"]
|
29 |
|
30 |
+
if censor_profanity:
|
31 |
+
generated = censor(generated)
|
32 |
|
33 |
return generated
|
src/profanity_filter.py
CHANGED
@@ -2,48 +2,47 @@ import string
|
|
2 |
|
3 |
import requests
|
4 |
|
|
|
5 |
|
6 |
-
class ProfanityFilter:
|
7 |
-
def __init__(self):
|
8 |
-
BANNED_LIST_URL = "https://raw.githubusercontent.com/snguyenthanh/better_profanity/master/better_profanity/profanity_wordlist.txt"
|
9 |
-
self.banned_list = requests.get(BANNED_LIST_URL).text.split("\n")
|
10 |
|
11 |
-
|
12 |
|
13 |
-
|
14 |
-
sentence_list = text.split("\n")
|
15 |
-
for s, sentence in enumerate(sentence_list):
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
).lower()
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
if char in string.punctuation:
|
38 |
-
censored_word_punc += word[c]
|
39 |
-
else:
|
40 |
-
censored_word_punc += censored_word[c]
|
41 |
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
# Update
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
2 |
|
3 |
import requests
|
4 |
|
5 |
+
BANNED_LIST_URL = "https://raw.githubusercontent.com/snguyenthanh/better_profanity/master/better_profanity/profanity_wordlist.txt"
|
6 |
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
def censor(text="", censor_char="*", keep_first_letter=True):
|
9 |
|
10 |
+
banned_list = requests.get(BANNED_LIST_URL).text.split("\n")
|
|
|
|
|
11 |
|
12 |
+
# Split sentences by newline
|
13 |
+
sentence_list = text.split("\n")
|
14 |
+
for s, sentence in enumerate(sentence_list):
|
15 |
|
16 |
+
# Split words in sentence by space
|
17 |
+
word_list = sentence.split()
|
18 |
+
for w, word in enumerate(word_list):
|
|
|
19 |
|
20 |
+
# Process word to match banned list
|
21 |
+
processed_word = word.translate(
|
22 |
+
str.maketrans("", "", string.punctuation)
|
23 |
+
).lower()
|
24 |
|
25 |
+
# Replace if word is profane
|
26 |
+
if processed_word in banned_list:
|
27 |
+
censored_word = censor_char * len(word)
|
28 |
|
29 |
+
# Keep first letter of word for context if desired
|
30 |
+
if keep_first_letter:
|
31 |
+
censored_word = word[0] + censored_word[1:]
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
# Replcate punctuation
|
34 |
+
censored_word_punc = ""
|
35 |
+
for c, char in enumerate(word):
|
36 |
+
if char in string.punctuation:
|
37 |
+
censored_word_punc += word[c]
|
38 |
+
else:
|
39 |
+
censored_word_punc += censored_word[c]
|
40 |
|
41 |
+
# Update word list
|
42 |
+
word_list[w] = censored_word_punc
|
43 |
|
44 |
+
# Update sentence list
|
45 |
+
sentence_list[s] = word_list
|
46 |
+
|
47 |
+
# Join everything back together
|
48 |
+
return "\n".join([" ".join(word_list) for word_list in sentence_list])
|