nschenone commited on
Commit
3474b25
1 Parent(s): b01becb

Added seed and removed parameters that will not change from UI

Browse files
Files changed (1) hide show
  1. app.py +29 -33
app.py CHANGED
@@ -2,27 +2,36 @@ import gradio as gr
2
  from transformers import pipeline, set_seed
3
 
4
  models = {
5
- "Rap" : pipeline("text-generation", "nschenone/rap-distil"),
6
- "Metal" : pipeline("text-generation", "nschenone/metal-distil")
 
 
 
 
 
 
7
  }
8
 
9
  def generate(
10
  text: str,
11
  model: str,
12
  max_length: int = 100,
13
- num_beams: int = 5,
14
- no_repeat_ngram_size: int = 3,
15
- early_stopping: bool = True,
16
- skip_special_tokens: bool = True,
17
  temperature: float = 1.5,
 
18
  ):
19
- set_seed(0)
 
 
 
 
 
 
20
 
21
  generated = models[model](
22
  text_inputs=text,
23
  max_length=max_length,
 
24
  num_beams=num_beams,
25
- num_return_sequences=1,
26
  no_repeat_ngram_size=no_repeat_ngram_size,
27
  early_stopping=early_stopping,
28
  skip_special_tokens=skip_special_tokens,
@@ -35,7 +44,11 @@ def generate(
35
  iface = gr.Interface(
36
  fn=generate,
37
  inputs=[
38
- "text",
 
 
 
 
39
  gr.Dropdown(
40
  choices=list(models.keys()),
41
  value=list(models.keys())[0],
@@ -49,34 +62,17 @@ iface = gr.Interface(
49
  label="Max Length"
50
  ),
51
  gr.Slider(
52
- minimum=1,
53
- maximum=5,
54
- value=5,
55
- step=1,
56
- label="Num Beams"
57
- ),
58
- gr.Slider(
59
- minimum=1,
60
- maximum=3,
61
- value=3,
62
- step=1,
63
- label="No Repeat N-Gram Size"
64
- ),
65
- gr.Checkbox(
66
- value=True,
67
- label="Early Stopping"
68
- ),
69
- gr.Checkbox(
70
- value=True,
71
- label="Skip Special Tokens"
72
- ),
73
- gr.Slider(
74
- minimum=0,
75
- maximum=2,
76
  value=1.5,
77
  step=0.1,
78
  label="Temperature"
79
  ),
 
 
 
 
 
80
  ],
81
  outputs="text"
82
  )
 
2
  from transformers import pipeline, set_seed
3
 
4
  models = {
5
+ "Rap" : pipeline(
6
+ task="text-generation",
7
+ model="nschenone/rap-distil"
8
+ ),
9
+ "Metal" : pipeline(
10
+ task="text-generation",
11
+ model="nschenone/metal-distil"
12
+ )
13
  }
14
 
15
  def generate(
16
  text: str,
17
  model: str,
18
  max_length: int = 100,
 
 
 
 
19
  temperature: float = 1.5,
20
+ seed: int = 0
21
  ):
22
+ num_beams: int = 5
23
+ num_return_sequences: int = 1
24
+ no_repeat_ngram_size: int = 3
25
+ early_stopping: bool = True
26
+ skip_special_tokens: bool = True
27
+
28
+ set_seed(seed)
29
 
30
  generated = models[model](
31
  text_inputs=text,
32
  max_length=max_length,
33
+ num_return_sequences=num_return_sequences,
34
  num_beams=num_beams,
 
35
  no_repeat_ngram_size=no_repeat_ngram_size,
36
  early_stopping=early_stopping,
37
  skip_special_tokens=skip_special_tokens,
 
44
  iface = gr.Interface(
45
  fn=generate,
46
  inputs=[
47
+ gr.Textbox(
48
+ value="[Verse]",
49
+ placeholder="Input text...",
50
+ label="Input Text"
51
+ ),
52
  gr.Dropdown(
53
  choices=list(models.keys()),
54
  value=list(models.keys())[0],
 
62
  label="Max Length"
63
  ),
64
  gr.Slider(
65
+ minimum=0.4,
66
+ maximum=1.9,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  value=1.5,
68
  step=0.1,
69
  label="Temperature"
70
  ),
71
+ gr.Number(
72
+ value=0,
73
+ precision=0,
74
+ label="Seed"
75
+ ),
76
  ],
77
  outputs="text"
78
  )