File size: 1,230 Bytes
1c58247
 
d9b077a
d95efe1
 
 
1c58247
d9b077a
 
 
b01becb
 
 
d9b077a
04d336a
b01becb
751b8ae
d9b077a
 
 
 
 
 
 
 
04d336a
751b8ae
 
 
64dc307
d9b077a
64dc307
d9b077a
4aef217
d95efe1
 
d9b077a
3474b25
d9b077a
880a360
 
8987832
64dc307
d9b077a
 
1c58247
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr

from src.generate import generate
from src.utils import load_pipelines_from_config

pipelines = load_pipelines_from_config(config_path="model_config.yaml")


def fn(
    text_inputs: str,
    model: str,
    max_length: int = 100,
    temperature: float = 1.5,
    seed: int = 0,
    censor_profanity: bool = True,
):

    return generate(
        pipeline=pipelines[model],
        pipeline_args={
            "text_inputs": text_inputs,
            "max_length": max_length,
            "temperature": temperature,
        },
        seed=seed,
        censor_profanity=censor_profanity,
    )


iface = gr.Interface(
    fn=fn,
    inputs=[
        gr.Textbox(value="[Verse]", placeholder="Input text...", label="Input Text"),
        gr.Dropdown(
            choices=list(pipelines.keys()),
            value=list(pipelines.keys())[0],
            label="Model",
        ),
        gr.Slider(minimum=50, maximum=1000, value=100, step=10, label="Max Length"),
        gr.Slider(minimum=0.9, maximum=1.9, value=1.5, step=0.05, label="Creativity"),
        gr.Number(value=42, precision=0, label="Seed"),
        gr.Checkbox(value=True, label="Censor Profanity"),
    ],
    outputs="text",
)
iface.launch()