File size: 6,254 Bytes
6ffd722 917022f 6ffd722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from huggingface_hub import HfApi
from torch.optim import AdamW
from tqdm import tqdm
import gc
from torch.cuda.amp import autocast
# Setare configurare CUDA pentru a reduce fragmentarea memoriei
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Verifică dacă GPU-ul este detectat
print(torch.cuda.is_available())
img_dir = '/media/andrei_ursu/storage2/chess/branches/chessgpt/backend/src/experiments/full/primulTest/SD21data'
# Definirea dataset-ului
class ManualCaptionDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.img_names = os.listdir(img_dir)
self.transform = transform
self.captions = []
for img_name in self.img_names:
caption = 'Photo of Andrei smiling and dressed in winter clothes at a Christmas market'
self.captions.append(caption)
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_name = os.path.join(self.img_dir, self.img_names[idx])
image = Image.open(img_name).convert("RGB")
caption = self.captions[idx]
if self.transform:
image = self.transform(image)
return image, caption
# Configurare transformări
transform = transforms.Compose([
transforms.Resize((256, 256)), # Dimensiune imagine redusă
transforms.ToTensor(),
])
# Crearea dataset-ului
dataset = ManualCaptionDataset(img_dir=img_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # Dimensiune batch redusă
# Încărcare model UNet
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", torch_dtype=torch.float16)
unet.to("cuda")
# Încărcare model pentru autoencoder
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", torch_dtype=torch.float16)
vae.to("cuda")
# Încărcare tokenizer și text model pentru CLIP
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
text_model.to("cuda")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Scheduler
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
# Pregătire optimizer
optimizer = AdamW(unet.parameters(), lr=5e-6)
# Setare model în modul de antrenament
unet.train()
text_model.train()
# Definire număr de epoci
num_epochs = 5
# Training loop
for epoch in range(num_epochs):
for images, captions in tqdm(dataloader):
images = images.to("cuda", dtype=torch.float16)
# Curăță memoria GPU înainte de fiecare iterare
gc.collect()
torch.cuda.empty_cache()
# Tokenizare captions
inputs = tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to("cuda")
# Generare zgomot aleatoriu
noise = torch.randn_like(images).to("cuda", dtype=torch.float16)
# Codificare imagini în latențe
latents = vae.encode(images).latent_dist.sample()
latents = latents * 0.18215
# Generare timesteps
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device="cuda").long()
# Forward pass prin UNet
encoder_hidden_states = text_model(inputs.input_ids)[0]
# Convertim encoder_hidden_states la float16
encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float16)
# Proiectăm dimensiunile `encoder_hidden_states` pentru a se potrivi cu cele așteptate de UNet
expected_dim = unet.config.cross_attention_dim
if encoder_hidden_states.shape[-1] != expected_dim:
projection_layer = torch.nn.Linear(encoder_hidden_states.shape[-1], expected_dim).to("cuda", dtype=torch.float16)
encoder_hidden_states = projection_layer(encoder_hidden_states)
# Generare predicție de zgomot
with autocast():
noise_pred = unet(latents, timesteps, encoder_hidden_states).sample
# Verifică dimensiunile tensorilor
print(f"noise_pred shape: {noise_pred.shape}")
print(f"noise shape: {noise.shape}")
# Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
if noise_pred.shape[1] != noise.shape[1]:
# Ajustează numărul de canale pentru noise_pred
conv_layer = torch.nn.Conv2d(
in_channels=noise_pred.shape[1],
out_channels=noise.shape[1],
kernel_size=1
).to("cuda", dtype=torch.float16)
noise_pred = conv_layer(noise_pred)
# Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
if noise_pred.shape[2:] != noise.shape[2:]:
noise_pred = torch.nn.functional.interpolate(noise_pred, size=images.shape[2:], mode='bilinear', align_corners=False)
# Calcul pierdere (loss) comparând ieșirea modelului cu zgomotul original
loss = torch.nn.functional.mse_loss(noise_pred, noise)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Curăță memoria GPU după fiecare iterare
gc.collect()
torch.cuda.empty_cache()
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
# Salvarea modelului antrenat
unet.save_pretrained("./finetuned-unet")
text_model.save_pretrained("./finetuned-text-model")
api = HfApi()
#api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-unet", repo_type="model")
#api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-text-model", repo_type="model")
# Încărcarea pe Hugging Face
api.upload_folder(
folder_path="./finetuned-unet",
path_in_repo=".",
repo_id="AndreiUrsu/finetuned-stable-diffusion-unet",
repo_type="model"
)
# Curăță memoria GPU la final
gc.collect()
torch.cuda.empty_cache()
|