Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
# Load model and tokenizer | |
model_path = "shrey-14/story_title_generation_model" | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_path) | |
# Initialize pipeline | |
generation_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer) | |
# Define parameters | |
num_return_sequences = 3 | |
num_beams = 10 | |
# Predefined stories | |
predefined_stories = { | |
"Story 1": "Emily's family pretended to forget her birthday. Feeling disappointed, she spent the day alone, only to come home and find a surprise party with all her friends and family waiting.", | |
"Story 2": "Alex found an old, crumpled note in the library. It read, \"Meet me under the oak tree at midnight.\" Curiosity got the better of him, and he discovered a buried time capsule filled with letters from the past.", | |
"Story 3": "Mia was rushing to leave for work when she realized her house key was missing. She retraced her steps and found it under her cat's food bowl. The little rascal had been playing with it all along.", | |
} | |
# Function to generate titles | |
def generate_title(story): | |
if not story.strip(): | |
return "Please enter a valid story." | |
title_list = [] | |
generated_titles = generation_pipeline( | |
story, | |
min_length=4, | |
num_return_sequences=num_return_sequences, | |
num_beams=num_beams | |
) | |
for generated_title in generated_titles: | |
title_list.append(generated_title['generated_text'].capitalize()) | |
return "\n".join([f"Title {i+1}: {title.capitalize()}" for i, title in enumerate(title_list)]) | |
# Function to update the story in the text area | |
def set_predefined_story(story): | |
return story | |
# Define custom CSS as a string | |
custom_css = """ | |
body { | |
background-color: white !important; | |
} | |
#component-5, #component-9, label { | |
background-color: #91BAD6; | |
} | |
span { | |
font-size: 18px !important; | |
color: white; | |
font-weight: bold; | |
} | |
textarea { | |
font-size: 16px !important; | |
color: black; | |
} | |
#generate_button { | |
font-size: 18px; | |
background-color: #2E5984; | |
color: white; | |
} | |
#story_button { | |
background-color: #91BAD6; | |
color: white; | |
} | |
#title { | |
font-size: 45px; | |
text-align: center !important; | |
} | |
#description { | |
font-size: 20px; | |
text-align: center !important; | |
} | |
""" | |
# Define Gradio interface | |
with gr.Blocks(css = custom_css) as demo: | |
gr.Markdown("<h1 id='title'>π Story Title Generator π</h1>") | |
gr.Markdown("<p id='description'>Unleash your creativity!β¨ Enter a short story, and watch as captivating titles are crafted for you. π</p>") | |
with gr.Row(): | |
with gr.Column(): | |
story_input = gr.TextArea( | |
label="Enter a Short Story", | |
lines=7, | |
placeholder="Write your story here...", | |
elem_classes="story_input" | |
) | |
generate_button = gr.Button("Generate Title", elem_id="generate_button") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Generated Titles", | |
placeholder="Your titles will appear here...", | |
lines=7, | |
elem_classes="output" | |
) | |
gr.Markdown("<h2>Select a Predefined Story:</h2>") | |
with gr.Row(): | |
for name, story in predefined_stories.items(): | |
with gr.Column(): | |
gr.Button(name, elem_id='story_button').click( | |
fn=lambda s=story: s, | |
inputs=[], | |
outputs=story_input | |
) | |
# Link the button click event to the generate_title function | |
generate_button.click(fn=generate_title, inputs=story_input, outputs=output) | |
# Launch Gradio app | |
demo.launch() | |