nschenone commited on
Commit
04d336a
1 Parent(s): 044d15c

Changed censor to function instead of class

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. src/generate.py +4 -4
  3. src/profanity_filter.py +33 -34
app.py CHANGED
@@ -12,7 +12,7 @@ def fn(
12
  max_length: int = 100,
13
  temperature: float = 1.5,
14
  seed: int = 0,
15
- censor: bool = True,
16
  ):
17
 
18
  return generate(
@@ -23,7 +23,7 @@ def fn(
23
  "temperature": temperature,
24
  },
25
  seed=seed,
26
- censor=censor,
27
  )
28
 
29
 
 
12
  max_length: int = 100,
13
  temperature: float = 1.5,
14
  seed: int = 0,
15
+ censor_profanity: bool = True,
16
  ):
17
 
18
  return generate(
 
23
  "temperature": temperature,
24
  },
25
  seed=seed,
26
+ censor_profanity=censor_profanity,
27
  )
28
 
29
 
src/generate.py CHANGED
@@ -1,14 +1,14 @@
1
  import transformers
2
  from transformers import set_seed
3
 
4
- from src.profanity_filter import ProfanityFilter
5
 
6
 
7
  def generate(
8
  pipeline: transformers.Pipeline,
9
  pipeline_args: dict,
10
  seed: int = 0,
11
- censor: bool = True,
12
  ):
13
 
14
  set_seed(seed)
@@ -27,7 +27,7 @@ def generate(
27
  args.update(pipeline_args)
28
  generated = pipeline(**args)[0]["generated_text"]
29
 
30
- if censor:
31
- generated = ProfanityFilter.censor(generated)
32
 
33
  return generated
 
1
  import transformers
2
  from transformers import set_seed
3
 
4
+ from src.profanity_filter import censor
5
 
6
 
7
  def generate(
8
  pipeline: transformers.Pipeline,
9
  pipeline_args: dict,
10
  seed: int = 0,
11
+ censor_profanity: bool = True,
12
  ):
13
 
14
  set_seed(seed)
 
27
  args.update(pipeline_args)
28
  generated = pipeline(**args)[0]["generated_text"]
29
 
30
+ if censor_profanity:
31
+ generated = censor(generated)
32
 
33
  return generated
src/profanity_filter.py CHANGED
@@ -2,48 +2,47 @@ import string
2
 
3
  import requests
4
 
 
5
 
6
- class ProfanityFilter:
7
- def __init__(self):
8
- BANNED_LIST_URL = "https://raw.githubusercontent.com/snguyenthanh/better_profanity/master/better_profanity/profanity_wordlist.txt"
9
- self.banned_list = requests.get(BANNED_LIST_URL).text.split("\n")
10
 
11
- def censor(self, text="", censor_char="*", keep_first_letter=True):
12
 
13
- # Split sentences by newline
14
- sentence_list = text.split("\n")
15
- for s, sentence in enumerate(sentence_list):
16
 
17
- # Split words in sentence by space
18
- word_list = sentence.split()
19
- for w, word in enumerate(word_list):
20
 
21
- # Process word to match banned list
22
- processed_word = word.translate(
23
- str.maketrans("", "", string.punctuation)
24
- ).lower()
25
 
26
- # Replace if word is profane
27
- if processed_word in self.banned_list:
28
- censored_word = censor_char * len(word)
 
29
 
30
- # Keep first letter of word for context if desired
31
- if keep_first_letter:
32
- censored_word = word[0] + censored_word[1:]
33
 
34
- # Replcate punctuation
35
- censored_word_punc = ""
36
- for c, char in enumerate(word):
37
- if char in string.punctuation:
38
- censored_word_punc += word[c]
39
- else:
40
- censored_word_punc += censored_word[c]
41
 
42
- # Update word list
43
- word_list[w] = censored_word_punc
 
 
 
 
 
44
 
45
- # Update sentence list
46
- sentence_list[s] = word_list
47
 
48
- # Join everything back together
49
- return "\n".join([" ".join(word_list) for word_list in sentence_list])
 
 
 
 
2
 
3
  import requests
4
 
5
+ BANNED_LIST_URL = "https://raw.githubusercontent.com/snguyenthanh/better_profanity/master/better_profanity/profanity_wordlist.txt"
6
 
 
 
 
 
7
 
8
+ def censor(text="", censor_char="*", keep_first_letter=True):
9
 
10
+ banned_list = requests.get(BANNED_LIST_URL).text.split("\n")
 
 
11
 
12
+ # Split sentences by newline
13
+ sentence_list = text.split("\n")
14
+ for s, sentence in enumerate(sentence_list):
15
 
16
+ # Split words in sentence by space
17
+ word_list = sentence.split()
18
+ for w, word in enumerate(word_list):
 
19
 
20
+ # Process word to match banned list
21
+ processed_word = word.translate(
22
+ str.maketrans("", "", string.punctuation)
23
+ ).lower()
24
 
25
+ # Replace if word is profane
26
+ if processed_word in banned_list:
27
+ censored_word = censor_char * len(word)
28
 
29
+ # Keep first letter of word for context if desired
30
+ if keep_first_letter:
31
+ censored_word = word[0] + censored_word[1:]
 
 
 
 
32
 
33
+ # Replcate punctuation
34
+ censored_word_punc = ""
35
+ for c, char in enumerate(word):
36
+ if char in string.punctuation:
37
+ censored_word_punc += word[c]
38
+ else:
39
+ censored_word_punc += censored_word[c]
40
 
41
+ # Update word list
42
+ word_list[w] = censored_word_punc
43
 
44
+ # Update sentence list
45
+ sentence_list[s] = word_list
46
+
47
+ # Join everything back together
48
+ return "\n".join([" ".join(word_list) for word_list in sentence_list])