lyric-buddy / src /generate.py
nschenone's picture
Changed censor to function instead of class
04d336a
raw
history blame
689 Bytes
import transformers
from transformers import set_seed
from src.profanity_filter import censor
def generate(
pipeline: transformers.Pipeline,
pipeline_args: dict,
seed: int = 0,
censor_profanity: bool = True,
):
set_seed(seed)
args = {
"text_inputs": None,
"max_length": 100,
"num_return_sequences": 1,
"num_beams": 5,
"no_repeat_ngram_size": 3,
"early_stopping": True,
"skip_special_tokens": True,
"temperature": 1.5,
}
args.update(pipeline_args)
generated = pipeline(**args)[0]["generated_text"]
if censor_profanity:
generated = censor(generated)
return generated