Welcome to Khmer Text Image Generation!
This model is based on UNet2DConditional and is designed to generate Khmer text images.
Model Overview
This model is a conditional text-to-image generation model, meaning it requires text input encoded using the channudam/roberta-khm-35 tokenizer and encoder which is available in this collection. The model was trained from scratch without any pre-trained initialization, ensuring that it learns Khmer text generation from the ground up.
Usage & Fine-Tuning
For optimal performance, fine-tuning on your own dataset is recommended. The model serves as a foundational framework that can be further refined for specific downstream tasks.
Dataset
The dataset used for training is publicly available on Kaggle
π Khmer Text Recognition Dataset: https://www.kaggle.com/datasets/emhengly/khmer-text-recognition-dataset/data
Example Usage
To generate Khmer text images using the UNet2DConditional model, use the following example:
import torch
import matplotlib.pylab as plt
from diffusers import UNet2DConditionModel, DDPMScheduler
from transformers import RobertaTokenizerFast, RobertaModel
# Load the UNet model and tokenizer
model = UNet2DConditionModel.from_pretrained("channudam/unet2dcon-khm-35").to("cuda")
tokenizer = RobertaTokenizerFast.from_pretrained("channudam/roberta-khm-35")
text_encoder = RobertaModel.from_pretrained("channudam/roberta-khm-35").to("cuda")
# Load the DDPM scheduler
scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
# Generate random noise for image generation
batch_size = 1
image_width, image_height, channels = 64, 32, 1
# Set manual seed for reproducibility
generator = torch.Generator(device="cuda").manual_seed(42)
latents = torch.randn((batch_size, channels, image_height, image_width), device="cuda", generator=generator)
# Encode input text
text = "ααααα
" # Example Khmer text
input_ids = tokenizer(text, max_length=35, padding="max_length", truncation=True, return_tensors="pt")['input_ids'].to("cuda")
encoder_hidden_states = text_encoder(input_ids)[0]
# Denoising loop
scheduler.set_timesteps(50)
for t in scheduler.timesteps:
with torch.no_grad():
noise_pred = model(latents, t, encoder_hidden_states)[0]
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Display results
print("Encoded Text: ", input_ids)
print("Decoded Text: ", tokenizer.batch_decode(input_ids))
print("Text Embedding Shape: ", encoder_hidden_states.shape)
# Convert latents to image
plt.imshow(((latents[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().type(torch.uint8).numpy(), cmap="gray")
plt.axis("off")
plt.show()
- Downloads last month
- 23