lyric-buddy / app.py
nschenone's picture
Added model config and dynamic pipeline loading
d95efe1
raw
history blame
No virus
1.67 kB
import gradio as gr
from transformers import set_seed
from src.utils import load_pipelines_from_config
pipelines = load_pipelines_from_config(config_path="model_config.yaml")
def generate(
text: str,
model: str,
max_length: int = 100,
temperature: float = 1.5,
seed: int = 0
):
num_beams: int = 5
num_return_sequences: int = 1
no_repeat_ngram_size: int = 3
early_stopping: bool = True
skip_special_tokens: bool = True
set_seed(seed)
generated = pipelines[model](
text_inputs=text,
max_length=max_length,
num_return_sequences=num_return_sequences,
num_beams=num_beams,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=early_stopping,
skip_special_tokens=skip_special_tokens,
temperature=temperature
)
return generated[0]["generated_text"]
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(
value="[Verse]",
placeholder="Input text...",
label="Input Text"
),
gr.Dropdown(
choices=list(pipelines.keys()),
value=list(pipelines.keys())[0],
label="Model"
),
gr.Slider(
minimum=50,
maximum=1000,
value=100,
step=10,
label="Max Length"
),
gr.Slider(
minimum=0.4,
maximum=1.9,
value=1.5,
step=0.1,
label="Temperature"
),
gr.Number(
value=0,
precision=0,
label="Seed"
),
],
outputs="text"
)
iface.launch()