|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2-large") |
|
model = AutoModelForCausalLM.from_pretrained("gpt2-large") |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
st.title("Blog Post Generator") |
|
st.write("Generate a blog post for a given topic using GPT-2 Large.") |
|
|
|
|
|
topic = st.text_input("Enter the topic for your blog post:") |
|
|
|
|
|
if st.button("Generate Blog Post"): |
|
if topic: |
|
|
|
input_text = f"Write a detailed blog post about {topic}. The post should cover various aspects of the topic and provide valuable information to the readers. Start with an introduction and follow with detailed paragraphs." |
|
|
|
|
|
inputs = tokenizer.encode(input_text, return_tensors="pt") |
|
|
|
|
|
outputs = model.generate( |
|
inputs, |
|
max_length=500, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
early_stopping=True, |
|
temperature=0.7, |
|
top_p=0.9 |
|
) |
|
|
|
|
|
blog_post = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
st.write("### Generated Blog Post:") |
|
st.write(blog_post) |
|
else: |
|
st.write("Please enter a topic to generate a blog post.") |