File size: 1,666 Bytes
1c58247
d95efe1
1c58247
d95efe1
 
 
1c58247
b01becb
 
 
 
 
3474b25
b01becb
3474b25
 
 
 
 
 
 
751b8ae
d95efe1
751b8ae
 
3474b25
751b8ae
 
 
 
 
 
 
d727040
751b8ae
 
64dc307
 
 
3474b25
 
 
 
 
4aef217
d95efe1
 
4aef217
b01becb
 
 
 
 
 
 
 
 
3474b25
 
b01becb
 
 
 
3474b25
 
 
 
 
64dc307
 
 
1c58247
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from transformers import set_seed

from src.utils import load_pipelines_from_config

pipelines = load_pipelines_from_config(config_path="model_config.yaml")

def generate(
    text: str,
    model: str,
    max_length: int = 100,
    temperature: float = 1.5,
    seed: int = 0
):
    num_beams: int = 5
    num_return_sequences: int = 1
    no_repeat_ngram_size: int = 3
    early_stopping: bool = True
    skip_special_tokens: bool = True

    set_seed(seed)

    generated = pipelines[model](
        text_inputs=text,
        max_length=max_length,
        num_return_sequences=num_return_sequences,
        num_beams=num_beams,
        no_repeat_ngram_size=no_repeat_ngram_size,
        early_stopping=early_stopping,
        skip_special_tokens=skip_special_tokens,
        temperature=temperature
    )

    return generated[0]["generated_text"]


iface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(
            value="[Verse]",
            placeholder="Input text...",
            label="Input Text"
        ),
        gr.Dropdown(
            choices=list(pipelines.keys()),
            value=list(pipelines.keys())[0],
            label="Model"
        ),
        gr.Slider(
            minimum=50,
            maximum=1000,
            value=100,
            step=10,
            label="Max Length"
        ),
        gr.Slider(
            minimum=0.4,
            maximum=1.9,
            value=1.5,
            step=0.1,
            label="Temperature"
        ),
        gr.Number(
            value=0,
            precision=0,
            label="Seed"
        ),
    ],
    outputs="text"
    )
iface.launch()