pragmatic / app.py
NihalGazi's picture
Update app.py
da48f05 verified
import gradio as gr
import os
import requests
from zipfile import ZipFile
from tqdm import tqdm
import numpy as np
from PIL import Image, ImageOps
import random
# Function to download and extract the NIST SD19 dataset
def download_nist_sd19(url, dest_folder):
if not os.path.exists(dest_folder):
os.makedirs(dest_folder)
filename = os.path.join(dest_folder, url.split('/')[-1])
if not os.path.exists(filename):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(filename, 'wb') as file, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
with ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall(dest_folder)
# URL to download the NIST SD19 dataset
nist_sd19_url = "https://s3.amazonaws.com/nist-srd/SD19/by_class.zip"
download_folder = "nist_sd19"
# Download and extract the dataset
download_nist_sd19(nist_sd19_url, download_folder)
# Path to the NIST SD19 dataset
nist_dataset_path = os.path.join(download_folder, "hsf_0")
# Function to load the dataset
def load_nist_dataset(path):
images = []
labels = []
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(".png"):
img_path = os.path.join(root, file)
label = file.split('_')[1] # Assuming the label is in the filename
images.append(img_path)
labels.append(label)
return images, labels
# Load the dataset
images, labels = load_nist_dataset(nist_dataset_path)
# Function to generate handwritten text image
def generate_handwritten_text(input_text):
char_images = []
for char in input_text:
matching_images = [img for img, label in zip(images, labels) if label == char.upper()]
if matching_images:
char_image_path = random.choice(matching_images)
char_image = Image.open(char_image_path).convert('L')
# Add padding to each character image to make it 28x28
char_image = ImageOps.pad(char_image, (28, 28), color='white')
char_images.append(char_image)
else:
# If no matching image is found, create a blank 28x28 image
char_images.append(Image.new('L', (28, 28), color=255))
img_width = 28 * len(input_text)
img_height = 28
output_image = Image.new('L', (img_width, img_height), color=255)
for idx, char_image in enumerate(char_images):
output_image.paste(char_image, (idx * 28, 0))
return output_image
# Gradio interface
interface = gr.Interface(fn=generate_handwritten_text,
inputs="text",
outputs="image",
title="NIST Handwritten Text Generator",
description="Enter text to generate a handwritten text image using the NIST SD19 dataset.")
interface.launch()