Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import random | |
from datasets import load_dataset | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
# Load dataset | |
dataset = load_dataset("rahul7star/hindi-poetry")["train"] | |
# Load your model and tokenizer | |
model_name = "rahul7star/hindi_poetry_language_model" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
# 2️⃣ Function to Generate Hindi Poetry | |
def generate_poetry_base(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95): | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
with torch.no_grad(): | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
def generate_poetry(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95): | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
with torch.no_grad(): | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
temperature=temperature, # Increased randomness | |
top_p=top_p, | |
do_sample=True, | |
repetition_penalty=1.5, # Added repetition penalty to prevent duplicates | |
num_beams=5, # Use beam search for higher quality output | |
no_repeat_ngram_size=2, # Prevent repeating the same n-grams | |
early_stopping=True, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
# Poetry Generation Function with Random Selection from Dataset and Explicit 4-Line Structure | |
def generate_random_poem(prompt, max_length=180, temperature=1.0, top_p=0.9, num_lines=4): | |
# Randomly select a line from the dataset | |
random_line = random.choice(dataset["poem"]) | |
# Prepare the input text with the random line selected, and start with a unique phrase to avoid repetition | |
input_text = f"{random_line} " # Unique start to force variety | |
# Tokenize the input text | |
encoding = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=max_length) | |
input_ids = encoding.input_ids.to(model.device) | |
attention_mask = encoding.attention_mask.to(model.device) | |
# Set pad_token_id to eos_token_id | |
pad_token_id = tokenizer.eos_token_id | |
# Generate the poem using the model with increased randomness | |
output = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_length=max_length, | |
temperature=temperature, # Increased randomness | |
top_p=top_p, | |
do_sample=True, | |
repetition_penalty=1.5, # Added repetition penalty to prevent duplicates | |
num_beams=5, # Use beam search for higher quality output | |
no_repeat_ngram_size=2, # Prevent repeating the same n-grams | |
early_stopping=True, | |
pad_token_id=pad_token_id | |
) | |
# Decode the output and split into lines | |
generated_poem = tokenizer.decode(output[0], skip_special_tokens=True) | |
generated_poem = generated_poem.strip() | |
# Split the generated text into separate lines based on full stops (Hindi poems often end with "।") | |
poem_lines = generated_poem.split("।") | |
final_poem = "\n".join(poem_lines) | |
return final_poem | |
# 3️⃣ Gradio Interface | |
interface = gr.Interface( | |
fn=generate_random_poem, # Use function directly without calling it | |
inputs=[ | |
gr.Textbox(label="Enter Prompt", placeholder="Start your Hindi poem..."), | |
gr.Slider(50, 500, step=10, value=180, label="Max Length"), | |
gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature"), | |
gr.Slider(1, 100, step=1, value=50, label="Top-k Sampling"), | |
gr.Slider(0.1, 1.0, step=0.05, value=0.95, label="Top-p Sampling"), | |
], | |
outputs=gr.Textbox(label="Generated Hindi Poem"), | |
title="Hindi Poetry Generator ✨", | |
description="Generate beautiful Hindi poetry. Just enter a prompt and adjust parameters. Example: 'मैया मोरी'", | |
) | |
# 4️⃣ Run the Gradio App | |
interface.launch(share=True) | |