nschenone commited on
Commit
d9b077a
1 Parent(s): fff55d3

Added dedicated generation function and profanity filter

Browse files
Files changed (4) hide show
  1. app.py +24 -50
  2. src/generate.py +33 -0
  3. src/profanity_filter.py +49 -0
  4. src/utils.py +5 -4
app.py CHANGED
@@ -1,72 +1,46 @@
1
  import gradio as gr
2
- from transformers import set_seed
3
 
 
4
  from src.utils import load_pipelines_from_config
5
 
6
  pipelines = load_pipelines_from_config(config_path="model_config.yaml")
7
 
8
- def generate(
9
- text: str,
 
10
  model: str,
11
  max_length: int = 100,
12
  temperature: float = 1.5,
13
- seed: int = 0
 
14
  ):
15
- num_beams: int = 5
16
- num_return_sequences: int = 1
17
- no_repeat_ngram_size: int = 3
18
- early_stopping: bool = True
19
- skip_special_tokens: bool = True
20
-
21
- set_seed(seed)
22
 
23
- generated = pipelines[model](
24
- text_inputs=text,
25
- max_length=max_length,
26
- num_return_sequences=num_return_sequences,
27
- num_beams=num_beams,
28
- no_repeat_ngram_size=no_repeat_ngram_size,
29
- early_stopping=early_stopping,
30
- skip_special_tokens=skip_special_tokens,
31
- temperature=temperature
32
  )
33
 
34
- return generated[0]["generated_text"]
35
-
36
 
37
  iface = gr.Interface(
38
- fn=generate,
39
  inputs=[
40
- gr.Textbox(
41
- value="[Verse]",
42
- placeholder="Input text...",
43
- label="Input Text"
44
- ),
45
  gr.Dropdown(
46
  choices=list(pipelines.keys()),
47
  value=list(pipelines.keys())[0],
48
- label="Model"
49
- ),
50
- gr.Slider(
51
- minimum=50,
52
- maximum=1000,
53
- value=100,
54
- step=10,
55
- label="Max Length"
56
- ),
57
- gr.Slider(
58
- minimum=0.4,
59
- maximum=1.9,
60
- value=1.5,
61
- step=0.1,
62
- label="Temperature"
63
- ),
64
- gr.Number(
65
- value=0,
66
- precision=0,
67
- label="Seed"
68
  ),
 
 
 
 
69
  ],
70
- outputs="text"
71
- )
72
  iface.launch()
 
1
  import gradio as gr
 
2
 
3
+ from src.generate import generate
4
  from src.utils import load_pipelines_from_config
5
 
6
  pipelines = load_pipelines_from_config(config_path="model_config.yaml")
7
 
8
+
9
+ def fn(
10
+ text_inputs: str,
11
  model: str,
12
  max_length: int = 100,
13
  temperature: float = 1.5,
14
+ seed: int = 0,
15
+ censor: bool = True,
16
  ):
 
 
 
 
 
 
 
17
 
18
+ return generate(
19
+ pipeline=pipelines[model],
20
+ pipeline_args={
21
+ "text_inputs": text_inputs,
22
+ "max_length": max_length,
23
+ "temperature": temperature,
24
+ },
25
+ seed=seed,
26
+ censor=censor,
27
  )
28
 
 
 
29
 
30
  iface = gr.Interface(
31
+ fn=fn,
32
  inputs=[
33
+ gr.Textbox(value="[Verse]", placeholder="Input text...", label="Input Text"),
 
 
 
 
34
  gr.Dropdown(
35
  choices=list(pipelines.keys()),
36
  value=list(pipelines.keys())[0],
37
+ label="Model",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ),
39
+ gr.Slider(minimum=50, maximum=1000, value=100, step=10, label="Max Length"),
40
+ gr.Slider(minimum=0.4, maximum=1.9, value=1.5, step=0.1, label="Temperature"),
41
+ gr.Number(value=0, precision=0, label="Seed"),
42
+ gr.CheckBox(value=True, label="Censor Profanity"),
43
  ],
44
+ outputs="text",
45
+ )
46
  iface.launch()
src/generate.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import set_seed
3
+
4
+ from src.profanity_filter import ProfanityFilter
5
+
6
+
7
+ def generate(
8
+ pipeline: transformers.Pipeline,
9
+ pipeline_args: dict,
10
+ seed: int = 0,
11
+ censor: bool = True,
12
+ ):
13
+
14
+ set_seed(seed)
15
+
16
+ default_pipline_args = {
17
+ "text_inputs": None,
18
+ "max_length": 100,
19
+ "num_return_sequences": 1,
20
+ "num_beams": 5,
21
+ "no_repeat_ngram_size": 3,
22
+ "early_stopping": True,
23
+ "skip_special_tokens": True,
24
+ "temperature": 1.5,
25
+ }
26
+
27
+ args = default_pipline_args.update(pipeline_args)
28
+ generated = pipeline(**args)[0]["generated_text"]
29
+
30
+ if censor:
31
+ generated = ProfanityFilter.censor(generated)
32
+
33
+ return generated
src/profanity_filter.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+
3
+ import requests
4
+
5
+
6
+ class ProfanityFilter:
7
+ def __init__(self):
8
+ BANNED_LIST_URL = "https://raw.githubusercontent.com/snguyenthanh/better_profanity/master/better_profanity/profanity_wordlist.txt"
9
+ self.banned_list = requests.get(BANNED_LIST_URL).text.split("\n")
10
+
11
+ def censor(self, text="", censor_char="*", keep_first_letter=True):
12
+
13
+ # Split sentences by newline
14
+ sentence_list = text.split("\n")
15
+ for s, sentence in enumerate(sentence_list):
16
+
17
+ # Split words in sentence by space
18
+ word_list = sentence.split()
19
+ for w, word in enumerate(word_list):
20
+
21
+ # Process word to match banned list
22
+ processed_word = word.translate(
23
+ str.maketrans("", "", string.punctuation)
24
+ ).lower()
25
+
26
+ # Replace if word is profane
27
+ if processed_word in self.banned_list:
28
+ censored_word = censor_char * len(word)
29
+
30
+ # Keep first letter of word for context if desired
31
+ if keep_first_letter:
32
+ censored_word = word[0] + censored_word[1:]
33
+
34
+ # Replcate punctuation
35
+ censored_word_punc = ""
36
+ for c, char in enumerate(word):
37
+ if char in string.punctuation:
38
+ censored_word_punc += word[c]
39
+ else:
40
+ censored_word_punc += censored_word[c]
41
+
42
+ # Update word list
43
+ word_list[w] = censored_word_punc
44
+
45
+ # Update sentence list
46
+ sentence_list[s] = word_list
47
+
48
+ # Join everything back together
49
+ return "\n".join([" ".join(word_list) for word_list in sentence_list])
src/utils.py CHANGED
@@ -1,16 +1,17 @@
1
  import yaml
2
  from transformers import pipeline
3
 
 
4
  def load_pipelines_from_config(config_path: str):
5
  with open(config_path, "r") as f:
6
  model_config = yaml.safe_load(f.read())
7
-
8
  models = {}
9
  for model, config in model_config.items():
10
  models[model] = pipeline(
11
  task=config["task"],
12
  model=config["model_name"],
13
- revision=config["hf_commit_hash"]
14
  )
15
-
16
- return models
 
1
  import yaml
2
  from transformers import pipeline
3
 
4
+
5
  def load_pipelines_from_config(config_path: str):
6
  with open(config_path, "r") as f:
7
  model_config = yaml.safe_load(f.read())
8
+
9
  models = {}
10
  for model, config in model_config.items():
11
  models[model] = pipeline(
12
  task=config["task"],
13
  model=config["model_name"],
14
+ revision=config["hf_commit_hash"],
15
  )
16
+
17
+ return models