lyric-buddy / app.py
nschenone's picture
Added second model
64dc307
raw
history blame
1.06 kB
import gradio as gr
from transformers import pipeline, set_seed
models = {
"Rap" : pipeline("text-generation", "nschenone/rap-distil"),
"Metal" : pipeline("text-generation", "nschenone/metal-distil")
}
def generate(text, model):
max_length: int = 100
num_beams: int = 5
num_return_sequences: int = 1
no_repeat_ngram_size: int = 3
early_stopping: bool = True
skip_special_tokens: bool = True
temperature: float = 1.5
set_seed(0)
generated = models[model](
text_inputs=text,
max_length=max_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=early_stopping,
skip_special_tokens=skip_special_tokens,
temperature=temperature
)
return generated[0]["generated_text"]
iface = gr.Interface(
fn=generate,
inputs=[
"text",
gr.Dropdown(choices=models.keys(), value=models.keys()[0], label="Model")
],
outputs="text"
)
iface.launch()