lyric-buddy / src /generate.py
nschenone's picture
Updated generation default args
044d15c
raw
history blame
694 Bytes
import transformers
from transformers import set_seed
from src.profanity_filter import ProfanityFilter
def generate(
pipeline: transformers.Pipeline,
pipeline_args: dict,
seed: int = 0,
censor: bool = True,
):
set_seed(seed)
args = {
"text_inputs": None,
"max_length": 100,
"num_return_sequences": 1,
"num_beams": 5,
"no_repeat_ngram_size": 3,
"early_stopping": True,
"skip_special_tokens": True,
"temperature": 1.5,
}
args.update(pipeline_args)
generated = pipeline(**args)[0]["generated_text"]
if censor:
generated = ProfanityFilter.censor(generated)
return generated