shrey-14's picture
Update app.py
ba5d031 verified
import gradio as gr
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer, pipeline
# Load model and tokenizer
model_path = "shrey-14/story_title_generation_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_path)
# Initialize pipeline
generation_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
# Define parameters
num_return_sequences = 3
num_beams = 10
# Predefined stories
predefined_stories = {
"Story 1": "Emily's family pretended to forget her birthday. Feeling disappointed, she spent the day alone, only to come home and find a surprise party with all her friends and family waiting.",
"Story 2": "Alex found an old, crumpled note in the library. It read, \"Meet me under the oak tree at midnight.\" Curiosity got the better of him, and he discovered a buried time capsule filled with letters from the past.",
"Story 3": "Mia was rushing to leave for work when she realized her house key was missing. She retraced her steps and found it under her cat's food bowl. The little rascal had been playing with it all along.",
}
# Function to generate titles
def generate_title(story):
if not story.strip():
return "Please enter a valid story."
title_list = []
generated_titles = generation_pipeline(
story,
min_length=4,
num_return_sequences=num_return_sequences,
num_beams=num_beams
)
for generated_title in generated_titles:
title_list.append(generated_title['generated_text'].capitalize())
return "\n".join([f"Title {i+1}: {title.capitalize()}" for i, title in enumerate(title_list)])
# Function to update the story in the text area
def set_predefined_story(story):
return story
# Define custom CSS as a string
custom_css = """
body {
background-color: white !important;
}
#component-5, #component-9, label {
background-color: #91BAD6;
}
span {
font-size: 18px !important;
color: white;
font-weight: bold;
}
textarea {
font-size: 16px !important;
color: black;
}
#generate_button {
font-size: 18px;
background-color: #2E5984;
color: white;
}
#story_button {
background-color: #91BAD6;
color: white;
}
#title {
font-size: 45px;
text-align: center !important;
}
#description {
font-size: 20px;
text-align: center !important;
}
"""
# Define Gradio interface
with gr.Blocks(css = custom_css) as demo:
gr.Markdown("<h1 id='title'>πŸ“– Story Title Generator πŸ“š</h1>")
gr.Markdown("<p id='description'>Unleash your creativity!✨ Enter a short story, and watch as captivating titles are crafted for you. πŸŽ‰</p>")
with gr.Row():
with gr.Column():
story_input = gr.TextArea(
label="Enter a Short Story",
lines=7,
placeholder="Write your story here...",
elem_classes="story_input"
)
generate_button = gr.Button("Generate Title", elem_id="generate_button")
with gr.Column():
output = gr.Textbox(
label="Generated Titles",
placeholder="Your titles will appear here...",
lines=7,
elem_classes="output"
)
gr.Markdown("<h2>Select a Predefined Story:</h2>")
with gr.Row():
for name, story in predefined_stories.items():
with gr.Column():
gr.Button(name, elem_id='story_button').click(
fn=lambda s=story: s,
inputs=[],
outputs=story_input
)
# Link the button click event to the generate_title function
generate_button.click(fn=generate_title, inputs=story_input, outputs=output)
# Launch Gradio app
demo.launch()