File size: 6,254 Bytes
6ffd722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
917022f
6ffd722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from huggingface_hub import HfApi
from torch.optim import AdamW
from tqdm import tqdm
import gc
from torch.cuda.amp import autocast

# Setare configurare CUDA pentru a reduce fragmentarea memoriei
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Verifică dacă GPU-ul este detectat
print(torch.cuda.is_available())  

img_dir = '/media/andrei_ursu/storage2/chess/branches/chessgpt/backend/src/experiments/full/primulTest/SD21data'

# Definirea dataset-ului
class ManualCaptionDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.img_names = os.listdir(img_dir)
        self.transform = transform
        self.captions = []
        
        for img_name in self.img_names:
            caption = 'Photo of Andrei smiling and dressed in winter clothes at a Christmas market'
            self.captions.append(caption)
        
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.img_names[idx])
        image = Image.open(img_name).convert("RGB")
        caption = self.captions[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, caption

# Configurare transformări
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Dimensiune imagine redusă
    transforms.ToTensor(),
])

# Crearea dataset-ului
dataset = ManualCaptionDataset(img_dir=img_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)  # Dimensiune batch redusă

# Încărcare model UNet
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", torch_dtype=torch.float16)
unet.to("cuda")

# Încărcare model pentru autoencoder
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", torch_dtype=torch.float16)
vae.to("cuda")

# Încărcare tokenizer și text model pentru CLIP
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
text_model.to("cuda")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

# Scheduler
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")

# Pregătire optimizer
optimizer = AdamW(unet.parameters(), lr=5e-6)

# Setare model în modul de antrenament
unet.train()
text_model.train()

# Definire număr de epoci
num_epochs = 5  

# Training loop
for epoch in range(num_epochs):
    for images, captions in tqdm(dataloader):
        images = images.to("cuda", dtype=torch.float16)
        
        # Curăță memoria GPU înainte de fiecare iterare
        gc.collect()
        torch.cuda.empty_cache()
        
        # Tokenizare captions
        inputs = tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to("cuda")
        
        # Generare zgomot aleatoriu
        noise = torch.randn_like(images).to("cuda", dtype=torch.float16)
        
        # Codificare imagini în latențe
        latents = vae.encode(images).latent_dist.sample()
        latents = latents * 0.18215

        # Generare timesteps
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device="cuda").long()

        # Forward pass prin UNet
        encoder_hidden_states = text_model(inputs.input_ids)[0]
        
        # Convertim encoder_hidden_states la float16
        encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float16)

        # Proiectăm dimensiunile `encoder_hidden_states` pentru a se potrivi cu cele așteptate de UNet
        expected_dim = unet.config.cross_attention_dim
        if encoder_hidden_states.shape[-1] != expected_dim:
            projection_layer = torch.nn.Linear(encoder_hidden_states.shape[-1], expected_dim).to("cuda", dtype=torch.float16)
            encoder_hidden_states = projection_layer(encoder_hidden_states)

        # Generare predicție de zgomot
        with autocast():
            noise_pred = unet(latents, timesteps, encoder_hidden_states).sample

        # Verifică dimensiunile tensorilor
        print(f"noise_pred shape: {noise_pred.shape}")
        print(f"noise shape: {noise.shape}")

        # Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
        if noise_pred.shape[1] != noise.shape[1]:
            # Ajustează numărul de canale pentru noise_pred
            conv_layer = torch.nn.Conv2d(
                in_channels=noise_pred.shape[1],
                out_channels=noise.shape[1],
                kernel_size=1
            ).to("cuda", dtype=torch.float16)
            noise_pred = conv_layer(noise_pred)

        # Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
        if noise_pred.shape[2:] != noise.shape[2:]:
            noise_pred = torch.nn.functional.interpolate(noise_pred, size=images.shape[2:], mode='bilinear', align_corners=False)

        # Calcul pierdere (loss) comparând ieșirea modelului cu zgomotul original
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Curăță memoria GPU după fiecare iterare
        gc.collect()
        torch.cuda.empty_cache()
    
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# Salvarea modelului antrenat
unet.save_pretrained("./finetuned-unet")
text_model.save_pretrained("./finetuned-text-model")
api = HfApi()
#api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-unet", repo_type="model")
#api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-text-model", repo_type="model")
# Încărcarea pe Hugging Face
api.upload_folder(
    folder_path="./finetuned-unet",
    path_in_repo=".",
    repo_id="AndreiUrsu/finetuned-stable-diffusion-unet",
    repo_type="model"
)


# Curăță memoria GPU la final
gc.collect()
torch.cuda.empty_cache()