import transformers from transformers import set_seed from src.profanity_filter import ProfanityFilter def generate( pipeline: transformers.Pipeline, pipeline_args: dict, seed: int = 0, censor: bool = True, ): set_seed(seed) args = { "text_inputs": None, "max_length": 100, "num_return_sequences": 1, "num_beams": 5, "no_repeat_ngram_size": 3, "early_stopping": True, "skip_special_tokens": True, "temperature": 1.5, } args.update(pipeline_args) generated = pipeline(**args)[0]["generated_text"] if censor: generated = ProfanityFilter.censor(generated) return generated