|
import gradio as gr |
|
import os |
|
import requests |
|
from zipfile import ZipFile |
|
from tqdm import tqdm |
|
import numpy as np |
|
from PIL import Image, ImageOps |
|
import random |
|
|
|
|
|
def download_nist_sd19(url, dest_folder): |
|
if not os.path.exists(dest_folder): |
|
os.makedirs(dest_folder) |
|
filename = os.path.join(dest_folder, url.split('/')[-1]) |
|
|
|
if not os.path.exists(filename): |
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get('content-length', 0)) |
|
with open(filename, 'wb') as file, tqdm( |
|
desc=filename, |
|
total=total_size, |
|
unit='iB', |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as bar: |
|
for data in response.iter_content(chunk_size=1024): |
|
size = file.write(data) |
|
bar.update(size) |
|
|
|
with ZipFile(filename, 'r') as zip_ref: |
|
zip_ref.extractall(dest_folder) |
|
|
|
|
|
nist_sd19_url = "https://s3.amazonaws.com/nist-srd/SD19/by_class.zip" |
|
download_folder = "nist_sd19" |
|
|
|
|
|
download_nist_sd19(nist_sd19_url, download_folder) |
|
|
|
|
|
nist_dataset_path = os.path.join(download_folder, "hsf_0") |
|
|
|
|
|
def load_nist_dataset(path): |
|
images = [] |
|
labels = [] |
|
for root, dirs, files in os.walk(path): |
|
for file in files: |
|
if file.endswith(".png"): |
|
img_path = os.path.join(root, file) |
|
label = file.split('_')[1] |
|
images.append(img_path) |
|
labels.append(label) |
|
return images, labels |
|
|
|
|
|
images, labels = load_nist_dataset(nist_dataset_path) |
|
|
|
|
|
def generate_handwritten_text(input_text): |
|
char_images = [] |
|
|
|
for char in input_text: |
|
matching_images = [img for img, label in zip(images, labels) if label == char.upper()] |
|
if matching_images: |
|
char_image_path = random.choice(matching_images) |
|
char_image = Image.open(char_image_path).convert('L') |
|
|
|
|
|
char_image = ImageOps.pad(char_image, (28, 28), color='white') |
|
|
|
char_images.append(char_image) |
|
else: |
|
|
|
char_images.append(Image.new('L', (28, 28), color=255)) |
|
|
|
img_width = 28 * len(input_text) |
|
img_height = 28 |
|
output_image = Image.new('L', (img_width, img_height), color=255) |
|
|
|
for idx, char_image in enumerate(char_images): |
|
output_image.paste(char_image, (idx * 28, 0)) |
|
|
|
return output_image |
|
|
|
|
|
interface = gr.Interface(fn=generate_handwritten_text, |
|
inputs="text", |
|
outputs="image", |
|
title="NIST Handwritten Text Generator", |
|
description="Enter text to generate a handwritten text image using the NIST SD19 dataset.") |
|
|
|
interface.launch() |
|
|