Textwizai / text_generation.py
Erfan11's picture
Create text_generation.py
29de403 verified
raw
history blame
462 Bytes
from transformers import GPT2LMHeadModel, GPT2Tokenizer
def generate_text(prompt, max_length=50):
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text