hindi / app.py
rahul7star's picture
Update app.py
8f28876 verified
raw
history blame
4.32 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import random
from datasets import load_dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load dataset
dataset = load_dataset("rahul7star/hindi-poetry")["train"]
# Load your model and tokenizer
model_name = "rahul7star/hindi_poetry_language_model"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# 2️⃣ Function to Generate Hindi Poetry
def generate_poetry_base(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
output = model.generate(
input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
pad_token_id=tokenizer.pad_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True)
def generate_poetry(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
output = model.generate(
input_ids,
max_length=max_length,
temperature=temperature, # Increased randomness
top_p=top_p,
do_sample=True,
repetition_penalty=1.5, # Added repetition penalty to prevent duplicates
num_beams=5, # Use beam search for higher quality output
no_repeat_ngram_size=2, # Prevent repeating the same n-grams
early_stopping=True,
pad_token_id=tokenizer.pad_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Poetry Generation Function with Random Selection from Dataset and Explicit 4-Line Structure
def generate_random_poem(prompt, max_length=180, temperature=1.0, top_p=0.9, num_lines=4):
# Randomly select a line from the dataset
random_line = random.choice(dataset["poem"])
# Prepare the input text with the random line selected, and start with a unique phrase to avoid repetition
input_text = f"{random_line} " # Unique start to force variety
# Tokenize the input text
encoding = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
input_ids = encoding.input_ids.to(model.device)
attention_mask = encoding.attention_mask.to(model.device)
# Set pad_token_id to eos_token_id
pad_token_id = tokenizer.eos_token_id
# Generate the poem using the model with increased randomness
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature, # Increased randomness
top_p=top_p,
do_sample=True,
repetition_penalty=1.5, # Added repetition penalty to prevent duplicates
num_beams=5, # Use beam search for higher quality output
no_repeat_ngram_size=2, # Prevent repeating the same n-grams
early_stopping=True,
pad_token_id=pad_token_id
)
# Decode the output and split into lines
generated_poem = tokenizer.decode(output[0], skip_special_tokens=True)
generated_poem = generated_poem.strip()
# Split the generated text into separate lines based on full stops (Hindi poems often end with "।")
poem_lines = generated_poem.split("।")
final_poem = "\n".join(poem_lines)
return final_poem
# 3️⃣ Gradio Interface
interface = gr.Interface(
fn=generate_random_poem, # Use function directly without calling it
inputs=[
gr.Textbox(label="Enter Prompt", placeholder="Start your Hindi poem..."),
gr.Slider(50, 500, step=10, value=180, label="Max Length"),
gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature"),
gr.Slider(1, 100, step=1, value=50, label="Top-k Sampling"),
gr.Slider(0.1, 1.0, step=0.05, value=0.95, label="Top-p Sampling"),
],
outputs=gr.Textbox(label="Generated Hindi Poem"),
title="Hindi Poetry Generator ✨",
description="Generate beautiful Hindi poetry. Just enter a prompt and adjust parameters. Example: 'मैया मोरी'",
)
# 4️⃣ Run the Gradio App
interface.launch(share=True)