.gitattributes
Browse files- .gitattributes +1 -0
- Untitled.ipynb +3 -0
- captions_moondream2.ipynb +0 -113
- captions_wd.ipynb +0 -421
- dataset_fromzip.ipynb +0 -154
- dataset_imagenet.ipynb +0 -0
- dataset_laion_coco.ipynb +0 -460
- dataset_mjnj.ipynb +0 -0
- dataset_mnist-te.ipynb +0 -458
- dataset_mnist.ipynb +0 -442
- inference.ipynb +0 -0
- model_index.json +28 -0
- pipeline_sdxs.py +295 -0
- samples/sdxs_320x576_0.jpg +2 -2
- samples/sdxs_384x576_0.jpg +2 -2
- samples/sdxs_448x576_0.jpg +2 -2
- samples/sdxs_512x576_0.jpg +2 -2
- samples/sdxs_576x320_0.jpg +2 -2
- samples/sdxs_576x384_0.jpg +2 -2
- samples/sdxs_576x448_0.jpg +2 -2
- samples/sdxs_576x512_0.jpg +2 -2
- samples/sdxs_576x576_0.jpg +2 -2
- scheduler/scheduler_config.json +19 -0
- sdxs/diffusion_pytorch_model.safetensors +1 -1
- sdxs_create.ipynb +0 -153
- src/captions_moondream2.ipynb +3 -0
- captions_qwen2-vl-7b.py → src/captions_qwen2-vl-7b.py +0 -0
- src/captions_wd.ipynb +3 -0
- dataset_combine.py → src/dataset_combine.py +0 -0
- src/dataset_fromzip.ipynb +3 -0
- src/dataset_imagenet.ipynb +3 -0
- src/dataset_laion_coco.ipynb +3 -0
- src/dataset_mjnj.ipynb +3 -0
- src/dataset_mnist-te.ipynb +3 -0
- src/dataset_mnist.ipynb +3 -0
- src/inference.ipynb +3 -0
- src/sdxs_create.ipynb +3 -0
- text_encoder/config.json +28 -0
- text_encoder/model.fp16.safetensors +3 -0
- text_projector/config.json +1 -0
- text_projector/model.safetensors +3 -0
- tokenizer/special_tokens_map.json +51 -0
- tokenizer/tokenizer_config.json +55 -0
- train.py_ +80 -47
- unet/config.json +78 -0
- unet/diffusion_pytorch_model.fp16.safetensors +3 -0
- vae/config.json +38 -0
- vae/diffusion_pytorch_model.fp16.safetensors +3 -0
.gitattributes
CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
*.png filter=lfs diff=lfs merge=lfs -text
|
38 |
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
36 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
*.png filter=lfs diff=lfs merge=lfs -text
|
38 |
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
Untitled.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b686074a7460b833710a3e653674756bc348b2ada9ecc939e4ea5d9fbcc9e05d
|
3 |
+
size 7120435
|
captions_moondream2.ipynb
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"id": "65af743c-3c88-49a6-9425-d753c35efff9",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [],
|
9 |
-
"source": [
|
10 |
-
"import os\n",
|
11 |
-
"import time\n",
|
12 |
-
"from PIL import Image, UnidentifiedImageError\n",
|
13 |
-
"from pathlib import Path\n",
|
14 |
-
"from tqdm import tqdm\n",
|
15 |
-
"\n",
|
16 |
-
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
17 |
-
"\n",
|
18 |
-
"model = AutoModelForCausalLM.from_pretrained(\n",
|
19 |
-
" \"vikhyatk/moondream2\",\n",
|
20 |
-
" revision=\"2025-01-09\",\n",
|
21 |
-
" trust_remote_code=True,\n",
|
22 |
-
" device_map={\"\": \"cuda\"}\n",
|
23 |
-
")\n",
|
24 |
-
"print('ok')\n",
|
25 |
-
"\n",
|
26 |
-
"def main(folder):\n",
|
27 |
-
" # Рекурсивное чтение всех изображений из папки и подпапок\n",
|
28 |
-
" image_folder = Path(folder).resolve()\n",
|
29 |
-
" jpeg_images = list(image_folder.rglob('*.jpeg'))\n",
|
30 |
-
" jpg_images = list(image_folder.rglob('*.jpg'))\n",
|
31 |
-
" png_images = list(image_folder.rglob('*.png'))\n",
|
32 |
-
" all_images = jpeg_images + png_images + jpg_images\n",
|
33 |
-
" all_images = sorted(all_images)\n",
|
34 |
-
" \n",
|
35 |
-
" num_images = len(all_images)\n",
|
36 |
-
" print(f\"Найдено изображений: {num_images}\")\n",
|
37 |
-
" \n",
|
38 |
-
" # Использование tqdm для прогресс-бара\n",
|
39 |
-
" printed = 0\n",
|
40 |
-
" for image_path in tqdm(all_images, desc=\"Обработка изображений\", unit=\"img\"):\n",
|
41 |
-
" # Проверка существования txt файла\n",
|
42 |
-
" text_filename = str(image_path.with_suffix('.txt'))\n",
|
43 |
-
" if Path(text_filename).exists():\n",
|
44 |
-
" continue # Пропускаем, если txt файл уже существует\n",
|
45 |
-
" \n",
|
46 |
-
" try:\n",
|
47 |
-
" # Загрузка и обработка изображения\n",
|
48 |
-
" image = Image.open(image_path)\n",
|
49 |
-
" \n",
|
50 |
-
" # Создание тумбнэйла с максимальным размером 500 пикселей\n",
|
51 |
-
" #max_size = (500, 500)\n",
|
52 |
-
" #image.thumbnail(max_size, Image.LANCZOS)\n",
|
53 |
-
" \n",
|
54 |
-
" # Обработка имени файла для подсказки\n",
|
55 |
-
" hint = \"\"\n",
|
56 |
-
" parts = str(image_path.stem).split('-')\n",
|
57 |
-
" if len(parts) >= 1:\n",
|
58 |
-
" t = parts[len(parts)-1].strip()\n",
|
59 |
-
" parts = t.split('~')\n",
|
60 |
-
" if len(parts) == 2:\n",
|
61 |
-
" hint = parts[1].replace(\"_\", \" \").replace(\"unknown\", \"\").replace(\"(\", \"\").replace(\")\", \"\").replace(\"+\", \" \").strip()\n",
|
62 |
-
" by = parts[0].replace(\"_\", \" \").replace(\"misc\", \"\").strip()\n",
|
63 |
-
" if by!= \"\":\n",
|
64 |
-
" hint+= \", \"+ by\n",
|
65 |
-
"\n",
|
66 |
-
" # Генерация подписи\n",
|
67 |
-
" mdream_capt = model.caption(image, length=\"short\")[\"caption\"]\n",
|
68 |
-
" caption = mdream_capt.replace(\"The image depicts \",\"\").replace(\"The image presents \",\"\").replace(\"The image features \",\"\").replace(\"The image portrays \",\"\").replace(\"The image is \",\"\").strip() + \" \" + hint\n",
|
69 |
-
" if printed == 0:\n",
|
70 |
-
" print(image_path, \": \",caption)\n",
|
71 |
-
" printed+=1\n",
|
72 |
-
" # Сохранение подписи в txt файл\n",
|
73 |
-
" with open(text_filename, 'w') as file_txt:\n",
|
74 |
-
" file_txt.write(caption)\n",
|
75 |
-
" \n",
|
76 |
-
" except UnidentifiedImageError:\n",
|
77 |
-
" print(f\"\\nОшибка: Невозможно идентифицировать файл '{image_path}'. Пропускаем.\")\n",
|
78 |
-
" # Можно также удалить проблемный файл\n",
|
79 |
-
" os.remove(image_path)\n",
|
80 |
-
" continue\n",
|
81 |
-
" except Exception as e:\n",
|
82 |
-
" print(f\"\\nОшибка при обработке {image_path}: {str(e)}. Пропускаем.\")\n",
|
83 |
-
" continue\n",
|
84 |
-
"\n",
|
85 |
-
" print(\"Готово!\")\n",
|
86 |
-
"\n",
|
87 |
-
"if __name__ == \"__main__\":\n",
|
88 |
-
" main(\"/workspace/all\")"
|
89 |
-
]
|
90 |
-
}
|
91 |
-
],
|
92 |
-
"metadata": {
|
93 |
-
"kernelspec": {
|
94 |
-
"display_name": "Python 3 (ipykernel)",
|
95 |
-
"language": "python",
|
96 |
-
"name": "python3"
|
97 |
-
},
|
98 |
-
"language_info": {
|
99 |
-
"codemirror_mode": {
|
100 |
-
"name": "ipython",
|
101 |
-
"version": 3
|
102 |
-
},
|
103 |
-
"file_extension": ".py",
|
104 |
-
"mimetype": "text/x-python",
|
105 |
-
"name": "python",
|
106 |
-
"nbconvert_exporter": "python",
|
107 |
-
"pygments_lexer": "ipython3",
|
108 |
-
"version": "3.11.10"
|
109 |
-
}
|
110 |
-
},
|
111 |
-
"nbformat": 4,
|
112 |
-
"nbformat_minor": 5
|
113 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captions_wd.ipynb
DELETED
@@ -1,421 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 5,
|
6 |
-
"id": "a72445eb",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"name": "stdout",
|
11 |
-
"output_type": "stream",
|
12 |
-
"text": [
|
13 |
-
"/workspace/wdv3-timm\n",
|
14 |
-
"Loading model 'vit' from 'SmilingWolf/wd-vit-tagger-v3'...\n",
|
15 |
-
"Loading tag list...\n",
|
16 |
-
"Creating data transform...\n",
|
17 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_057654d2.jpg, Caption: 1girl, ass, solo, long hair, barefoot, from behind, black hair, jacket, soles, photorealistic, feet, beach, kneeling, photo background, sand, outdoors, dirty feet, facing away, thong, bikini, black jacket, rating_questionable\n",
|
18 |
-
" \n",
|
19 |
-
"Processed 1000/48086 files, approximate remaining time: 03:32:09\n",
|
20 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_0ad74431.jpg, Caption: 1girl, solo, breasts, long hair, large breasts, navel, underwear, brown hair, panties, cleavage, mole on breast, lips, realistic, shirt pull, rating_questionable\n",
|
21 |
-
" \n",
|
22 |
-
"Processed 2000/48086 files, approximate remaining time: 03:28:39\n",
|
23 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_102581ba.jpg, Caption: 1girl, breasts, solo, black hair, underwear, bra, panties, looking at viewer, red bra, plump, veiny breasts, red panties, makeup, curly hair, red lips, curvy, rating_questionable\n",
|
24 |
-
" \n",
|
25 |
-
"Processed 3000/48086 files, approximate remaining time: 03:24:10\n",
|
26 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_15e73fd9.jpg, Caption: 1girl, breasts, solo, blonde hair, realistic, long hair, underboob, large breasts, pants, clothing cutout, lips, blue eyes, looking at viewer, standing, rating_sensitive\n",
|
27 |
-
" \n",
|
28 |
-
"Processed 4000/48086 files, approximate remaining time: 03:20:08\n",
|
29 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_1b435281.jpg, Caption: 1girl, swimsuit, solo, bikini, realistic, sunglasses, photo background, braid, red bikini, black hair, dark skin, long hair, side-tie bikini bottom, breasts, dark-skinned female, rating_sensitive\n",
|
30 |
-
" \n",
|
31 |
-
"Processed 5000/48086 files, approximate remaining time: 03:15:03\n",
|
32 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_208fa3d4.jpg, Caption: 1girl, solo, breasts, thighhighs, pasties, navel, bracelet, closed eyes, white thighhighs, jewelry, nude, covering privates, large breasts, covering crotch, brown hair, long hair, heart pasties, rating_questionable\n",
|
33 |
-
" \n",
|
34 |
-
"Processed 6000/48086 files, approximate remaining time: 03:10:38\n",
|
35 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_2613bb3e.jpg, Caption: 1girl, breasts, solo, jewelry, blonde hair, ring, hat, dress, sitting, cleavage, large breasts, long hair, couch, lips, realistic, bracelet, crossed legs, rating_sensitive\n",
|
36 |
-
" \n",
|
37 |
-
"Processed 7000/48086 files, approximate remaining time: 03:05:51\n",
|
38 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_30621d49.jpg, Caption: 1girl, photorealistic, solo, realistic, breasts, blonde hair, panties, underwear, navel, long hair, mole, looking at viewer, blurry background, cleavage, mole on breast, white panties, blurry, bracelet, jewelry, lips, large breasts, blue eyes, hair over one eye, striped clothes, rating_sensitive\n",
|
39 |
-
" \n",
|
40 |
-
"Processed 9000/48086 files, approximate remaining time: 02:57:09\n",
|
41 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_35d9c14b.jpg, Caption: 1girl, breasts, nipples, solo, jewelry, realistic, blonde hair, photorealistic, topless, tan, blue eyes, earrings, nail polish, navel, necklace, outdoors, lips, long hair, tanlines, looking at viewer, beach, large breasts, day, huge breasts, bikini tan, veiny breasts, denim, rating_questionable\n",
|
42 |
-
" \n",
|
43 |
-
"Processed 10000/48086 files, approximate remaining time: 02:52:43\n",
|
44 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_3b3e2ea3.jpg, Caption: 1girl, jewelry, solo, flower, long hair, blonde hair, cup, earrings, necklace, dress, realistic, ring, breasts, red dress, drinking glass, looking at viewer, smile, rating_sensitive\n",
|
45 |
-
" \n",
|
46 |
-
"Processed 11000/48086 files, approximate remaining time: 02:48:08\n",
|
47 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_4072b6a6.jpg, Caption: 1girl, breasts, underwear, solo, cleavage, panties, long hair, black hair, jewelry, earrings, huge breasts, bed, sitting, realistic, looking at viewer, dress, red dress, lips, red panties, bedroom, pillow, veiny breasts, indoors, rating_sensitive\n",
|
48 |
-
" \n",
|
49 |
-
"Processed 12000/48086 files, approximate remaining time: 02:43:34\n",
|
50 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_4606161b.jpg, Caption: 1girl, solo, underwear, breasts, thighhighs, panties, long hair, bra, phone, corded phone, plump, talking on phone, large breasts, black hair, photorealistic, couch, brown eyes, lingerie, realistic, white bra, nail polish, pillow, kneeling, lips, lace, antique phone, rating_questionable\n",
|
51 |
-
" \n",
|
52 |
-
"Processed 13000/48086 files, approximate remaining time: 02:38:57\n",
|
53 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_4ae817b6.jpg, Caption: 1girl, solo, breasts, black hair, long hair, flower, large breasts, realistic, jewelry, cleavage, rose, lips, looking at viewer, lying, earrings, rating_questionable\n",
|
54 |
-
" \n",
|
55 |
-
"Processed 14000/48086 files, approximate remaining time: 02:34:32\n",
|
56 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_5015c8b2.jpg, Caption: 1girl, breasts, nipples, solo, jewelry, nude, necklace, brown hair, large breasts, long hair, navel, earrings, brown eyes, realistic, photorealistic, looking at viewer, lips, rating_questionable\n",
|
57 |
-
" \n",
|
58 |
-
"Processed 15000/48086 files, approximate remaining time: 02:30:00\n",
|
59 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_557d9f7d.jpg, Caption: 1girl, breasts, solo, realistic, black hair, photorealistic, cleavage, lips, long hair, large breasts, no bra, looking at viewer, dark skin, brown eyes, nose, dark-skinned female, rating_sensitive\n",
|
60 |
-
" \n",
|
61 |
-
"Processed 16000/48086 files, approximate remaining time: 02:25:27\n",
|
62 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_5af7dce5.jpg, Caption: 1girl, animal print, breasts, swimsuit, cow print, solo, bikini, cow print bikini, head out of frame, highleg, highleg bikini, long hair, print bikini, black hair, large breasts, navel, lips, jewelry, bracelet, side-tie bikini bottom, cleavage, realistic, water, mole, rating_sensitive\n",
|
63 |
-
" \n",
|
64 |
-
"Processed 17000/48086 files, approximate remaining time: 02:21:14\n",
|
65 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_6566f354.jpg, Caption: breasts, 1girl, monochrome, greyscale, nipples, nude, head out of frame, navel, solo, tattoo, plump, gloves, large breasts, belly, indoors, rating_questionable\n",
|
66 |
-
" \n",
|
67 |
-
"Processed 19000/48086 files, approximate remaining time: 02:12:43\n",
|
68 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_6fd84ef4.jpg, Caption: 1girl, breasts, long hair, solo, reflection, black hair, large breasts, bra, cleavage, mirror, strapless, underwear, rating_sensitive\n",
|
69 |
-
" \n",
|
70 |
-
"Processed 21000/48086 files, approximate remaining time: 02:04:07\n",
|
71 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_75576a2a.jpg, Caption: 1girl, breasts, solo, long hair, brown hair, cleavage, large breasts, realistic, looking at viewer, corset, photorealistic, choker, teeth, lips, rating_sensitive\n",
|
72 |
-
" \n",
|
73 |
-
"Processed 22000/48086 files, approximate remaining time: 02:00:21\n",
|
74 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_8043ba16.jpg, Caption: 1girl, jewelry, solo, underwear, breasts, bra, earrings, realistic, panties, dark skin, dark-skinned female, necklace, large breasts, long hair, brown hair, cleavage, underwear only, white bra, navel, looking at viewer, brown eyes, rating_sensitive\n",
|
75 |
-
" \n",
|
76 |
-
"Processed 24000/48086 files, approximate remaining time: 01:52:28\n",
|
77 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_85ad2328.jpg, Caption: 1girl, breasts, solo, blonde hair, earrings, bra, long hair, jewelry, looking at viewer, white bra, underwear, cleavage, large breasts, off shoulder, realistic, blurry, photorealistic, blurry background, skirt, lips, hoop earrings, rating_sensitive\n",
|
78 |
-
" \n",
|
79 |
-
"Processed 25000/48086 files, approximate remaining time: 01:48:24\n",
|
80 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_8b39bd3a.jpg, Caption: 1girl, breasts, swimsuit, solo, bikini, realistic, photorealistic, white bikini, navel, black hair, looking at viewer, smile, grin, teeth, huge breasts, indoors, black bikini, rating_sensitive\n",
|
81 |
-
" \n",
|
82 |
-
"Processed 26000/48086 files, approximate remaining time: 01:44:14\n",
|
83 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_90a3921c.jpg, Caption: breasts, 1girl, nipples, jewelry, solo, realistic, nude, necklace, sitting, huge breasts, plump, chair, brown hair, earrings, looking at viewer, indoors, veiny breasts, lips, completely nude, hoop earrings, curly hair, parted lips, sagging breasts, large areolae, nose, painting (object), hand between legs, photorealistic, lamp, rating_questionable\n",
|
84 |
-
" \n",
|
85 |
-
"Processed 27000/48086 files, approximate remaining time: 01:40:02\n",
|
86 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_961338ff.jpg, Caption: 1girl, breasts, jewelry, realistic, solo, topless, tattoo, earrings, wet, swimsuit, bikini, lips, bikini bottom only, navel, necklace, piercing, photorealistic, long hair, large breasts, water, nose, hoop earrings, brown hair, ring, rating_questionable\n",
|
87 |
-
" \n",
|
88 |
-
"Processed 28000/48086 files, approximate remaining time: 01:35:31\n",
|
89 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_9bafc6eb.jpg, Caption: 1girl, breasts, nipples, solo, blonde hair, realistic, sitting, flower, crossed legs, lips, indoors, skirt, nose, chair, looking at viewer, ring, long hair, brown eyes, huge breasts, clothes lift, jewelry, rating_questionable\n",
|
90 |
-
" \n",
|
91 |
-
"Processed 29000/48086 files, approximate remaining time: 01:31:08\n",
|
92 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_a0d76d92.jpg, Caption: 1girl, breasts, nipples, solo, photorealistic, pussy, dress, veiny breasts, looking at viewer, sitting, jewelry, plump, realistic, no panties, uncensored, brown hair, breasts out, spread legs, red nails, huge breasts, lips, blue dress, earrings, large areolae, no bra, long hair, nose, dark areolae, dark nipples, rating_explicit\n",
|
93 |
-
" \n",
|
94 |
-
"Processed 30000/48086 files, approximate remaining time: 01:26:39\n",
|
95 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_a6492873.jpg, Caption: 1girl, breasts, solo, nipples, gloves, large breasts, thighhighs, black hair, long hair, realistic, breasts out, rating_questionable\n",
|
96 |
-
" \n",
|
97 |
-
"Processed 31000/48086 files, approximate remaining time: 01:22:07\n",
|
98 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_abcb977b.jpg, Caption: 1girl, breasts, solo, swimsuit, huge breasts, cleavage, brown hair, long hair, realistic, lips, looking at viewer, sitting, bikini, rating_sensitive\n",
|
99 |
-
" \n",
|
100 |
-
"Processed 32000/48086 files, approximate remaining time: 01:17:35\n",
|
101 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_b0d64888.jpg, Caption: 1girl, solo, navel piercing, underwear, panties, realistic, navel, blonde hair, breasts, long hair, piercing, white panties, large breasts, crop top, no pants, shirt, rating_questionable\n",
|
102 |
-
" \n",
|
103 |
-
"Processed 33000/48086 files, approximate remaining time: 01:12:59\n",
|
104 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_b61586e6.jpg, Caption: 1girl, solo, ass, breasts, long hair, jewelry, bracelet, blonde hair, large breasts, realistic, dress, red dress, detached sleeves, thighs, rating_sensitive\n",
|
105 |
-
" \n",
|
106 |
-
"Processed 34000/48086 files, approximate remaining time: 01:08:21\n",
|
107 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_c0e148a2.jpg, Caption: 1girl, breasts, solo, dress, veiny breasts, hoop earrings, looking at viewer, tree, outdoors, panties, cleavage, thick lips, photorealistic, underwear, realistic, jewelry, day, earrings, makeup, lips, dark skin, curly hair, long hair, pink dress, hanging breasts, white panties, flower, leaning forward, photo background, large breasts, huge breasts, thighs, brown hair, rating_sensitive\n",
|
108 |
-
" \n",
|
109 |
-
"Processed 36000/48086 files, approximate remaining time: 00:58:57\n",
|
110 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_c5d0e276.jpg, Caption: jewelry, blonde hair, realistic, breasts, dress, 1girl, photo background, blue eyes, long hair, mole on arm, looking at viewer, necklace, bracelet, large breasts, mole on thigh, photorealistic, thighs, table, ring, cup, lips, sitting, mole, solo focus, white dress, rating_sensitive\n",
|
111 |
-
" \n",
|
112 |
-
"Processed 37000/48086 files, approximate remaining time: 00:54:09\n",
|
113 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_cb114019.jpg, Caption: 1girl, solo, photorealistic, realistic, swimsuit, long hair, bikini, blonde hair, breasts, jewelry, bracelet, navel, pants, closed eyes, lips, blurry, blurry background, rating_sensitive\n",
|
114 |
-
" \n",
|
115 |
-
"Processed 38000/48086 files, approximate remaining time: 00:49:23\n",
|
116 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_d01c138d.jpg, Caption: 1girl, breasts, solo, blonde hair, long hair, tattoo, swimsuit, large breasts, bikini, realistic, navel, lips, highleg, blue eyes, looking at viewer, rating_sensitive\n",
|
117 |
-
" \n",
|
118 |
-
"Processed 39000/48086 files, approximate remaining time: 00:44:34\n",
|
119 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_d546c0da.jpg, Caption: 1girl, breasts, solo, cleavage, pants, realistic, long hair, jewelry, brown hair, photorealistic, shirt, off shoulder, window, jeans, black pants, denim, huge breasts, bracelet, lips, indoors, looking to the side, rating_sensitive\n",
|
120 |
-
" \n",
|
121 |
-
"Processed 40000/48086 files, approximate remaining time: 00:39:45\n",
|
122 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_da9575e3.jpg, Caption: 1girl, solo, swimsuit, long hair, bikini, breasts, navel, brown hair, brown eyes, large breasts, sitting, water, looking at viewer, white bikini, lips, rating_sensitive\n",
|
123 |
-
" \n",
|
124 |
-
"Processed 41000/48086 files, approximate remaining time: 00:34:54\n",
|
125 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_e008eb32.jpg, Caption: 1girl, solo, underwear, bra, realistic, braid, looking at viewer, lying, long hair, photorealistic, blonde hair, on stomach, rating_sensitive\n",
|
126 |
-
" \n",
|
127 |
-
"Processed 42000/48086 files, approximate remaining time: 00:30:02\n",
|
128 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_e586f5d1.jpg, Caption: 1girl, breasts, solo, nipples, long hair, brown hair, realistic, jewelry, bracelet, brown eyes, photorealistic, lips, large breasts, strap slip, looking at viewer, makeup, rating_questionable\n",
|
129 |
-
" \n",
|
130 |
-
"Processed 43000/48086 files, approximate remaining time: 00:25:09\n",
|
131 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_ea7ebe07.jpg, Caption: 1girl, breasts, solo, denim, jewelry, photorealistic, jeans, realistic, long hair, cleavage, bracelet, necklace, blonde hair, large breasts, blue eyes, torn jeans, lips, looking at viewer, multicolored hair, pants, hands on own hips, brown hair, two-tone hair, rating_sensitive\n",
|
132 |
-
" \n",
|
133 |
-
"Processed 44000/48086 files, approximate remaining time: 00:20:14\n",
|
134 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_efafbd57.jpg, Caption: breasts, bracelet, jewelry, closed eyes, mirror, reflection, 1girl, leotard, long hair, brown hair, large breasts, makeup, lips, lipstick, rating_sensitive\n",
|
135 |
-
" \n",
|
136 |
-
"Processed 45000/48086 files, approximate remaining time: 00:15:18\n",
|
137 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_f5349928.jpg, Caption: 1girl, breasts, nipples, solo, dress, tan, tanlines, earrings, jewelry, looking at viewer, sitting, huge breasts, smile, high heels, black dress, blonde hair, photorealistic, grin, breasts out, black nails, thighs, realistic, thick thighs, rating_questionable\n",
|
138 |
-
" \n",
|
139 |
-
"Processed 46000/48086 files, approximate remaining time: 00:10:22\n",
|
140 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_fa8286bf.jpg, Caption: 1girl, breasts, nipples, solo, piercing, earrings, jewelry, pool, realistic, photorealistic, navel, nipple piercing, open mouth, nude, navel piercing, hoop earrings, swimsuit, large breasts, looking at viewer, blonde hair, blurry background, lips, bikini, sagging breasts, yellow bikini, holding, nipple bar, short hair, day, rating_questionable\n",
|
141 |
-
" \n",
|
142 |
-
"Processed 47000/48086 files, approximate remaining time: 00:05:24\n",
|
143 |
-
"File: /workspace/dataset/EroticWallPhone/EroticWallPhone_ff859bea.jpg, Caption: 1girl, breasts, nipples, solo, photo background, jewelry, photorealistic, long hair, realistic, bracelet, huge breasts, outdoors, ring, looking at viewer, smile, tree, grin, lips, black hair, asian, day, brown eyes, sagging breasts, rating_questionable\n",
|
144 |
-
" \n",
|
145 |
-
"Processed 48000/48086 files, approximate remaining time: 00:00:25\n",
|
146 |
-
"Done!\n"
|
147 |
-
]
|
148 |
-
}
|
149 |
-
],
|
150 |
-
"source": [
|
151 |
-
"from dataclasses import dataclass\n",
|
152 |
-
"from pathlib import Path\n",
|
153 |
-
"from typing import Optional\n",
|
154 |
-
"\n",
|
155 |
-
"import numpy as np\n",
|
156 |
-
"import pandas as pd\n",
|
157 |
-
"import timm\n",
|
158 |
-
"import torch\n",
|
159 |
-
"from huggingface_hub import hf_hub_download\n",
|
160 |
-
"from huggingface_hub.utils import HfHubHTTPError\n",
|
161 |
-
"from PIL import Image\n",
|
162 |
-
"from simple_parsing import field, parse_known_args\n",
|
163 |
-
"from timm.data import create_transform, resolve_data_config\n",
|
164 |
-
"from torch import Tensor, nn\n",
|
165 |
-
"from torch.nn import functional as F\n",
|
166 |
-
"import os\n",
|
167 |
-
"import time\n",
|
168 |
-
"from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer\n",
|
169 |
-
"from PIL import Image\n",
|
170 |
-
"\n",
|
171 |
-
"torch_device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
172 |
-
"\n",
|
173 |
-
"new_path = '/workspace/wdv3-timm'\n",
|
174 |
-
"os.chdir(new_path)\n",
|
175 |
-
"print(os.getcwd())\n",
|
176 |
-
"\n",
|
177 |
-
"\n",
|
178 |
-
"MODEL_REPO_MAP = {\n",
|
179 |
-
" \"vit\": \"SmilingWolf/wd-vit-tagger-v3\",\n",
|
180 |
-
" \"swinv2\": \"SmilingWolf/wd-swinv2-tagger-v3\",\n",
|
181 |
-
" \"convnext\": \"SmilingWolf/wd-convnext-tagger-v3\",\n",
|
182 |
-
"}\n",
|
183 |
-
"\n",
|
184 |
-
"\n",
|
185 |
-
"\n",
|
186 |
-
"def pil_ensure_rgb(image: Image.Image) -> Image.Image:\n",
|
187 |
-
" # convert to RGB/RGBA if not already (deals with palette images etc.)\n",
|
188 |
-
" if image.mode not in [\"RGB\", \"RGBA\"]:\n",
|
189 |
-
" image = image.convert(\"RGBA\") if \"transparency\" in image.info else image.convert(\"RGB\")\n",
|
190 |
-
" # convert RGBA to RGB with white background\n",
|
191 |
-
" if image.mode == \"RGBA\":\n",
|
192 |
-
" canvas = Image.new(\"RGBA\", image.size, (255, 255, 255))\n",
|
193 |
-
" canvas.alpha_composite(image)\n",
|
194 |
-
" image = canvas.convert(\"RGB\")\n",
|
195 |
-
" return image\n",
|
196 |
-
"\n",
|
197 |
-
"\n",
|
198 |
-
"def pil_pad_square(image: Image.Image) -> Image.Image:\n",
|
199 |
-
" w, h = image.size\n",
|
200 |
-
" # get the largest dimension so we can pad to a square\n",
|
201 |
-
" px = max(image.size)\n",
|
202 |
-
" # pad to square with white background\n",
|
203 |
-
" canvas = Image.new(\"RGB\", (px, px), (255, 255, 255))\n",
|
204 |
-
" canvas.paste(image, ((px - w) // 2, (px - h) // 2))\n",
|
205 |
-
" return canvas\n",
|
206 |
-
"\n",
|
207 |
-
"\n",
|
208 |
-
"@dataclass\n",
|
209 |
-
"class LabelData:\n",
|
210 |
-
" names: list[str]\n",
|
211 |
-
" rating: list[np.int64]\n",
|
212 |
-
" general: list[np.int64]\n",
|
213 |
-
" character: list[np.int64]\n",
|
214 |
-
"\n",
|
215 |
-
"\n",
|
216 |
-
"def load_labels_hf(\n",
|
217 |
-
" repo_id: str,\n",
|
218 |
-
" revision: Optional[str] = None,\n",
|
219 |
-
" token: Optional[str] = None,\n",
|
220 |
-
") -> LabelData:\n",
|
221 |
-
" try:\n",
|
222 |
-
" csv_path = hf_hub_download(\n",
|
223 |
-
" repo_id=repo_id, filename=\"selected_tags.csv\", revision=revision, token=token\n",
|
224 |
-
" )\n",
|
225 |
-
" csv_path = Path(csv_path).resolve()\n",
|
226 |
-
" except HfHubHTTPError as e:\n",
|
227 |
-
" raise FileNotFoundError(f\"selected_tags.csv failed to download from {repo_id}\") from e\n",
|
228 |
-
"\n",
|
229 |
-
" df: pd.DataFrame = pd.read_csv(csv_path, usecols=[\"name\", \"category\"])\n",
|
230 |
-
" tag_data = LabelData(\n",
|
231 |
-
" names=df[\"name\"].tolist(),\n",
|
232 |
-
" rating=list(np.where(df[\"category\"] == 9)[0]),\n",
|
233 |
-
" general=list(np.where(df[\"category\"] == 0)[0]),\n",
|
234 |
-
" character=list(np.where(df[\"category\"] == 4)[0]),\n",
|
235 |
-
" )\n",
|
236 |
-
"\n",
|
237 |
-
" return tag_data\n",
|
238 |
-
"\n",
|
239 |
-
"\n",
|
240 |
-
"def get_tags(\n",
|
241 |
-
" probs: Tensor,\n",
|
242 |
-
" labels: LabelData,\n",
|
243 |
-
" gen_threshold: float,\n",
|
244 |
-
" char_threshold: float,\n",
|
245 |
-
"):\n",
|
246 |
-
" # Convert indices+probs to labels\n",
|
247 |
-
" probs = list(zip(labels.names, probs.numpy()))\n",
|
248 |
-
"\n",
|
249 |
-
" # First 4 labels are actually ratings\n",
|
250 |
-
" rating_labels = dict([probs[i] for i in labels.rating])\n",
|
251 |
-
" rating_labels = dict(sorted(rating_labels.items(), key=lambda item: item[1], reverse=True))\n",
|
252 |
-
"\n",
|
253 |
-
" # General labels, pick any where prediction confidence > threshold\n",
|
254 |
-
" gen_labels = [probs[i] for i in labels.general]\n",
|
255 |
-
" gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])\n",
|
256 |
-
" gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))\n",
|
257 |
-
"\n",
|
258 |
-
" # Character labels, pick any where prediction confidence > threshold\n",
|
259 |
-
" char_labels = [probs[i] for i in labels.character]\n",
|
260 |
-
" char_labels = dict([x for x in char_labels if x[1] > char_threshold])\n",
|
261 |
-
" char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))\n",
|
262 |
-
"\n",
|
263 |
-
" # Combine general and character labels, sort by confidence\n",
|
264 |
-
" combined_names = [x for x in gen_labels]\n",
|
265 |
-
" combined_names.extend([x for x in char_labels])\n",
|
266 |
-
"\n",
|
267 |
-
" # Convert to a string suitable for use as a training caption\n",
|
268 |
-
" caption = \", \".join(combined_names)\n",
|
269 |
-
"\n",
|
270 |
-
" taglist = caption.replace(\"_\", \" \").replace(\"(\", \"\\(\").replace(\")\", \"\\)\")\n",
|
271 |
-
" \n",
|
272 |
-
" caption = caption.replace(\"_\", \" \")\n",
|
273 |
-
" # Получение первого ключа из отсортированного словаря или пустой строки, если словарь пуст\n",
|
274 |
-
" caption += \", rating_\" +next(iter(sorted(rating_labels, key=rating_labels.get, reverse=True)), '')\n",
|
275 |
-
" \n",
|
276 |
-
" return caption, taglist, rating_labels, char_labels, gen_labels\n",
|
277 |
-
"\n",
|
278 |
-
"@dataclass\n",
|
279 |
-
"class ScriptOptions:\n",
|
280 |
-
" image_folder: Path = \"/workspace/dataset/EroticWallPhone\"\n",
|
281 |
-
" model: str = field(default=\"vit\")\n",
|
282 |
-
" gen_threshold: float = field(default=0.5)\n",
|
283 |
-
" char_threshold: float = field(default=0.85)\n",
|
284 |
-
"\n",
|
285 |
-
"def main(opts: ScriptOptions):\n",
|
286 |
-
" \n",
|
287 |
-
" repo_id = MODEL_REPO_MAP.get(opts.model)\n",
|
288 |
-
" image_folder = Path(opts.image_folder).resolve()\n",
|
289 |
-
" if not image_folder.is_dir():\n",
|
290 |
-
" raise NotADirectoryError(f\"Image folder not found: {image_folder}\")\n",
|
291 |
-
"\n",
|
292 |
-
" print(f\"Loading model '{opts.model}' from '{repo_id}'...\")\n",
|
293 |
-
" model: nn.Module = timm.create_model(\"hf-hub:\" + repo_id).eval()\n",
|
294 |
-
" state_dict = timm.models.load_state_dict_from_hf(repo_id)\n",
|
295 |
-
" model.load_state_dict(state_dict)\n",
|
296 |
-
"\n",
|
297 |
-
" print(\"Loading tag list...\")\n",
|
298 |
-
" labels: LabelData = load_labels_hf(repo_id=repo_id)\n",
|
299 |
-
"\n",
|
300 |
-
" print(\"Creating data transform...\")\n",
|
301 |
-
" transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))\n",
|
302 |
-
"\n",
|
303 |
-
" # Чтение всех изображений из папки\n",
|
304 |
-
" jpeg_images = image_folder.glob('*.jpeg')\n",
|
305 |
-
" jpg_images = image_folder.glob('*.jpg')\n",
|
306 |
-
" png_images = image_folder.glob('*.png')\n",
|
307 |
-
" all_images = list(jpeg_images) + list(png_images) + list(jpg_images)\n",
|
308 |
-
" # Фильтрация списка all_images, чтобы оставить только файлы, начинающиеся с \n",
|
309 |
-
" #all_images = [img for img in all_images if img.name.startswith('nus')]\n",
|
310 |
-
" all_images = sorted(all_images)\n",
|
311 |
-
" #all_images = all_images[:10]\n",
|
312 |
-
" \n",
|
313 |
-
" num_images = len(all_images)\n",
|
314 |
-
" start_time = time.time()\n",
|
315 |
-
" images_processed = 0\n",
|
316 |
-
" for image_path in all_images: # Для файлов .jpg, измените расширение при необходимости\n",
|
317 |
-
" images_processed+=1\n",
|
318 |
-
" # ... [Оставьте остальную часть кода без изменений] ...\n",
|
319 |
-
" #print(\"Loading image and preprocessing...\")\n",
|
320 |
-
" # get image\n",
|
321 |
-
" img_input: Image.Image = Image.open(image_path)\n",
|
322 |
-
" # ensure image is RGB\n",
|
323 |
-
" img_input = pil_ensure_rgb(img_input)\n",
|
324 |
-
" # pad to square with white background\n",
|
325 |
-
" img_input = pil_pad_square(img_input)\n",
|
326 |
-
" # run the model's input transform to convert to tensor and rescale\n",
|
327 |
-
" inputs: Tensor = transform(img_input).unsqueeze(0)\n",
|
328 |
-
" # NCHW image RGB to BGR\n",
|
329 |
-
" inputs = inputs[:, [2, 1, 0]]\n",
|
330 |
-
"\n",
|
331 |
-
" #print(\"Running inference...\")\n",
|
332 |
-
" with torch.inference_mode():\n",
|
333 |
-
" # move model to GPU, if available\n",
|
334 |
-
" if torch_device.type != \"cpu\":\n",
|
335 |
-
" model = model.to(torch_device)\n",
|
336 |
-
" inputs = inputs.to(torch_device)\n",
|
337 |
-
" # run the model\n",
|
338 |
-
" outputs = model.forward(inputs)\n",
|
339 |
-
" # apply the final activation function (timm doesn't support doing this internally)\n",
|
340 |
-
" outputs = F.sigmoid(outputs)\n",
|
341 |
-
" # move inputs, outputs, and model back to to cpu if we were on GPU\n",
|
342 |
-
" if torch_device.type != \"cpu\":\n",
|
343 |
-
" inputs = inputs.to(\"cpu\")\n",
|
344 |
-
" outputs = outputs.to(\"cpu\")\n",
|
345 |
-
" model = model.to(\"cpu\")\n",
|
346 |
-
"\n",
|
347 |
-
" #print(\"Processing results...\")\n",
|
348 |
-
" caption, taglist, ratings, character, general = get_tags(\n",
|
349 |
-
" probs=outputs.squeeze(0),\n",
|
350 |
-
" labels=labels,\n",
|
351 |
-
" gen_threshold=opts.gen_threshold,\n",
|
352 |
-
" char_threshold=opts.char_threshold,\n",
|
353 |
-
" )\n",
|
354 |
-
"\n",
|
355 |
-
" # Изменение: Сохранение тегов в файл\n",
|
356 |
-
" if 'nusha2' in str(image_path):\n",
|
357 |
-
" caption = \"nusha,\"+ caption\n",
|
358 |
-
" caption = caption.replace(\"dog\", \"nusha\")\n",
|
359 |
-
" caption = caption.replace(\"cat\", \"dog, nusha\")\n",
|
360 |
-
" if 'sh' in str(image_path) or 'tf' in str(image_path) or 'vkgirls' in str(image_path):\n",
|
361 |
-
" caption = \"mobile_photo,\"+ caption\n",
|
362 |
-
"\n",
|
363 |
-
"\n",
|
364 |
-
"\n",
|
365 |
-
" tags_filename = str(image_path.with_suffix('.tag'))\n",
|
366 |
-
" text_filename = str(image_path.with_suffix('.txt'))\n",
|
367 |
-
" if images_processed % 1000==0:\n",
|
368 |
-
" elapsed_time = time.time() - start_time\n",
|
369 |
-
" estimated_total_time = (elapsed_time / images_processed) * num_images\n",
|
370 |
-
" remaining_time = estimated_total_time - elapsed_time\n",
|
371 |
-
" print(f\"File: {image_path}, Caption: {caption}\\n \")\n",
|
372 |
-
" print(f\"Processed {images_processed}/{num_images} files, approximate remaining time: {time.strftime('%H:%M:%S', time.gmtime(remaining_time))}\")\n",
|
373 |
-
" with open(tags_filename, 'w' if Path(tags_filename).exists() else 'w') as file_tag:\n",
|
374 |
-
" file_tag.write(f\"{caption}\")\n",
|
375 |
-
" with open(text_filename, 'w' if Path(text_filename).exists() else 'w') as file_txt:\n",
|
376 |
-
" #text = mdream_capt + \". \" + caption\n",
|
377 |
-
" #file_txt.write(f\"{text}\")\n",
|
378 |
-
" file_txt.write(f\"{caption}\")\n",
|
379 |
-
"\n",
|
380 |
-
" print(\"Done!\")\n",
|
381 |
-
"\n",
|
382 |
-
"\n",
|
383 |
-
"if __name__ == \"__main__\":\n",
|
384 |
-
" opts, _ = parse_known_args(ScriptOptions)\n",
|
385 |
-
" #if opts.model not in MODEL_REPO_MAP:\n",
|
386 |
-
" #print(f\"Available models: {list(MODEL_REPO_MAP.keys())}\")\n",
|
387 |
-
" #raise ValueError(f\"Unknown model name '{opts.model}'\")\n",
|
388 |
-
" main(opts)\n"
|
389 |
-
]
|
390 |
-
},
|
391 |
-
{
|
392 |
-
"cell_type": "code",
|
393 |
-
"execution_count": null,
|
394 |
-
"id": "cc6a43e3-3a8e-49a2-9217-dbfb002ca40c",
|
395 |
-
"metadata": {},
|
396 |
-
"outputs": [],
|
397 |
-
"source": []
|
398 |
-
}
|
399 |
-
],
|
400 |
-
"metadata": {
|
401 |
-
"kernelspec": {
|
402 |
-
"display_name": "Python 3 (ipykernel)",
|
403 |
-
"language": "python",
|
404 |
-
"name": "python3"
|
405 |
-
},
|
406 |
-
"language_info": {
|
407 |
-
"codemirror_mode": {
|
408 |
-
"name": "ipython",
|
409 |
-
"version": 3
|
410 |
-
},
|
411 |
-
"file_extension": ".py",
|
412 |
-
"mimetype": "text/x-python",
|
413 |
-
"name": "python",
|
414 |
-
"nbconvert_exporter": "python",
|
415 |
-
"pygments_lexer": "ipython3",
|
416 |
-
"version": "3.11.10"
|
417 |
-
}
|
418 |
-
},
|
419 |
-
"nbformat": 4,
|
420 |
-
"nbformat_minor": 5
|
421 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_fromzip.ipynb
DELETED
@@ -1,154 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 20,
|
6 |
-
"id": "574d024a-3ffc-40d7-98f4-e74744e65435",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"name": "stdout",
|
11 |
-
"output_type": "stream",
|
12 |
-
"text": [
|
13 |
-
"unziped\n",
|
14 |
-
"Total number of images: 9463\n",
|
15 |
-
"copy from angelslove\n",
|
16 |
-
"ok\n"
|
17 |
-
]
|
18 |
-
}
|
19 |
-
],
|
20 |
-
"source": [
|
21 |
-
"import os\n",
|
22 |
-
"import zipfile\n",
|
23 |
-
"import zlib\n",
|
24 |
-
"import shutil\n",
|
25 |
-
"from pathlib import Path\n",
|
26 |
-
"from PIL import Image\n",
|
27 |
-
"from tqdm import tqdm \n",
|
28 |
-
"\n",
|
29 |
-
"maxsize = 1152\n",
|
30 |
-
"\n",
|
31 |
-
"def unzip_files_in_directory(directory):\n",
|
32 |
-
" for root, _, files in os.walk(directory):\n",
|
33 |
-
" for file in files:\n",
|
34 |
-
" if file.endswith('.zip'):\n",
|
35 |
-
" zip_path = os.path.join(root, file)\n",
|
36 |
-
" # Create a new subdirectory for the extracted files\n",
|
37 |
-
" extract_dir = os.path.join(root, os.path.splitext(file)[0])\n",
|
38 |
-
" os.makedirs(extract_dir, exist_ok=True)\n",
|
39 |
-
" try:\n",
|
40 |
-
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
|
41 |
-
" zip_ref.extractall(extract_dir)\n",
|
42 |
-
" except (zipfile.BadZipFile, FileNotFoundError, OSError) as e:\n",
|
43 |
-
" print(f\"Skipping corrupted or unreadable file: {zip_path}. Error: {e}\")\n",
|
44 |
-
"\n",
|
45 |
-
" # Optionally, remove the zip file after extraction\n",
|
46 |
-
" os.remove(zip_path)\n",
|
47 |
-
"\n",
|
48 |
-
"def list_image_files(directory, image_extensions=('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff')):\n",
|
49 |
-
" image_files = []\n",
|
50 |
-
" for root, _, files in os.walk(directory):\n",
|
51 |
-
" for file in files:\n",
|
52 |
-
" if file.lower().endswith(image_extensions):\n",
|
53 |
-
" image_files.append(os.path.join(root, file))\n",
|
54 |
-
" return image_files\n",
|
55 |
-
"\n",
|
56 |
-
"def compute_crc32(file_path):\n",
|
57 |
-
" crc_value = 0\n",
|
58 |
-
" with open(file_path, 'rb') as file:\n",
|
59 |
-
" for chunk in iter(lambda: file.read(4096), b''):\n",
|
60 |
-
" crc_value = zlib.crc32(chunk, crc_value)\n",
|
61 |
-
" return format(crc_value & 0xFFFFFFFF, '08x')\n",
|
62 |
-
"\n",
|
63 |
-
"\n",
|
64 |
-
"def copy_images_with_crc32(image_files, output_directory, base_folder_name):\n",
|
65 |
-
" output_path = Path(output_directory) / f\"{base_folder_name}\"\n",
|
66 |
-
" output_path.mkdir(parents=True, exist_ok=True)\n",
|
67 |
-
" for image_file in tqdm(image_files, desc=\"Processing images\"):\n",
|
68 |
-
" try:\n",
|
69 |
-
" crc32_hash = compute_crc32(image_file)\n",
|
70 |
-
" file_extension = os.path.splitext(image_file)[1]\n",
|
71 |
-
" \n",
|
72 |
-
" # Open the image file\n",
|
73 |
-
" with Image.open(image_file) as img:\n",
|
74 |
-
" # Check if the image is larger than 1600 pixels in either dimension\n",
|
75 |
-
" if max(img.size) > maxsize or img.mode != 'RGB':\n",
|
76 |
-
" # Calculate new dimensions while preserving the aspect ratio\n",
|
77 |
-
" if max(img.size) > maxsize:\n",
|
78 |
-
" img.thumbnail((maxsize, maxsize), Image.BICUBIC)\n",
|
79 |
-
" if img.mode != 'RGB':\n",
|
80 |
-
" img = img.convert('RGB')\n",
|
81 |
-
" # Save as optimized PNG\n",
|
82 |
-
" new_file_name = f\"{base_folder_name}_{crc32_hash}.png\"\n",
|
83 |
-
" new_file_path = output_path / new_file_name\n",
|
84 |
-
" img.save(new_file_path, 'PNG')#, optimize=True)\n",
|
85 |
-
" else:\n",
|
86 |
-
" # Simply copy the image if it's not larger than 1600 pixels\n",
|
87 |
-
" new_file_name = f\"{base_folder_name}_{crc32_hash}{file_extension}\"\n",
|
88 |
-
" new_file_path = output_path / new_file_name\n",
|
89 |
-
" shutil.copy(image_file, new_file_path)\n",
|
90 |
-
" \n",
|
91 |
-
" # Check if a corresponding txt file exists and copy it\n",
|
92 |
-
" image_path = Path(image_file)\n",
|
93 |
-
" txt_file_path = image_path.with_suffix('.txt')\n",
|
94 |
-
" if txt_file_path.exists():\n",
|
95 |
-
" new_txt_name = Path(new_file_name).with_suffix('.txt')\n",
|
96 |
-
" new_txt_path = output_path / new_txt_name\n",
|
97 |
-
" shutil.copy(txt_file_path, new_txt_path)\n",
|
98 |
-
" \n",
|
99 |
-
" except Exception as e:\n",
|
100 |
-
" print(f\"Skipping file due to error: {image_file}. Error: {e}\")\n",
|
101 |
-
"\n",
|
102 |
-
"def main(directory, output_directory):\n",
|
103 |
-
" base_folder_name = os.path.basename(directory.rstrip('/').rstrip('\\\\'))\n",
|
104 |
-
"\n",
|
105 |
-
" # Unzip all zip files in the directory\n",
|
106 |
-
" unzip_files_in_directory(directory)\n",
|
107 |
-
" print('unziped')\n",
|
108 |
-
"\n",
|
109 |
-
" # List all image files\n",
|
110 |
-
" image_files = list_image_files(directory)\n",
|
111 |
-
" print(f\"Total number of images: {len(image_files)}\")\n",
|
112 |
-
"\n",
|
113 |
-
" # Copy images with CRC32 hash in their names\n",
|
114 |
-
" print('copy from',base_folder_name)\n",
|
115 |
-
" copy_images_with_crc32(image_files, output_directory, base_folder_name)\n",
|
116 |
-
" print('ok')\n",
|
117 |
-
"\n",
|
118 |
-
"# Example usage\n",
|
119 |
-
"directory_path = '/workspace/all/MetalAlbumCovers'\n",
|
120 |
-
"output_directory_path = 'all2'\n",
|
121 |
-
"main(directory_path, output_directory_path)"
|
122 |
-
]
|
123 |
-
},
|
124 |
-
{
|
125 |
-
"cell_type": "code",
|
126 |
-
"execution_count": null,
|
127 |
-
"id": "e047aa2a-f7cc-4332-91cf-17c373f61ea2",
|
128 |
-
"metadata": {},
|
129 |
-
"outputs": [],
|
130 |
-
"source": []
|
131 |
-
}
|
132 |
-
],
|
133 |
-
"metadata": {
|
134 |
-
"kernelspec": {
|
135 |
-
"display_name": "Python 3 (ipykernel)",
|
136 |
-
"language": "python",
|
137 |
-
"name": "python3"
|
138 |
-
},
|
139 |
-
"language_info": {
|
140 |
-
"codemirror_mode": {
|
141 |
-
"name": "ipython",
|
142 |
-
"version": 3
|
143 |
-
},
|
144 |
-
"file_extension": ".py",
|
145 |
-
"mimetype": "text/x-python",
|
146 |
-
"name": "python",
|
147 |
-
"nbconvert_exporter": "python",
|
148 |
-
"pygments_lexer": "ipython3",
|
149 |
-
"version": "3.11.10"
|
150 |
-
}
|
151 |
-
},
|
152 |
-
"nbformat": 4,
|
153 |
-
"nbformat_minor": 5
|
154 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_imagenet.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
dataset_laion_coco.ipynb
DELETED
@@ -1,460 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 3,
|
6 |
-
"id": "248b87c8-e453-402a-bedd-e31119e3da19",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"data": {
|
11 |
-
"application/vnd.jupyter.widget-view+json": {
|
12 |
-
"model_id": "736558aed6b04e93b50ccd0531523222",
|
13 |
-
"version_major": 2,
|
14 |
-
"version_minor": 0
|
15 |
-
},
|
16 |
-
"text/plain": [
|
17 |
-
"README.md: 0%| | 0.00/11.8k [00:00<?, ?B/s]"
|
18 |
-
]
|
19 |
-
},
|
20 |
-
"metadata": {},
|
21 |
-
"output_type": "display_data"
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"data": {
|
25 |
-
"application/vnd.jupyter.widget-view+json": {
|
26 |
-
"model_id": "6e62f403424f4282a4ce9a0fb371f6eb",
|
27 |
-
"version_major": 2,
|
28 |
-
"version_minor": 0
|
29 |
-
},
|
30 |
-
"text/plain": [
|
31 |
-
"Resolving data files: 0%| | 0/31 [00:00<?, ?it/s]"
|
32 |
-
]
|
33 |
-
},
|
34 |
-
"metadata": {},
|
35 |
-
"output_type": "display_data"
|
36 |
-
},
|
37 |
-
{
|
38 |
-
"name": "stdout",
|
39 |
-
"output_type": "stream",
|
40 |
-
"text": [
|
41 |
-
"Processed 10000 records\n",
|
42 |
-
"Processed 20000 records\n",
|
43 |
-
"Processed 30000 records\n",
|
44 |
-
"Processed 40000 records\n",
|
45 |
-
"Processed 50000 records\n",
|
46 |
-
"Processed 60000 records\n",
|
47 |
-
"Processed 70000 records\n",
|
48 |
-
"Processed 80000 records\n",
|
49 |
-
"Processed 90000 records\n",
|
50 |
-
"Processed 100000 records\n",
|
51 |
-
"Processed 110000 records\n",
|
52 |
-
"Processed 120000 records\n",
|
53 |
-
"Processed 130000 records\n",
|
54 |
-
"Processed 140000 records\n",
|
55 |
-
"Processed 150000 records\n",
|
56 |
-
"Processed 160000 records\n",
|
57 |
-
"Processed 170000 records\n",
|
58 |
-
"Processed 180000 records\n",
|
59 |
-
"Processed 190000 records\n",
|
60 |
-
"Processed 200000 records\n",
|
61 |
-
"Processed 210000 records\n",
|
62 |
-
"Processed 220000 records\n",
|
63 |
-
"Processed 230000 records\n",
|
64 |
-
"Processed 240000 records\n",
|
65 |
-
"Processed 250000 records\n",
|
66 |
-
"Processed 260000 records\n",
|
67 |
-
"Processed 270000 records\n",
|
68 |
-
"Processed 280000 records\n",
|
69 |
-
"Processed 290000 records\n",
|
70 |
-
"Processed 300000 records\n",
|
71 |
-
"Processed 310000 records\n",
|
72 |
-
"Processed 320000 records\n",
|
73 |
-
"Processed 330000 records\n",
|
74 |
-
"Processed 340000 records\n",
|
75 |
-
"Processed 350000 records\n",
|
76 |
-
"Processed 360000 records\n",
|
77 |
-
"Processed 370000 records\n",
|
78 |
-
"Processed 380000 records\n",
|
79 |
-
"Processed 390000 records\n",
|
80 |
-
"Processed 400000 records\n",
|
81 |
-
"Processed 410000 records\n",
|
82 |
-
"Processed 420000 records\n",
|
83 |
-
"Processed 430000 records\n",
|
84 |
-
"Processed 440000 records\n",
|
85 |
-
"Processed 450000 records\n",
|
86 |
-
"Processed 460000 records\n",
|
87 |
-
"Processed 470000 records\n",
|
88 |
-
"Processed 480000 records\n",
|
89 |
-
"Processed 490000 records\n",
|
90 |
-
"Processed 500000 records\n",
|
91 |
-
"Processed 510000 records\n",
|
92 |
-
"Processed 520000 records\n",
|
93 |
-
"Processed 530000 records\n",
|
94 |
-
"Processed 540000 records\n",
|
95 |
-
"Processed 550000 records\n",
|
96 |
-
"Processed 560000 records\n",
|
97 |
-
"Processed 570000 records\n",
|
98 |
-
"Processed 580000 records\n",
|
99 |
-
"Processed 590000 records\n",
|
100 |
-
"Processed 600000 records\n",
|
101 |
-
"Processed 610000 records\n",
|
102 |
-
"Processed 620000 records\n",
|
103 |
-
"Processed 630000 records\n",
|
104 |
-
"Processed 640000 records\n",
|
105 |
-
"Processed 650000 records\n",
|
106 |
-
"Processed 660000 records\n",
|
107 |
-
"Processed 670000 records\n",
|
108 |
-
"Processed 680000 records\n",
|
109 |
-
"Processed 690000 records\n",
|
110 |
-
"Processed 700000 records\n",
|
111 |
-
"Processed 710000 records\n",
|
112 |
-
"Processed 720000 records\n",
|
113 |
-
"Processed 730000 records\n",
|
114 |
-
"Processed 740000 records\n",
|
115 |
-
"Processed 750000 records\n",
|
116 |
-
"Processed 760000 records\n",
|
117 |
-
"Processed 770000 records\n",
|
118 |
-
"Processed 780000 records\n",
|
119 |
-
"Processed 790000 records\n",
|
120 |
-
"Processed 800000 records\n",
|
121 |
-
"Processed 810000 records\n",
|
122 |
-
"Processed 820000 records\n",
|
123 |
-
"Processed 830000 records\n",
|
124 |
-
"Processed 840000 records\n",
|
125 |
-
"Processed 850000 records\n",
|
126 |
-
"Processed 860000 records\n",
|
127 |
-
"Processed 870000 records\n",
|
128 |
-
"Processed 880000 records\n",
|
129 |
-
"Processed 890000 records\n",
|
130 |
-
"Processed 900000 records\n",
|
131 |
-
"Processed 910000 records\n",
|
132 |
-
"Processed 920000 records\n",
|
133 |
-
"Processed 930000 records\n",
|
134 |
-
"Processed 940000 records\n",
|
135 |
-
"Processed 950000 records\n",
|
136 |
-
"Processed 960000 records\n",
|
137 |
-
"Processed 970000 records\n",
|
138 |
-
"Processed 980000 records\n",
|
139 |
-
"Processed 990000 records\n",
|
140 |
-
"Processed 1000000 records\n",
|
141 |
-
"Processed 1010000 records\n",
|
142 |
-
"Processed 1020000 records\n",
|
143 |
-
"Processed 1030000 records\n",
|
144 |
-
"Processed 1040000 records\n",
|
145 |
-
"Processed 1050000 records\n",
|
146 |
-
"Processed 1060000 records\n",
|
147 |
-
"Processed 1070000 records\n",
|
148 |
-
"Processed 1080000 records\n",
|
149 |
-
"Processed 1090000 records\n",
|
150 |
-
"Processed 1100000 records\n",
|
151 |
-
"Processed 1110000 records\n",
|
152 |
-
"Processed 1120000 records\n",
|
153 |
-
"Processed 1130000 records\n",
|
154 |
-
"Processed 1140000 records\n",
|
155 |
-
"Processed 1150000 records\n",
|
156 |
-
"Processed 1160000 records\n",
|
157 |
-
"Processed 1170000 records\n",
|
158 |
-
"Processed 1180000 records\n",
|
159 |
-
"Processed 1190000 records\n",
|
160 |
-
"Processed 1200000 records\n",
|
161 |
-
"Processed 1210000 records\n",
|
162 |
-
"Processed 1220000 records\n",
|
163 |
-
"Processed 1230000 records\n",
|
164 |
-
"Processed 1240000 records\n",
|
165 |
-
"Processed 1250000 records\n",
|
166 |
-
"Processed 1260000 records\n",
|
167 |
-
"Processed 1270000 records\n",
|
168 |
-
"Processed 1280000 records\n",
|
169 |
-
"Processed 1290000 records\n",
|
170 |
-
"Processed 1300000 records\n",
|
171 |
-
"Processed 1310000 records\n",
|
172 |
-
"Processed 1320000 records\n",
|
173 |
-
"Processed 1330000 records\n",
|
174 |
-
"Processed 1340000 records\n",
|
175 |
-
"Processed 1350000 records\n",
|
176 |
-
"Processed 1360000 records\n",
|
177 |
-
"Processed 1370000 records\n",
|
178 |
-
"Processed 1380000 records\n",
|
179 |
-
"Processed 1390000 records\n",
|
180 |
-
"Processed 1400000 records\n",
|
181 |
-
"Processed 1410000 records\n",
|
182 |
-
"Processed 1420000 records\n",
|
183 |
-
"Processed 1430000 records\n",
|
184 |
-
"Processed 1440000 records\n",
|
185 |
-
"Processed 1450000 records\n",
|
186 |
-
"Processed 1460000 records\n",
|
187 |
-
"Processed 1470000 records\n",
|
188 |
-
"Processed 1480000 records\n",
|
189 |
-
"Processed 1490000 records\n",
|
190 |
-
"Processed 1500000 records\n",
|
191 |
-
"Processed 1510000 records\n",
|
192 |
-
"Processed 1520000 records\n",
|
193 |
-
"Processed 1530000 records\n",
|
194 |
-
"Processed 1540000 records\n",
|
195 |
-
"Processed 1550000 records\n",
|
196 |
-
"Processed 1560000 records\n",
|
197 |
-
"Processed 1570000 records\n",
|
198 |
-
"Processed 1580000 records\n",
|
199 |
-
"Processed 1590000 records\n",
|
200 |
-
"Processed 1600000 records\n",
|
201 |
-
"Processed 1610000 records\n",
|
202 |
-
"Processed 1620000 records\n",
|
203 |
-
"Processed 1630000 records\n",
|
204 |
-
"Processed 1640000 records\n",
|
205 |
-
"Processed 1650000 records\n",
|
206 |
-
"Processed 1660000 records\n",
|
207 |
-
"Processed 1670000 records\n",
|
208 |
-
"Processed 1680000 records\n",
|
209 |
-
"Processed 1690000 records\n",
|
210 |
-
"Processed 1700000 records\n",
|
211 |
-
"Processed 1710000 records\n",
|
212 |
-
"Processed 1720000 records\n",
|
213 |
-
"Processed 1730000 records\n",
|
214 |
-
"Processed 1740000 records\n",
|
215 |
-
"Processed 1750000 records\n",
|
216 |
-
"Processed 1760000 records\n",
|
217 |
-
"Processed 1770000 records\n",
|
218 |
-
"Processed 1780000 records\n",
|
219 |
-
"Processed 1790000 records\n",
|
220 |
-
"Processed 1800000 records\n",
|
221 |
-
"Processed 1810000 records\n",
|
222 |
-
"Processed 1820000 records\n",
|
223 |
-
"Processed 1830000 records\n",
|
224 |
-
"Processed 1840000 records\n",
|
225 |
-
"Processed 1850000 records\n",
|
226 |
-
"Processed 1860000 records\n",
|
227 |
-
"Processed 1870000 records\n",
|
228 |
-
"Processed 1880000 records\n",
|
229 |
-
"Processed 1890000 records\n",
|
230 |
-
"Processed 1900000 records\n",
|
231 |
-
"Processed 1910000 records\n",
|
232 |
-
"Processed 1920000 records\n",
|
233 |
-
"Processed 1930000 records\n",
|
234 |
-
"Processed 1940000 records\n",
|
235 |
-
"Processed 1950000 records\n",
|
236 |
-
"Processed 1960000 records\n",
|
237 |
-
"Processed 1970000 records\n",
|
238 |
-
"Processed 1980000 records\n",
|
239 |
-
"{'key': '092026937', 'caption': 'A modern living room featuring a white Chase sofa, sleek coffee tables, a gray shelf unit with decorative items, a cozy chair with a yellow blanket, and a dracaena plant, all bathed in natural light.', 'url': 'https://images.webfronts.com/cache/frbbxseehlqx.jpg?imgeng=/w_1000/h_1000/m_letterbox_ffffff_100'}\n"
|
240 |
-
]
|
241 |
-
},
|
242 |
-
{
|
243 |
-
"data": {
|
244 |
-
"application/vnd.jupyter.widget-view+json": {
|
245 |
-
"model_id": "24d1e19a18e649259aea982b3c561c95",
|
246 |
-
"version_major": 2,
|
247 |
-
"version_minor": 0
|
248 |
-
},
|
249 |
-
"text/plain": [
|
250 |
-
"Saving the dataset (0/1 shards): 0%| | 0/300000 [00:00<?, ? examples/s]"
|
251 |
-
]
|
252 |
-
},
|
253 |
-
"metadata": {},
|
254 |
-
"output_type": "display_data"
|
255 |
-
}
|
256 |
-
],
|
257 |
-
"source": [
|
258 |
-
"from datasets import load_dataset, Dataset\n",
|
259 |
-
"import random\n",
|
260 |
-
"\n",
|
261 |
-
"# Загрузка метаданных датасета\n",
|
262 |
-
"dataset = load_dataset(\"CaptionEmporium/laion-coco-13m-molmo-d-7b\", split=\"train\", streaming=True)\n",
|
263 |
-
"\n",
|
264 |
-
"# Условия для фильтрации\n",
|
265 |
-
"min_wh = 768 # Минимальное значение для min(width, height)\n",
|
266 |
-
"max_wh = 1152 # Максимальное значение для max(width, height)\n",
|
267 |
-
"max_cnt = 300000\n",
|
268 |
-
"\n",
|
269 |
-
"# Функция для выбора случайного значения из двух полей с вероятностью 0.5\n",
|
270 |
-
"def select_caption(example):\n",
|
271 |
-
" if random.random() < 0.5:\n",
|
272 |
-
" return example[\"caption_molmo_short\"]\n",
|
273 |
-
" else:\n",
|
274 |
-
" return example[\"caption_molmo_medium\"]\n",
|
275 |
-
"\n",
|
276 |
-
"# Фильтрация и выборка данных\n",
|
277 |
-
"counter = 0\n",
|
278 |
-
"filtered_examples = {\"key\": [],\"caption\": [], \"url\": []}\n",
|
279 |
-
"for example in dataset:\n",
|
280 |
-
" if (example.get(\"pwatermark\") is not None and example[\"pwatermark\"] < 0.9 and\n",
|
281 |
-
" example.get(\"width\") is not None and example.get(\"height\") is not None and\n",
|
282 |
-
" min(example[\"width\"], example[\"height\"]) > min_wh and\n",
|
283 |
-
" max(example[\"width\"], example[\"height\"]) < max_wh and\n",
|
284 |
-
" example.get(\"status\") == \"success\"):\n",
|
285 |
-
" filtered_examples[\"key\"].append(example[\"key\"])\n",
|
286 |
-
" filtered_examples[\"caption\"].append(select_caption(example))\n",
|
287 |
-
" filtered_examples[\"url\"].append(example[\"url\"])\n",
|
288 |
-
"\n",
|
289 |
-
" counter += 1\n",
|
290 |
-
" if counter % 10000 == 0:\n",
|
291 |
-
" print(f\"Processed {counter} records\")\n",
|
292 |
-
" if len(filtered_examples[\"caption\"]) >= max_cnt:\n",
|
293 |
-
" break\n",
|
294 |
-
"\n",
|
295 |
-
"# Преобразование в Dataset\n",
|
296 |
-
"filtered_dataset = Dataset.from_dict(filtered_examples)\n",
|
297 |
-
"\n",
|
298 |
-
"# Вывод результатов\n",
|
299 |
-
"print(filtered_dataset[0])\n",
|
300 |
-
"import os\n",
|
301 |
-
"# Сохранение отфильтрованного датасета\n",
|
302 |
-
"save_path = f\"laion-coco-{min_wh}-{max_wh}-{max_cnt}\"\n",
|
303 |
-
"os.makedirs(save_path, exist_ok=True)\n",
|
304 |
-
"filtered_dataset.save_to_disk(save_path)\n"
|
305 |
-
]
|
306 |
-
},
|
307 |
-
{
|
308 |
-
"cell_type": "code",
|
309 |
-
"execution_count": 4,
|
310 |
-
"id": "134f88b8-bdfa-44e2-8d51-73820ce6beac",
|
311 |
-
"metadata": {},
|
312 |
-
"outputs": [
|
313 |
-
{
|
314 |
-
"name": "stderr",
|
315 |
-
"output_type": "stream",
|
316 |
-
"text": [
|
317 |
-
"Processing dataset: 100%|██████████| 300000/300000 [00:13<00:00, 21914.44it/s]\n",
|
318 |
-
"Downloading images: 0%| | 656/300000 [01:43<17:12:32, 4.83it/s]/usr/local/lib/python3.11/dist-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
|
319 |
-
" warnings.warn(\n",
|
320 |
-
"Downloading images: 28%|██▊ | 83916/300000 [5:04:44<6:53:39, 8.71it/s] IOPub message rate exceeded.\n",
|
321 |
-
"The Jupyter server will temporarily stop sending output\n",
|
322 |
-
"to the client in order to avoid crashing it.\n",
|
323 |
-
"To change this limit, set the config variable\n",
|
324 |
-
"`--ServerApp.iopub_msg_rate_limit`.\n",
|
325 |
-
"\n",
|
326 |
-
"Current values:\n",
|
327 |
-
"ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
328 |
-
"ServerApp.rate_limit_window=3.0 (secs)\n",
|
329 |
-
"\n",
|
330 |
-
"Downloading images: 37%|███▋ | 110322/300000 [6:29:51<14:58:09, 3.52it/s]IOPub message rate exceeded.\n",
|
331 |
-
"The Jupyter server will temporarily stop sending output\n",
|
332 |
-
"to the client in order to avoid crashing it.\n",
|
333 |
-
"To change this limit, set the config variable\n",
|
334 |
-
"`--ServerApp.iopub_msg_rate_limit`.\n",
|
335 |
-
"\n",
|
336 |
-
"Current values:\n",
|
337 |
-
"ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
338 |
-
"ServerApp.rate_limit_window=3.0 (secs)\n",
|
339 |
-
"\n",
|
340 |
-
"Downloading images: 50%|█████ | 150065/300000 [8:50:30<9:34:24, 4.35it/s]IOPub message rate exceeded.\n",
|
341 |
-
"The Jupyter server will temporarily stop sending output\n",
|
342 |
-
"to the client in order to avoid crashing it.\n",
|
343 |
-
"To change this limit, set the config variable\n",
|
344 |
-
"`--ServerApp.iopub_msg_rate_limit`.\n",
|
345 |
-
"\n",
|
346 |
-
"Current values:\n",
|
347 |
-
"ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
348 |
-
"ServerApp.rate_limit_window=3.0 (secs)\n",
|
349 |
-
"\n",
|
350 |
-
"Downloading images: 53%|█████▎ | 158966/300000 [9:26:00<8:13:29, 4.76it/s] IOPub message rate exceeded.\n",
|
351 |
-
"The Jupyter server will temporarily stop sending output\n",
|
352 |
-
"to the client in order to avoid crashing it.\n",
|
353 |
-
"To change this limit, set the config variable\n",
|
354 |
-
"`--ServerApp.iopub_msg_rate_limit`.\n",
|
355 |
-
"\n",
|
356 |
-
"Current values:\n",
|
357 |
-
"ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
|
358 |
-
"ServerApp.rate_limit_window=3.0 (secs)\n",
|
359 |
-
"\n"
|
360 |
-
]
|
361 |
-
},
|
362 |
-
{
|
363 |
-
"ename": "TypeError",
|
364 |
-
"evalue": "cannot unpack non-iterable int object",
|
365 |
-
"output_type": "error",
|
366 |
-
"traceback": [
|
367 |
-
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
368 |
-
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
|
369 |
-
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 45\u001b[39m\n\u001b[32m 42\u001b[39m future_to_key = {executor.submit(download_and_save_image_and_caption, idx, row[\u001b[33m\"\u001b[39m\u001b[33mkey\u001b[39m\u001b[33m\"\u001b[39m], row[\u001b[33m\"\u001b[39m\u001b[33murl\u001b[39m\u001b[33m\"\u001b[39m], row[\u001b[33m\"\u001b[39m\u001b[33mcaption\u001b[39m\u001b[33m\"\u001b[39m]): row[\u001b[33m\"\u001b[39m\u001b[33mkey\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m idx, row \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28menumerate\u001b[39m(dataset.select(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mmin\u001b[39m(max_rows, \u001b[38;5;28mlen\u001b[39m(dataset))))), total=\u001b[38;5;28mmin\u001b[39m(max_rows, \u001b[38;5;28mlen\u001b[39m(dataset)), desc=\u001b[33m\"\u001b[39m\u001b[33mProcessing dataset\u001b[39m\u001b[33m\"\u001b[39m)}\n\u001b[32m 44\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m future \u001b[38;5;129;01min\u001b[39;00m tqdm(as_completed(future_to_key), total=\u001b[38;5;28mlen\u001b[39m(future_to_key), desc=\u001b[33m\"\u001b[39m\u001b[33mDownloading images\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m \u001b[43mfuture\u001b[49m\u001b[43m.\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Ожидание завершения всех задач\u001b[39;00m\n\u001b[32m 47\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mProcessing complete.\u001b[39m\u001b[33m\"\u001b[39m)\n",
|
370 |
-
"\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.11/concurrent/futures/_base.py:449\u001b[39m, in \u001b[36mFuture.result\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[32m 448\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state == FINISHED:\n\u001b[32m--> \u001b[39m\u001b[32m449\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 451\u001b[39m \u001b[38;5;28mself\u001b[39m._condition.wait(timeout)\n\u001b[32m 453\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._state \u001b[38;5;129;01min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
|
371 |
-
"\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.11/concurrent/futures/_base.py:401\u001b[39m, in \u001b[36mFuture.__get_result\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception:\n\u001b[32m 400\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._exception\n\u001b[32m 402\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[32m 404\u001b[39m \u001b[38;5;28mself\u001b[39m = \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
372 |
-
"\u001b[36mFile \u001b[39m\u001b[32m/usr/lib/python3.11/concurrent/futures/thread.py:58\u001b[39m, in \u001b[36m_WorkItem.run\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m 57\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m58\u001b[39m result = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 59\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[32m 60\u001b[39m \u001b[38;5;28mself\u001b[39m.future.set_exception(exc)\n",
|
373 |
-
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 26\u001b[39m, in \u001b[36mdownload_and_save_image_and_caption\u001b[39m\u001b[34m(index, key, url, caption)\u001b[39m\n\u001b[32m 24\u001b[39m response.raise_for_status()\n\u001b[32m 25\u001b[39m img = Image.open(BytesIO(response.content)).convert(\u001b[33m\"\u001b[39m\u001b[33mRGB\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m26\u001b[39m \u001b[43mimg\u001b[49m\u001b[43m.\u001b[49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m.\u001b[49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mkey\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[33;43m.png\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mPNG\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimize\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 28\u001b[39m \u001b[38;5;66;03m# Сохранение текстового файла\u001b[39;00m\n\u001b[32m 29\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(os.path.join(output_path, \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.txt\u001b[39m\u001b[33m\"\u001b[39m), \u001b[33m\"\u001b[39m\u001b[33mw\u001b[39m\u001b[33m\"\u001b[39m, encoding=\u001b[33m\"\u001b[39m\u001b[33mutf-8\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n",
|
374 |
-
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/PIL/Image.py:2605\u001b[39m, in \u001b[36mImage.save\u001b[39m\u001b[34m(self, fp, format, **params)\u001b[39m\n\u001b[32m 2602\u001b[39m fp = cast(IO[\u001b[38;5;28mbytes\u001b[39m], fp)\n\u001b[32m 2604\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2605\u001b[39m \u001b[43msave_handler\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2606\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[32m 2607\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m open_fp:\n",
|
375 |
-
"\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.11/dist-packages/PIL/PngImagePlugin.py:1442\u001b[39m, in \u001b[36m_save\u001b[39m\u001b[34m(im, fp, filename, chunk, save_all)\u001b[39m\n\u001b[32m 1440\u001b[39m chunk(fp, \u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mtRNS\u001b[39m\u001b[33m\"\u001b[39m, o16(transparency))\n\u001b[32m 1441\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m im.mode == \u001b[33m\"\u001b[39m\u001b[33mRGB\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m-> \u001b[39m\u001b[32m1442\u001b[39m red, green, blue = transparency\n\u001b[32m 1443\u001b[39m chunk(fp, \u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mtRNS\u001b[39m\u001b[33m\"\u001b[39m, o16(red) + o16(green) + o16(blue))\n\u001b[32m 1444\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
376 |
-
"\u001b[31mTypeError\u001b[39m: cannot unpack non-iterable int object"
|
377 |
-
]
|
378 |
-
}
|
379 |
-
],
|
380 |
-
"source": [
|
381 |
-
"import os\n",
|
382 |
-
"import requests\n",
|
383 |
-
"from PIL import Image\n",
|
384 |
-
"from io import BytesIO\n",
|
385 |
-
"from concurrent.futures import ThreadPoolExecutor, as_completed\n",
|
386 |
-
"from tqdm import tqdm\n",
|
387 |
-
"from datasets import load_from_disk, Dataset\n",
|
388 |
-
"from urllib.parse import unquote\n",
|
389 |
-
"\n",
|
390 |
-
"# Путь к сохраненному датасету\n",
|
391 |
-
"dataset_path = \"laion-coco-768-1152-300000\"\n",
|
392 |
-
"\n",
|
393 |
-
"# Максимальное количество строк для обработки\n",
|
394 |
-
"max_rows = 300000\n",
|
395 |
-
"\n",
|
396 |
-
"# Путь для сохранения данных\n",
|
397 |
-
"output_path = f\"{dataset_path}-data-{max_rows}\"\n",
|
398 |
-
"os.makedirs(output_path, exist_ok=True)\n",
|
399 |
-
"\n",
|
400 |
-
"# Функция для скачивания и сохранения изображения и текстового файла\n",
|
401 |
-
"def download_and_save_image_and_caption(index, key, url, caption):\n",
|
402 |
-
" try:\n",
|
403 |
-
" response = requests.get(unquote(url), timeout=10)\n",
|
404 |
-
" response.raise_for_status()\n",
|
405 |
-
" img = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
|
406 |
-
" img.save(os.path.join(output_path, f\"{key}.png\"), \"PNG\", optimize=True)\n",
|
407 |
-
"\n",
|
408 |
-
" # Сохранение текстового файла\n",
|
409 |
-
" with open(os.path.join(output_path, f\"{key}.txt\"), \"w\", encoding=\"utf-8\") as f:\n",
|
410 |
-
" f.write(caption)\n",
|
411 |
-
"\n",
|
412 |
-
" return True\n",
|
413 |
-
" except (requests.RequestException, IOError) as e:\n",
|
414 |
-
" #print(f\"Failed to download image for key {key} (row {index})\")\n",
|
415 |
-
" return False\n",
|
416 |
-
"\n",
|
417 |
-
"# Чтение датасета\n",
|
418 |
-
"dataset = load_from_disk(dataset_path)\n",
|
419 |
-
"\n",
|
420 |
-
"# Обработка датасета\n",
|
421 |
-
"with ThreadPoolExecutor(max_workers=5) as executor:\n",
|
422 |
-
" future_to_key = {executor.submit(download_and_save_image_and_caption, idx, row[\"key\"], row[\"url\"], row[\"caption\"]): row[\"key\"] for idx, row in tqdm(enumerate(dataset.select(range(min(max_rows, len(dataset))))), total=min(max_rows, len(dataset)), desc=\"Processing dataset\")}\n",
|
423 |
-
"\n",
|
424 |
-
" for future in tqdm(as_completed(future_to_key), total=len(future_to_key), desc=\"Downloading images\"):\n",
|
425 |
-
" future.result() # Ожидание завершения всех задач\n",
|
426 |
-
"\n",
|
427 |
-
"print(\"Processing complete.\")\n"
|
428 |
-
]
|
429 |
-
},
|
430 |
-
{
|
431 |
-
"cell_type": "code",
|
432 |
-
"execution_count": null,
|
433 |
-
"id": "1de5c500-7d65-41a6-9afb-67b454646fb9",
|
434 |
-
"metadata": {},
|
435 |
-
"outputs": [],
|
436 |
-
"source": []
|
437 |
-
}
|
438 |
-
],
|
439 |
-
"metadata": {
|
440 |
-
"kernelspec": {
|
441 |
-
"display_name": "Python 3 (ipykernel)",
|
442 |
-
"language": "python",
|
443 |
-
"name": "python3"
|
444 |
-
},
|
445 |
-
"language_info": {
|
446 |
-
"codemirror_mode": {
|
447 |
-
"name": "ipython",
|
448 |
-
"version": 3
|
449 |
-
},
|
450 |
-
"file_extension": ".py",
|
451 |
-
"mimetype": "text/x-python",
|
452 |
-
"name": "python",
|
453 |
-
"nbconvert_exporter": "python",
|
454 |
-
"pygments_lexer": "ipython3",
|
455 |
-
"version": "3.11.10"
|
456 |
-
}
|
457 |
-
},
|
458 |
-
"nbformat": 4,
|
459 |
-
"nbformat_minor": 5
|
460 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_mjnj.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
dataset_mnist-te.ipynb
DELETED
@@ -1,458 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 2,
|
6 |
-
"id": "48d89cbc-3660-49b9-be37-087c1b05bf78",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"data": {
|
11 |
-
"application/vnd.jupyter.widget-view+json": {
|
12 |
-
"model_id": "a776258b88794927b45bbfc8fc9bdcb0",
|
13 |
-
"version_major": 2,
|
14 |
-
"version_minor": 0
|
15 |
-
},
|
16 |
-
"text/plain": [
|
17 |
-
"Encoding images to latents: 0%| | 0/60000 [00:00<?, ? examples/s]"
|
18 |
-
]
|
19 |
-
},
|
20 |
-
"metadata": {},
|
21 |
-
"output_type": "display_data"
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"data": {
|
25 |
-
"application/vnd.jupyter.widget-view+json": {
|
26 |
-
"model_id": "1f4adfe6d6794196af7755992bac151c",
|
27 |
-
"version_major": 2,
|
28 |
-
"version_minor": 0
|
29 |
-
},
|
30 |
-
"text/plain": [
|
31 |
-
"Saving the dataset (0/143 shards): 0%| | 0/60000 [00:00<?, ? examples/s]"
|
32 |
-
]
|
33 |
-
},
|
34 |
-
"metadata": {},
|
35 |
-
"output_type": "display_data"
|
36 |
-
},
|
37 |
-
{
|
38 |
-
"name": "stdout",
|
39 |
-
"output_type": "stream",
|
40 |
-
"text": [
|
41 |
-
"ok\n"
|
42 |
-
]
|
43 |
-
}
|
44 |
-
],
|
45 |
-
"source": [
|
46 |
-
"# pip install datasets diffusers transformers\n",
|
47 |
-
"# pip install accelerate\n",
|
48 |
-
"# pip install flash-attn --no-build-isolation\n",
|
49 |
-
"# git config --global credential.helper store\n",
|
50 |
-
"# pip install -U \"huggingface_hub[cli]\"\n",
|
51 |
-
"# huggingface-cli login\n",
|
52 |
-
"from datasets import load_dataset, DatasetDict\n",
|
53 |
-
"from diffusers import AutoencoderKL\n",
|
54 |
-
"from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode\n",
|
55 |
-
"from transformers import AutoModel, AutoImageProcessor, AutoTokenizer\n",
|
56 |
-
"import torch\n",
|
57 |
-
"import os\n",
|
58 |
-
"import numpy as np\n",
|
59 |
-
"from PIL import Image\n",
|
60 |
-
"from tqdm import tqdm\n",
|
61 |
-
"import random\n",
|
62 |
-
"from pathlib import Path\n",
|
63 |
-
"\n",
|
64 |
-
"# ---------------- 1️⃣ Настройки ----------------\n",
|
65 |
-
"dtype = torch.float16\n",
|
66 |
-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
67 |
-
"batch_size = 64\n",
|
68 |
-
"img_size = 64\n",
|
69 |
-
"img_share = 0.0\n",
|
70 |
-
"empty_share = 0.01\n",
|
71 |
-
"\n",
|
72 |
-
"# 1. Явно создать все необходимые директории\n",
|
73 |
-
"cache_root = Path(\"cache\")\n",
|
74 |
-
"(cache_root/\"datasets\").mkdir(parents=True, exist_ok=True)\n",
|
75 |
-
"Path(\"datasets\").mkdir(parents=True, exist_ok=True)\n",
|
76 |
-
"\n",
|
77 |
-
"# 2. Установить переменные среды ПЕРЕД импортом библиотек\n",
|
78 |
-
"os.environ[\"HF_HOME\"] = str(cache_root)\n",
|
79 |
-
"os.environ[\"HF_DATASETS_CACHE\"] = str(cache_root/\"datasets\")\n",
|
80 |
-
"\n",
|
81 |
-
"\n",
|
82 |
-
"# ---------------- 2️⃣ Загрузка датасета ----------------\n",
|
83 |
-
"dataset = load_dataset(\"mnist\", split=\"train\", cache_dir=str(cache_root))\n",
|
84 |
-
"\n",
|
85 |
-
"# ---------------- 3️⃣ Загрузка моделей ----------------\n",
|
86 |
-
"vae = AutoencoderKL.from_pretrained(\"AuraDiffusion/16ch-vae\", torch_dtype=dtype).to(device).eval()\n",
|
87 |
-
"model = AutoModel.from_pretrained(\"visheratin/mexma-siglip\", torch_dtype=dtype, trust_remote_code=True, optimized=True).to(device)\n",
|
88 |
-
"processor = AutoImageProcessor.from_pretrained(\"visheratin/mexma-siglip\",use_fast=True)\n",
|
89 |
-
"tokenizer = AutoTokenizer.from_pretrained(\"visheratin/mexma-siglip\")\n",
|
90 |
-
"\n",
|
91 |
-
"# ---------------- 4️⃣ Трансформации ----------------\n",
|
92 |
-
"transform = Compose([\n",
|
93 |
-
" lambda img: img.convert(\"RGB\"), \n",
|
94 |
-
" Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC), # Ресайз\n",
|
95 |
-
" ToTensor(), # В тензор\n",
|
96 |
-
" Normalize(mean=0.5, std=0.5) # [-1, 1]\n",
|
97 |
-
"])\n",
|
98 |
-
"\n",
|
99 |
-
"# ---------------- 5️⃣ Функция обработки изображений ----------------\n",
|
100 |
-
"def encode_images_batch(images):\n",
|
101 |
-
" pixel_values = torch.stack([processor(images=img, return_tensors=\"pt\")[\"pixel_values\"].squeeze(0) for img in images]).to(device, dtype)\n",
|
102 |
-
" \n",
|
103 |
-
" with torch.inference_mode():\n",
|
104 |
-
" image_embeddings = model.vision_model(pixel_values).pooler_output #chang on last_hidden_state # (B, 729, 1152)\n",
|
105 |
-
"\n",
|
106 |
-
" return image_embeddings.unsqueeze(1).cpu().numpy()\n",
|
107 |
-
"\n",
|
108 |
-
"def encode_texts_batch(texts):\n",
|
109 |
-
" try:\n",
|
110 |
-
" with torch.inference_mode():\n",
|
111 |
-
" text_tokenized = tokenizer(texts, return_tensors=\"pt\", padding=\"max_length\").to(device)\n",
|
112 |
-
" features = model.text_model(\n",
|
113 |
-
" input_ids=text_tokenized.input_ids, attention_mask=text_tokenized.attention_mask\n",
|
114 |
-
" ).last_hidden_state\n",
|
115 |
-
" features_proj = model.text_projector(features)\n",
|
116 |
-
" return features_proj.cpu().numpy()\n",
|
117 |
-
" except Exception as e:\n",
|
118 |
-
" print(f\"Ошибка при кодировании текстов: {e}\")\n",
|
119 |
-
" raise\n",
|
120 |
-
"\n",
|
121 |
-
"\n",
|
122 |
-
"# return empty str with prob\n",
|
123 |
-
"def maybe_empty_label(label, prob=0.01):\n",
|
124 |
-
" return \"\" if random.random() < prob else label\n",
|
125 |
-
"\n",
|
126 |
-
"\n",
|
127 |
-
"def encode_to_latents(examples):\n",
|
128 |
-
" pixel_values = torch.stack([transform(img) for img in examples[\"image\"]]).to(device, dtype) # (B, 3, 256, 256)\n",
|
129 |
-
" \n",
|
130 |
-
" # VAE Latents\n",
|
131 |
-
" with torch.no_grad():\n",
|
132 |
-
" posterior = vae.encode(pixel_values.to(device)).latent_dist.mode()\n",
|
133 |
-
" z = (posterior - vae.config.shift_factor) * vae.config.scaling_factor\n",
|
134 |
-
" latents = z.cpu().numpy()\n",
|
135 |
-
"\n",
|
136 |
-
" # Преобразование числовых меток в строковые\n",
|
137 |
-
" text_labels = [str(lbl) for lbl in examples[\"label\"]]\n",
|
138 |
-
" \n",
|
139 |
-
" if random.random() < img_share:\n",
|
140 |
-
" # Image Embeddings\n",
|
141 |
-
" pil_images = [Image.fromarray(((img.cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)) for img in pixel_values]\n",
|
142 |
-
" embeddings = encode_images_batch(pil_images)\n",
|
143 |
-
" #print(\"image_embeddings\",embeddings.shape)\n",
|
144 |
-
" else:\n",
|
145 |
-
" text_labels_with_empty = [maybe_empty_label(lbl, empty_share) for lbl in text_labels]\n",
|
146 |
-
" #print(\"text_labels_with_empty\",text_labels_with_empty)\n",
|
147 |
-
" embeddings = encode_texts_batch(text_labels_with_empty)\n",
|
148 |
-
" #print(\"text_embeddings\",embeddings.shape)\n",
|
149 |
-
"\n",
|
150 |
-
" return {\n",
|
151 |
-
" \"vae\": latents,\n",
|
152 |
-
" \"embeddings\": embeddings,\n",
|
153 |
-
" \"text\": text_labels\n",
|
154 |
-
" }\n",
|
155 |
-
"\n",
|
156 |
-
"# ---------------- 6️⃣ Обработка датасета ----------------\n",
|
157 |
-
"limited_dataset = dataset#.select(range(10))#00000)) # Ограничиваем 1000 семплов\n",
|
158 |
-
"encoded_dataset = limited_dataset.map(\n",
|
159 |
-
" encode_to_latents,\n",
|
160 |
-
" batched=True,\n",
|
161 |
-
" batch_size=batch_size,\n",
|
162 |
-
" remove_columns=[\"image\"],\n",
|
163 |
-
" desc=\"Encoding images to latents\"\n",
|
164 |
-
")\n",
|
165 |
-
"\n",
|
166 |
-
"# ---------------- 7️⃣ Сохранение ----------------\n",
|
167 |
-
"save_path = \"datasets/mnist-te\"\n",
|
168 |
-
"os.makedirs(save_path, exist_ok=True)\n",
|
169 |
-
"#encoded_dataset.to_parquet(os.path.join(save_path, \"dataset.parquet\")) # Оптимальный формат\n",
|
170 |
-
"encoded_dataset.save_to_disk(save_path)\n",
|
171 |
-
"print(\"ok\")\n",
|
172 |
-
"\n"
|
173 |
-
]
|
174 |
-
},
|
175 |
-
{
|
176 |
-
"cell_type": "code",
|
177 |
-
"execution_count": 3,
|
178 |
-
"id": "c0507b7e-7dbe-43ca-999a-6dc38aa1fb40",
|
179 |
-
"metadata": {},
|
180 |
-
"outputs": [
|
181 |
-
{
|
182 |
-
"data": {
|
183 |
-
"application/vnd.jupyter.widget-view+json": {
|
184 |
-
"model_id": "06e183871e9f424ea684bacfc6f948a6",
|
185 |
-
"version_major": 2,
|
186 |
-
"version_minor": 0
|
187 |
-
},
|
188 |
-
"text/plain": [
|
189 |
-
"Loading dataset from disk: 0%| | 0/143 [00:00<?, ?it/s]"
|
190 |
-
]
|
191 |
-
},
|
192 |
-
"metadata": {},
|
193 |
-
"output_type": "display_data"
|
194 |
-
},
|
195 |
-
{
|
196 |
-
"name": "stdout",
|
197 |
-
"output_type": "stream",
|
198 |
-
"text": [
|
199 |
-
"Форма латентного представления: (16, 8, 8)\n",
|
200 |
-
"embedding shape: (512, 1152)\n"
|
201 |
-
]
|
202 |
-
},
|
203 |
-
{
|
204 |
-
"data": {
|
205 |
-
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAFeCAYAAADnm4a1AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAI0VJREFUeJzt3XmQVNX58PGnu6e7Z6ZnxVmAGZiBQTYlaEBcACFqpBLxJ7jgUlHGBXeMKTTBskB/xlIxUkWVccuiWEWsRBGRLBU1UfF1iZrlFSUqiyKLs+9bz9J93j8s5vX2cw72DCMMw/dTlarcZ07fvt1z5/FwnnvO8RljjAAAFP/hvgAAGKxIkADgQIIEAAcSJAA4kCABwIEECQAOJEgAcCBBAoADCRIAHEiQOGrt2rVLfD6frF279nBfCgYpEmQfrV27Vnw+X+//UlJSpKioSMrLy2Xfvn2H+/IG3KOPPnrYE8jhvobXX39dfD6frF+//rBdAw6PlMN9AUeqe+65R8aMGSPRaFT+8Y9/yNq1a+XNN9+Ujz76SFJTUw/35Q2YRx99VPLy8qS8vPyovgYcnUiQ/fSDH/xApk+fLiIi11xzjeTl5cmqVatk06ZNsmjRosN8dYdHW1ubRCKRw30ZwIDhn9gDZPbs2SIisnPnTk/8k08+kQsvvFCGDRsmqampMn36dNm0aZN6fWNjo/zkJz+R0tJSCYfDUlxcLFdccYXU1tb2tqmurparr75aCgsLJTU1VaZOnSpPP/205zz7x9Ueeugh+dWvfiVlZWUSDoflpJNOkvfff9/TtrKyUq688kopLi6WcDgsI0aMkPPOO0927dolIiKlpaWydetW2bx5c++Qwty5c0Xk/w81bN68WW688UYpKCiQ4uJiEREpLy+X0tJS9Rnvvvtu8fl8Kr5u3TqZMWOGpKenS25urpx++uny8ssvf+M17P/ebr31Vhk1apSEw2EZN26crFq1SuLxuPp+y8vLJTs7W3JycmTx4sXS2NioriVZ+z/Ltm3b5Ec/+pFkZ2dLfn6+rFixQowxsmfPHjnvvPMkKytLhg8fLqtXr/a8vqurS1auXCnTpk2T7OxsiUQiMnv2bHnttdfUe9XV1cnll18uWVlZvdf+wQcfWMdPk73fkBx6kANkf1LJzc3tjW3dulVmzpwpRUVFsnz5colEIvLss8/KggUL5Pnnn5eFCxeKiEhra6vMnj1bPv74Y7nqqqvku9/9rtTW1sqmTZtk7969kpeXJx0dHTJ37lzZsWOH3HzzzTJmzBh57rnnpLy8XBobG+XHP/6x53qeeeYZaWlpkeuuu058Pp88+OCDcv7558tnn30mwWBQREQuuOAC2bp1qyxdulRKS0ulurpaXnnlFdm9e7eUlpbKmjVrZOnSpZKRkSF33nmniIgUFhZ63ufGG2+U/Px8WblypbS1tfX5e/vf//1fufvuu+W0006Te+65R0KhkLz77rvy6quvytlnn33Aa2hvb5c5c+bIvn375LrrrpPRo0fL22+/LXfccYdUVFTImjVrRETEGCPnnXeevPnmm3L99dfLpEmT5IUXXpDFixf3+XoTXXzxxTJp0iR54IEH5M9//rPce++9MmzYMHniiSfkjDPOkFWrVsnvfvc7ue222+Skk06S008/XUREmpub5Te/+Y1ceumlsmTJEmlpaZHf/va3Mm/ePHnvvffkhBNOEBGReDwu5557rrz33ntyww03yMSJE+XFF1+0Xnuy9xv6wKBPnnrqKSMi5m9/+5upqakxe/bsMevXrzf5+fkmHA6bPXv29LY988wzzZQpU0w0Gu2NxeNxc9ppp5ljjz22N7Zy5UojImbDhg3q/eLxuDHGmDVr1hgRMevWrev9WVdXlzn11FNNRkaGaW5uNsYY8/nnnxsRMcccc4ypr6/vbfviiy8aETF//OMfjTHGNDQ0GBExv/jFLw74eY877jgzZ84c5/cwa9Ys09PT4/nZ4sWLTUlJiXrNXXfdZb5+y23fvt34/X6zcOFCE4vFrJ/7QNfw85//3EQiEbNt2zZPfPny5SYQCJjdu3cbY4zZuHGjERHz4IMP9rbp6ekxs2fPNiJinnrqKdfHN8YY89prrxkRMc8995z6LNdee63nnMXFxcbn85kHHnigN97Q0GDS0tLM4sWLPW07Ozs979PQ0GAKCwvNVVdd1Rt7/vnnjYiYNWvW9MZisZg544wz1LUne78hefwTu5/OOussyc/Pl1GjRsmFF14okUhENm3a1PvPzPr6enn11Vdl0aJF0tLSIrW1tVJbWyt1dXUyb9482b59e2/V+/nnn5epU6da/wu//5+kf/nLX2T48OFy6aWX9v4sGAzKLbfcIq2trbJ582bP6y6++GJPb3b/EMBnn30mIiJpaWkSCoXk9ddfl4aGhn5/D0uWLJFAINCv127cuFHi8bisXLlS/H7vrWj7p3ii5557TmbPni25ubm9329tba2cddZZEovF5I033hCRr767lJQUueGGG3pfGwgEZOnSpf267q+75pprPOecPn26GGPk6quv7o3n5OTIhAkTer/7/W1DoZCIfNVLrK+vl56eHpk+fbr8+9//7m3317/+VYLBoCxZsqQ35vf75aabbvJcR1/uNySPf2L30yOPPCLjx4+XpqYmefLJJ+WNN96QcDjc+/MdO3aIMUZWrFghK1assJ6jurpaioqKZOfOnXLBBRcc8P2++OILOfbYY1UimTRpUu/Pv2706NGe4/3Jcn8yDIfDsmrVKlm2bJkUFhbKKaecIvPnz5crrrhChg8fnsQ38JUxY8Yk3TbRzp07xe/3y+TJk/v1+u3bt8uWLVskPz/f+vPq6moR+eq7GTFihGRkZHh+PmHChH6979clfs/Z2dmSmpoqeXl5Kl5XV+eJPf3007J69Wr55JNPpLu7uzf+9e90/7Wnp6d7Xjtu3DjPcV/uNySPBNlPM2bM6K1iL1iwQGbNmiWXXXaZfPrpp5KRkdFbJLjttttk3rx51nMk3uQDydWrM1/bYePWW2+Vc889VzZu3CgvvfSSrFixQu6//3559dVX5cQTT0zqfdLS0lTM1fuLxWJJnTNZ8Xhcvv/978tPf/pT68/Hjx8/oO9nY/uek/nu161bJ+Xl5bJgwQK5/fbbpaCgQAKBgNx///2q0JeMw32/DVUkyAGw/8b+3ve+J7/85S9l+fLlMnbsWBH56p/BZ5111gFfX1ZWJh999NEB25SUlMiWLVskHo97epGffPJJ78/7o6ysTJYtWybLli2T7du3ywknnCCrV6+WdevWiUhy/9RNlJuba60QJ/Zyy8rKJB6Py3//+9/eooSN6xrKysqktbX1G7/fkpIS+fvf/y6tra2eXuSnn356wNd9m9avXy9jx46VDRs2eD7fXXfd5WlXUlIir732mrS3t3t6kTt27PC068v9huQxBjlA5s6dKzNmzJA1a9ZINBqVgoICmTt3rjzxxBNSUVGh2tfU1PT+/wsuuEA++OADeeGFF1S7/b2OH/7wh1JZWSl/+MMfen/W09MjDz/8sGRkZMicOXP6dL3t7e0SjUY9sbKyMsnMzJTOzs7eWCQS6fPjMGVlZdLU1CRbtmzpjVVUVKjPt2DBAvH7/XLPPfeox3K+3ttyXcOiRYvknXfekZdeekn9rLGxUXp6ekTkq++up6dHHnvssd6fx2Ixefjhh/v0uQbS/l7m1z/nu+++K++8846n3bx586S7u1t+/etf98bi8bg88sgjnnZ9ud+QPHqQA+j222+Xiy66SNauXSvXX3+9PPLIIzJr1iyZMmWKLFmyRMaOHStVVVXyzjvvyN69e+WDDz7ofd369evloosukquuukqmTZsm9fX1smnTJnn88cdl6tSpcu2118oTTzwh5eXl8q9//UtKS0tl/fr18tZbb8maNWskMzOzT9e6bds2OfPMM2XRokUyefJkSUlJkRdeeEGqqqrkkksu6W03bdo0eeyxx+Tee++VcePGSUFBgZxxxhkHPPcll1wiP/vZz2ThwoVyyy23SHt7uzz22GMyfvx4TwFi3Lhxcuedd8rPf/5zmT17tpx//vkSDofl/fffl5EjR8r9999/wGu4/fbbZdOmTTJ//nwpLy+XadOmSVtbm3z44Yeyfv162bVrl+Tl5cm5554rM2fOlOXLl8uuXbtk8uTJsmHDBmlqaurTdzaQ5s+fLxs2bJCFCxfKOeecI59//rk8/vjjMnnyZGltbe1tt2DBApkxY4YsW7ZMduzYIRMnTpRNmzZJfX29iHh718neb+iDw1dAPzLtf7zl/fffVz+LxWKmrKzMlJWV9T76snPnTnPFFVeY4cOHm2AwaIqKisz8+fPN+vXrPa+tq6szN998sykqKjKhUMgUFxebxYsXm9ra2t42VVVV5sorrzR5eXkmFAqZKVOmqEdU9j/mY3t8R0TMXXfdZYwxpra21tx0001m4sSJJhKJmOzsbHPyySebZ5991vOayspKc84555jMzEwjIr2P2xzoezDGmJdfftkcf/zxJhQKmQkTJph169apx3z2e/LJJ82JJ55owuGwyc3NNXPmzDGvvPLKN16DMca0tLSYO+64w4wbN86EQiGTl5dnTjvtNPPQQw+Zrq4uz/d7+eWXm6ysLJOdnW0uv/xy85///OegH/OpqanxtF28eLGJRCLqHHPmzDHHHXdc73E8Hjf33XefKSkpMeFw2Jx44onmT3/6k/URqZqaGnPZZZeZzMxMk52dbcrLy81bb71lRMT8/ve/97RN9n5DcnzGsC82cKTZuHGjLFy4UN58802ZOXPm4b6cIYsECQxyHR0dnqcFYrGYnH322fLPf/5TKisrrU8SYGAwBgkMckuXLpWOjg459dRTpbOzUzZs2CBvv/223HfffSTHbxk9SGCQe+aZZ2T16tWyY8cOiUajMm7cOLnhhhvk5ptvPtyXNuSRIAHAgecgAcCBBAkADiRIAHBIuood8OmZGkaillhcxTTbsOdADoUmO3/Y9p621/b32pI9V2K7oTgsfDC/E/Tdt30fJyPZ90v2/Mm0020iIV3pb+1sSeod6UECgAMJEgAcSJAA4ECCBACHpB8U78/CqQBw+OlatDHdlnYaPUgAcCBBAoADCRIAHEiQAODAepAAhrSAr/+zs+hBAoADCRIAHEiQAODAGCSAIS1+EHNc6EECgAMJEgAcSJAA4ECCBAAHijQAhraD2NmaHiQAOJAgAcCBBAkADiRIAHCgSANgSDMHsdc6PUgAcCBBAoADCRIAHEiQAOBAkQbAkOZjuTMAGHgkSABwIEECgAMJEgAcKNIAGNoo0gDAwCNBAoADCRIAHBiDBDDE9X8Qkh4kADiQIAHAgQQJAA4kSABwoEgDYEg7iG2x6UECgAsJEgAcSJAA4ECCBAAHijTAoKVngKSk6D/ZUGpQt/OFVKytM+o5jnV1Wt7zICoaQxA9SABwIEECgAMJEgAcSJAA4ECRBhgkAgHvn2NWJKLaTJg0QcVmzjpVxdKDmSr2f95+y3P83n/+qdq0t7R843UeefpfeKIHCQAOJEgAcCBBAoADCRIAHCjSAINEeka653hUSYlqM3P26Sp24aIFKpaWkqFi8RTvzJzP932u2uyJdujXdfeo2BGF5c4AYOCRIAHAgQQJAA4kSABwGJJFmtSwXuop3TIrwe8PJHW+9oSB6/a2Nt2IVaJwkNIjaZ7jEUUjVJtj8oapWCQ9XcV6LCuZdXZ1e44Dfv3nH7Qsp9Z5pBdp9KpxSaMHCQAOJEgAcCBBAoDDET8GaRszGT5Sj90UFY9UsXAoVcUCfj1gUVFR6Tn+dMd21aY72q1iQF9kZnrHEm1jkBnpeiy9vb1LxZrq9Ko8dXU13oBtw2jG0j3oQQKAAwkSABxIkADgQIIEAIcjv0gT0g+F51keph09ShdpsjKy9fmCeo9hv89buNm1+wvVhiINDlYkYTLDsOws3SiuqyjRjqiKdVr2vPYnvDQU1H/+B/FM9ZBEDxIAHEiQAOBAggQABxIkADgc+UUay0yarCxdfBlZWKRiw3LzVCycpos0zU1NnuOQpTAkYlnhB3Dw+XQ5JBLxzqQZdkyuapNpKywG9KpU6RlpKlZYVOh9XTis2sRts2uOcAdTeKIHCQAOJEgAcCBBAoADCRIAHI74Ik1aul6ybFSxLsiMHz9OxfIL9Oyari5dbPl02zbPsV+S26oBcAlYCiuZWZme4+JRo1SbsjH6Pi4Yma9iX+yuUrGm5lbPcXOzXhKtuyemL/aI1/8yDT1IAHAgQQKAAwkSABxIkADgMASKNHpP4LFleiB78pQpKhbw6xkxlZX7VCxxboE/wKJQODgBn+6b5OV4iy1lY8eqNrZiY3dML7VXW1OjYl/s2vWNbUzsCN8De4DRgwQABxIkADiQIAHA4Ygfg0xJ0avv5B9zjIoVFhSoWNzy34eKyr26XTzuOe6JDcWHafFtSQnqh8ILLPdjSUmJ57i4qES1SY9kqFhNTbWKVVZ8qWJVVd6Hxzs69LYM8KIHCQAOJEgAcCBBAoADCRIAHI78Io1lb9+MLL2fsG1wu6NN7yccs6xmEot7Yz52D4aDP6Dvx0nHTVKxE6eeoGJz5s7xHB+Tq/d391seMG+or1Ox+gYda231rlQV7zla9nLv/zYS9CABwIEECQAOJEgAcCBBAoDDEV+ksWwvLH5LEcW2xH0oVe8LHEuYNSMi0tPtLdKYIbh3MAZGyLJP+7hj9ao858yfp2LfOeF477nCerWpHsv9WVGpV+Vpam5WsXhPl4odHdhyAQAGHAkSABxIkADgQIIEAIcjvkgTCurlzlIssxlsg+ddUT2TIHFpMxGRUNhbzPH52Rcbdj6/7nNEIpkqVjJaL2WWm5XrDVhqgT1d+p7ds/cLFauqrlCxxJk0+Gb0IAHAgQQJAA4kSABwIEECgEMfijS2p9EP/YySzMyI53jEiBGqTc6wXBULplqKOd16D+BwaqqKqcINE2ngELDsPyMxXfiLhPV+7olLpdkmbHVE9T4ye/fqvdw//fhTFWtqaNInPCqw3BkADDgSJAA4kCABwGFQPygeTtXjNCOKiz3Ho0fpB27z8/S+2EHL/tkpIf3QbUpIjyH5fN4YQ5BwCVtWiMrKzlGxtIi+t32JN5ZlELK9o13FmhobVKymWq/w09Z+tD4ozmo+ADDgSJAA4ECCBAAHEiQAOAzqIk16epqKjSwa6TkuKRml2gyzFGnE6IFanyXW3aUfHjeWFX4Am6xMvSd7WljfxyFLzCTej5baQqflQfG2Vl24scXYKqTv6EECgAMJEgAcSJAA4ECCBACHwV2kscw2yC/I9xwXFg5XbTIyMvTJLBtod3bqAe9oNKpiXQn7CccZ7IaIpFi23igdpYuGpaN1LD1NF2kS52gZE1MtOrv03tbtnToW7T5a98AeWPQgAcCBBAkADiRIAHAgQQKAw6Au0qSl6yJNblaO5zgnJ0e1CYVDKuYP6P8WNLe0qlhTc7OKtTR728V79OA5jj5FJaNVbMYpp6rYxIkTVCycqu9RX0IhMWa5zWx7W7c161isi3t0INCDBAAHEiQAOJAgAcCBBAkADoO6SBMO6X1kMjMzPcdZWZmqTcAyw8FnWTuqublRxSorKlUscc+PWEzvZYMjmb43QineIsrwkXrG1ukzZ6nYlMnHqdjEiZNULDVs2X895p1J09ik97GuqdV7zTS36HY9Mb1s39Gh//vP2NCDBAAHEiQAOJAgAcCBBAkADn0o0tgGP7/dZb/8AV1sCYfDBzwWEQlYZs3E4/paGxr1rJmKKl2kqav3Fmm6uynSHKlycnNULHGfIxGRgsIRnuNJk3WhZdZpp6jYyTP0TJphCUv0iYiIT9+jHdEOz/HWLf9VbT788EMVq6mqUrF4z9FapLHkJMtSh8miBwkADiRIAHAgQQKAw6B+UDxgGaexxRL5LG3ilqVR2iwro9TW1qlYU8IDu4xBDgbecaVAir6Vsy0rPU2Zoh/knjL1Oyo2LDfPczxpkl6R5+STv6ti+fl6vDHFcm22obL6ulrP8ebNm1WbDz78SMWqq/TD4992feBoQQ8SABxIkADgQIIEAAcSJAA4DOoije35Tr8/Iaf79MPkYtm3uiemY92Wh2m7u3QBpj3hAV4Tj+v3HCR8fv2lJX4dQcsqSalpenuLNMvWFUHLw/uBoD5f4nX4RL8uLSuiYlkRvad5RqZesSk15J0gELJcw/DhhSo2wrIqT3GR3rc6JzfXczxqtN5e4Zi8YSqWEtB/UrbHlG21xrY27332xe7dqs1nn+3Ur2vVW4dgYNCDBAAHEiQAOJAgAcCBBAkADkkXaWwFE0stZEDZtkkwCfWRuG1peduCHpbzRyK6MOFP0YP9sUFQlPEH9HVl52apWGqqXt0oljCLKBLRRY+i4iIVKyzUs0KybAWTSJqKJRZz4jH9GxhhWUVn/LgyFRtZVKzfM+Fzpli22QiH9XcWtHyP9gJel/dcqfozmrj+TJYameMPRTfsiLZ7jmsb9KyuGstMr2hnp+X86HUQeYoeJAA4kCABwIEECQAOJEgAcEi6SPNtF2RsLLskiC+hWhSwzVxIscz2sFSZRlsKE8cdf7yK1dXVe46//PJL1aajo13F4l16iTVj+VApQe9nSLPsmZybMLNDRGRk0QgVCyTONBKRlmbv1hI5OdmqzdgxJSpWMqZUxWz7OQctM25MYjXNMpNmVImenTJp4kQVy8rWxajEIoft/uzqiKpYa8JsFRGRpjq9ZcGXFdWeY7/fUqg7Tl/r6FH6M9mW37NpTNgCpKa6VrVpatbbhLC02beHHiQAOJAgAcCBBAkADiRIAHDow3Jnh34gOBbtUjF/QrElkqFnOMTjujgSsCzTNWq0Xubqf845R8UmTvAOxldX60H9qKVI0xXVMxziMT3YH0wo0gT9+teSk60LK6WWwkp6RC8hFk+Y8ZGRqQstkTS9zFhWjp41k2JZXi5m+b67Evbt6bR8Fz09+ruoqdQzRSor9Pfd1u7dT8hWvNizd5+K7bPEdluWFatJ2Oclz1Iku/yKK1SsqFDPDvIH9e/TWIqGzQl7HzUm7McuIhKN6sITvkH/t8WmBwkALiRIAHAgQQKAw6Aeg4x26zHI5jbvWF9DfZNq01ijx26yhukxpHBQP+B87LHjVCxx6f6OqH7YuKtLX2u8R4/NxWwPige8/50KWvZRzszQY4Q5OZbVfNL0+GLcJIwb2pacsaxoE7esYtRl+Z10WsbFYgljlbF2vZVFbY0eb9y9S48H7qvSD+Z/ucfbbl+lbrN79x4Vs+0h3WAZ64vHvatEnTB1qmrT3qH3VbeNx6ZY/sxs+7S3tLZ4jju7LKv0WM6Pbw89SABwIEECgAMJEgAcSJAA4DCo98WurtUD6m+/86bnuDOq9wTeO32Gik397gkqlpWlH4QOW4ociUWTQLp+GDsWtBRHLAUZ+xYR3nZ+297TlgeL/Zb/vjU36e+jtrbRc9zarh9qb27Wxa46S/GitUU/kF1bp4stiXs1tzTq13V06GJXY0OjjjXp66irqU5oo89fZ7mu1hZdWOm2FEzSIt7fZ1aOLvKlZ+gtO3wB/TuxrTTUE9dFq8Tvo6vbcq+wcs8hRQ8SABxIkADgQIIEAAcSJAA4DOoiTVOjHpx/7x//9Bx/YZl58eWX1SpW19ioYkUjhqtYjmXGTUjNuNEFE9uCIbaVe2KWmEmYsZIS0r8WfQ0iKYFKFav4Uq98s2XLVs/xlxV6RZvaer28f2WFPn9tjW6XuKWDiEi0yzu7prtTFyXiloJDzDL7yNYucaaLsZ7LVuTQfYJIpp6lNOU7kz3Hp5x2qmpTNEKv3OOzbC1hjP5MnZaZV80t3pk0tu05cGjRgwQABxIkADiQIAHAgQQJAA6DukhjE4t5B7crK3UhYctH/1fFbEWU/PxjVCwS0QP2oZC3QOK37D1tmegixjKFQm0XLba9vvX5bSuU2WZo1DXUq9jObTs9x7aZL03NjSpWXWvZl7lRz7gxYvlQg5b+0oYN0/fBidOmeY6nTTvB8rphKubz6fP3xHSBqsqybUdtfcLvhe7LYcevAAAcSJAA4ECCBAAHEiQAOBxxRZpEtuXDdn2hZ9ckLvklIhIOBlUsxTJjxZ9QIbEVfKxb71qKKD5LgUe3sS1tpsUse8bY9jFpbfEuPWYrGnRbZp3Y2h359C8l17K3z4jhBZ7jYyz7YicW70RELPU1aWnVy8t9vPVjFduTsD93t2UvcRxa9CABwIEECQAOJEgAcDjixyBtejr1SimNnfqhZ0BEJCWkx6LTUr3bKSQ+zC8iErSsumRbVaiyUj8U/q9//1vF9uz17uPd3aPvYxxa9CABwIEECQAOJEgAcCBBAoDDkCzSAH0RsDyYn1iTsRVp/JanwnssD9xXVen93Xfu/EzF9uzxFmmi7XrfcBxa9CABwIEECQAOJEgAcCBBAoADRRrAUoBJ3C7DusKS5XW2vdCra/Q+7fu+1HuTt7V598WOWfbTxqFFDxIAHEiQAOBAggQABxIkADhQpAEsW2OYmDfot2yqYZtdYxO1bJ0Q7dCxrm7vLJx43HJhOKToQQKAAwkSABxIkADgQIIEAAeKNIDRxZZYQoHEZ1nazG/Z49z4dWElGNTtAsGA5XwJMVsNiLrNIUUPEgAcSJAA4ECCBAAHEiQAOFCkASyVj8T6iLHMavH7bEUavdxZKBhSsdycYSqWnZXpOW5raVFtYj0sgXYo0YMEAAcSJAA4kCABwIExSMC2Kk/CkKNtewXbg9w+y7hkVna2ig0fWahie3bv9hxXVlapNoxB9sNBPFxPDxIAHEiQAOBAggQABxIkADhQpAFsWy6oiK2QY2llKeakBPSfWSgYtLRLXOGHpXv6LrltMJJFDxIAHEiQAOBAggQABxIkADhQpMFRz1iKLYmxuKVN3LJVg23GjS6+iMS79ao/scR9sWO6Db6JrbDV/8INPUgAcCBBAoADCRIAHEiQAOBAkQZHPWMsxZAkJrHYhv5tRZpQWM+aSUtLU7FIesRznJKi/zxZ7uzQogcJAA4kSABwIEECgAMJEgAcKNIAlj2v9b7YSb1MApY9aTIzs1RsZHGxiuWP8O5JE/z0E9WmM9qp3xQHdDCLxtGDBAAHEiQAOJAgAcCBMUgc9fx+vdqOP2EFHlsb25YLtpWBMjMyVWx4YYGKRTLSE96T/svA6P8oJL8BAHAgQQKAAwkSABxIkADgQJEGR72AZVkef8Kj4gFLI5/lofBYT4+KhVPD+vy2wlDCFg62gg8OLXqQAOBAggQABxIkADiQIAHAgSINjnpx+94JnsNAit42weezrAJk2XLBMgdHYt3dul3CewSD+j1xaNGDBAAHEiQAOJAgAcCBBAkADhRpcNTrinaoWEtzs+e4qbFBtem2FFrClj2ww6GQiqWFUy1X4i36xGOWfR7QZ7YaXLLoQQKAAwkSABxIkADgQIIEAAeKNDjqtba1q9hnn33uOR5TUqrajJ84XsUCwRwVi/v1n1k4PV3FEvlt67ChH/r/PdKDBAAHEiQAOJAgAcCBMUgc9apralTsw60feY7zC/JUm1FlJSqWX5CvYq0trSr2+a7dKtbY0Og57umOqTbou4PZuIIeJAA4kCABwIEECQAOJEgAcPCZJDfftS0lDwwNup+QGol4jkcOH6HajC0rU7HMjDQVa7YUaSorKlSsuqrSc1xXV6faxOOs8NNnfv37NbHkCmD0IAHAgQQJAA4kSABwIEECgANFGgBDms+n+4HxOEUaADgoJEgAcCBBAoADCRIAHFjuDMCQZnz9X/CMHiQAOJAgAcCBBAkADiRIAHCgSANgiKNIAwADjgQJAA4kSABwIEECgANFGgBDms+yJ02y6EECgAMJEgAcSJAA4JD0GKRtw4X+P34JHC7Jbh3C3X149Tfj6NcZwxgkAAw4EiQAOJAgAcCBBAkADkkXaWzDowHbgGhCy7jldZZtasVYG/ZzoNbyOtupbGc6iNXZB0yy12oPfsvlNGuNI4nCR7KXkOz5E7dz7+91JdnOeq8keXqf9fq/+V2sW9ZbYkn/dvu9tf1B/fIS2P7QdULwWWLG8lpf4ndmOX0w3faeyaEHCQAOJEgAcCBBAoADCRIAHJIv0tgGjAFgCKMHCQAOJEgAcCBBAoADCRIAHEiQAOBAggQABxIkADiQIAHAgQQJAA7/D5Err3THjWEFAAAAAElFTkSuQmCC",
|
206 |
-
"text/plain": [
|
207 |
-
"<Figure size 400x400 with 1 Axes>"
|
208 |
-
]
|
209 |
-
},
|
210 |
-
"metadata": {},
|
211 |
-
"output_type": "display_data"
|
212 |
-
},
|
213 |
-
{
|
214 |
-
"name": "stdout",
|
215 |
-
"output_type": "stream",
|
216 |
-
"text": [
|
217 |
-
"🔹 Текстовое описание: 4\n",
|
218 |
-
"ok\n"
|
219 |
-
]
|
220 |
-
}
|
221 |
-
],
|
222 |
-
"source": [
|
223 |
-
"from datasets import load_from_disk\n",
|
224 |
-
"import matplotlib.pyplot as plt\n",
|
225 |
-
"import numpy as np\n",
|
226 |
-
"import torch\n",
|
227 |
-
"from PIL import Image \n",
|
228 |
-
"\n",
|
229 |
-
"dtype = torch.float16\n",
|
230 |
-
"\n",
|
231 |
-
"# Загружаем сохраненный датасет\n",
|
232 |
-
"loaded_dataset = load_from_disk(save_path)\n",
|
233 |
-
"\n",
|
234 |
-
"# Проверяем структуру датасета\n",
|
235 |
-
"#print(\"Структура датасета:\", loaded_dataset.features)\n",
|
236 |
-
"\n",
|
237 |
-
"# Выбираем пример для демонстрации\n",
|
238 |
-
"example = loaded_dataset[2]\n",
|
239 |
-
"\n",
|
240 |
-
"# Выводим информацию о примере\n",
|
241 |
-
"print(\"Форма латентного представления:\", np.array(example[\"vae\"]).shape)\n",
|
242 |
-
"print(\"embedding shape:\", np.array(example[\"embeddings\"]).shape)\n",
|
243 |
-
"\n",
|
244 |
-
"# Преобразуем латентное представление в тензор PyTorch\n",
|
245 |
-
"latent_tensor = torch.tensor(example[\"vae\"], dtype=dtype).unsqueeze(0).to(device)\n",
|
246 |
-
"\n",
|
247 |
-
"# Декодируем латентное представление обратно в изображение\n",
|
248 |
-
"with torch.no_grad():\n",
|
249 |
-
" #reconstructed_image = vae.decode(latent_tensor).sample # Результат — тензор\n",
|
250 |
-
" latent = (latent_tensor.detach() / vae.config.scaling_factor) + vae.config.shift_factor\n",
|
251 |
-
" reconstructed_image = vae.decode(latent).sample\n",
|
252 |
-
"\n",
|
253 |
-
"# Переносим тензор на CPU и преобразуем в NumPy массив\n",
|
254 |
-
"\n",
|
255 |
-
"reconstructed_image = reconstructed_image.squeeze(0).cpu().numpy() # Удаляем размерность батча\n",
|
256 |
-
"\n",
|
257 |
-
"# Переносим каналы в правильный формат (CHW -> HWC) и нормализуем значения пикселей\n",
|
258 |
-
"reconstructed_image = np.transpose(reconstructed_image, (1, 2, 0))\n",
|
259 |
-
"reconstructed_image = (reconstructed_image + 1) / 2 # Нормализация в диапазон [0, 1]\n",
|
260 |
-
"\n",
|
261 |
-
"# Преобразуем тип данных к float32\n",
|
262 |
-
"reconstructed_image = reconstructed_image.astype(np.float32)\n",
|
263 |
-
"reconstructed_image = np.clip(reconstructed_image, 0.0, 1.0)\n",
|
264 |
-
"\n",
|
265 |
-
"# Отображаем восстановленное изображение\n",
|
266 |
-
"plt.figure(figsize=(4, 4))\n",
|
267 |
-
"plt.imshow(reconstructed_image)\n",
|
268 |
-
"plt.title(f\"Reconstructed Image\")\n",
|
269 |
-
"plt.axis(\"off\")\n",
|
270 |
-
"plt.show()\n",
|
271 |
-
"print(f\"🔹 Текстовое описание: {example['text']}\")\n",
|
272 |
-
"print(\"ok\")"
|
273 |
-
]
|
274 |
-
},
|
275 |
-
{
|
276 |
-
"cell_type": "code",
|
277 |
-
"execution_count": 5,
|
278 |
-
"id": "3818a9a7-f72c-42ea-9805-43bfd4b214a0",
|
279 |
-
"metadata": {},
|
280 |
-
"outputs": [],
|
281 |
-
"source": [
|
282 |
-
"#!pip install matplotlib"
|
283 |
-
]
|
284 |
-
},
|
285 |
-
{
|
286 |
-
"cell_type": "code",
|
287 |
-
"execution_count": 2,
|
288 |
-
"id": "f5336198-8925-4e03-ad81-72ee0eb5248a",
|
289 |
-
"metadata": {},
|
290 |
-
"outputs": [
|
291 |
-
{
|
292 |
-
"name": "stdout",
|
293 |
-
"output_type": "stream",
|
294 |
-
"text": [
|
295 |
-
"HF_HOME is set to: cache\n",
|
296 |
-
"HF_DATASETS_CACHE is set to: cache/datasets\n"
|
297 |
-
]
|
298 |
-
}
|
299 |
-
],
|
300 |
-
"source": [
|
301 |
-
"# Проверьте переменные окружения\n",
|
302 |
-
"hf_home = os.environ.get(\"HF_HOME\")\n",
|
303 |
-
"hf_datasets_cache = os.environ.get(\"HF_DATASETS_CACHE\")\n",
|
304 |
-
"\n",
|
305 |
-
"if hf_home:\n",
|
306 |
-
" print(f\"HF_HOME is set to: {hf_home}\")\n",
|
307 |
-
"if hf_datasets_cache:\n",
|
308 |
-
" print(f\"HF_DATASETS_CACHE is set to: {hf_datasets_cache}\")"
|
309 |
-
]
|
310 |
-
},
|
311 |
-
{
|
312 |
-
"cell_type": "code",
|
313 |
-
"execution_count": 3,
|
314 |
-
"id": "94d7f0fe-3b27-4d08-ba35-ba7de096c635",
|
315 |
-
"metadata": {},
|
316 |
-
"outputs": [
|
317 |
-
{
|
318 |
-
"name": "stdout",
|
319 |
-
"output_type": "stream",
|
320 |
-
"text": [
|
321 |
-
"cache/datasets\n"
|
322 |
-
]
|
323 |
-
}
|
324 |
-
],
|
325 |
-
"source": [
|
326 |
-
"print(hf_datasets_cache)"
|
327 |
-
]
|
328 |
-
},
|
329 |
-
{
|
330 |
-
"cell_type": "code",
|
331 |
-
"execution_count": 2,
|
332 |
-
"id": "8206ec60-e828-4a56-b902-3b44b99b563c",
|
333 |
-
"metadata": {},
|
334 |
-
"outputs": [
|
335 |
-
{
|
336 |
-
"name": "stdout",
|
337 |
-
"output_type": "stream",
|
338 |
-
"text": [
|
339 |
-
"Collecting transformers\n",
|
340 |
-
" Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)\n",
|
341 |
-
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.13.1)\n",
|
342 |
-
"Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.28.1)\n",
|
343 |
-
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.3)\n",
|
344 |
-
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.1)\n",
|
345 |
-
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
|
346 |
-
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n",
|
347 |
-
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n",
|
348 |
-
"Collecting tokenizers<0.22,>=0.21 (from transformers)\n",
|
349 |
-
" Downloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n",
|
350 |
-
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.2)\n",
|
351 |
-
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n",
|
352 |
-
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2024.2.0)\n",
|
353 |
-
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.12.2)\n",
|
354 |
-
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.3.2)\n",
|
355 |
-
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n",
|
356 |
-
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.2.3)\n",
|
357 |
-
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2024.8.30)\n",
|
358 |
-
"Downloading transformers-4.49.0-py3-none-any.whl (10.0 MB)\n",
|
359 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.0/10.0 MB\u001b[0m \u001b[31m98.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
360 |
-
"\u001b[?25hDownloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)\n",
|
361 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m94.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
362 |
-
"\u001b[?25hInstalling collected packages: tokenizers, transformers\n",
|
363 |
-
"Successfully installed tokenizers-0.21.0 transformers-4.49.0\n",
|
364 |
-
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
365 |
-
"\u001b[0m"
|
366 |
-
]
|
367 |
-
}
|
368 |
-
],
|
369 |
-
"source": [
|
370 |
-
"!pip install -U transformers --break-system-packages"
|
371 |
-
]
|
372 |
-
},
|
373 |
-
{
|
374 |
-
"cell_type": "code",
|
375 |
-
"execution_count": 4,
|
376 |
-
"id": "48e6e1d3-310c-4f67-a27b-90a54036dc4d",
|
377 |
-
"metadata": {},
|
378 |
-
"outputs": [
|
379 |
-
{
|
380 |
-
"name": "stdout",
|
381 |
-
"output_type": "stream",
|
382 |
-
"text": [
|
383 |
-
"Collecting flash-attn\n",
|
384 |
-
" Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)\n",
|
385 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m94.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
386 |
-
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
|
387 |
-
"\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from flash-attn) (2.4.1+cu124)\n",
|
388 |
-
"Collecting einops (from flash-attn)\n",
|
389 |
-
" Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)\n",
|
390 |
-
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.13.1)\n",
|
391 |
-
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (4.12.2)\n",
|
392 |
-
"Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (1.12)\n",
|
393 |
-
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.2.1)\n",
|
394 |
-
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.1.3)\n",
|
395 |
-
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (2024.2.0)\n",
|
396 |
-
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
397 |
-
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
398 |
-
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
399 |
-
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (9.1.0.70)\n",
|
400 |
-
"Requirement already satisfied: nvidia-cublas-cu12==12.4.2.65 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.2.65)\n",
|
401 |
-
"Requirement already satisfied: nvidia-cufft-cu12==11.2.0.44 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (11.2.0.44)\n",
|
402 |
-
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.119 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (10.3.5.119)\n",
|
403 |
-
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.0.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (11.6.0.99)\n",
|
404 |
-
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.0.142 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.3.0.142)\n",
|
405 |
-
"Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (2.20.5)\n",
|
406 |
-
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
407 |
-
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
408 |
-
"Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.0.0)\n",
|
409 |
-
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch->flash-attn) (2.1.5)\n",
|
410 |
-
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.11/dist-packages (from sympy->torch->flash-attn) (1.3.0)\n",
|
411 |
-
"Downloading einops-0.8.1-py3-none-any.whl (64 kB)\n",
|
412 |
-
"Building wheels for collected packages: flash-attn\n",
|
413 |
-
" Building wheel for flash-attn (setup.py) ... \u001b[?25ldone\n",
|
414 |
-
"\u001b[?25h Created wheel for flash-attn: filename=flash_attn-2.7.4.post1-cp311-cp311-linux_x86_64.whl size=187805408 sha256=92cf49e6f66795b6934cec0cba526ed6e45d3313de3f905d45df8773f19092a9\n",
|
415 |
-
" Stored in directory: /root/.cache/pip/wheels/3d/88/d8/284b89f56af7d5bf366b10d6b8e251ac8a7c7bf3f04203fb4f\n",
|
416 |
-
"Successfully built flash-attn\n",
|
417 |
-
"Installing collected packages: einops, flash-attn\n",
|
418 |
-
"Successfully installed einops-0.8.1 flash-attn-2.7.4.post1\n",
|
419 |
-
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
420 |
-
"\u001b[0m"
|
421 |
-
]
|
422 |
-
}
|
423 |
-
],
|
424 |
-
"source": [
|
425 |
-
"!pip install flash-attn --no-build-isolation"
|
426 |
-
]
|
427 |
-
},
|
428 |
-
{
|
429 |
-
"cell_type": "code",
|
430 |
-
"execution_count": null,
|
431 |
-
"id": "24b9dd4e-789a-4162-b5b2-45c26f9b7504",
|
432 |
-
"metadata": {},
|
433 |
-
"outputs": [],
|
434 |
-
"source": []
|
435 |
-
}
|
436 |
-
],
|
437 |
-
"metadata": {
|
438 |
-
"kernelspec": {
|
439 |
-
"display_name": "Python 3 (ipykernel)",
|
440 |
-
"language": "python",
|
441 |
-
"name": "python3"
|
442 |
-
},
|
443 |
-
"language_info": {
|
444 |
-
"codemirror_mode": {
|
445 |
-
"name": "ipython",
|
446 |
-
"version": 3
|
447 |
-
},
|
448 |
-
"file_extension": ".py",
|
449 |
-
"mimetype": "text/x-python",
|
450 |
-
"name": "python",
|
451 |
-
"nbconvert_exporter": "python",
|
452 |
-
"pygments_lexer": "ipython3",
|
453 |
-
"version": "3.11.10"
|
454 |
-
}
|
455 |
-
},
|
456 |
-
"nbformat": 4,
|
457 |
-
"nbformat_minor": 5
|
458 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_mnist.ipynb
DELETED
@@ -1,442 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 1,
|
6 |
-
"id": "48d89cbc-3660-49b9-be37-087c1b05bf78",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"name": "stderr",
|
11 |
-
"output_type": "stream",
|
12 |
-
"text": [
|
13 |
-
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour\n"
|
14 |
-
]
|
15 |
-
},
|
16 |
-
{
|
17 |
-
"data": {
|
18 |
-
"application/vnd.jupyter.widget-view+json": {
|
19 |
-
"model_id": "32539c59549c43f8a526d5143656eda5",
|
20 |
-
"version_major": 2,
|
21 |
-
"version_minor": 0
|
22 |
-
},
|
23 |
-
"text/plain": [
|
24 |
-
"Encoding images to latents: 0%| | 0/60000 [00:00<?, ? examples/s]"
|
25 |
-
]
|
26 |
-
},
|
27 |
-
"metadata": {},
|
28 |
-
"output_type": "display_data"
|
29 |
-
},
|
30 |
-
{
|
31 |
-
"data": {
|
32 |
-
"application/vnd.jupyter.widget-view+json": {
|
33 |
-
"model_id": "2c0184b4765644e4ab4f1114283e8cdb",
|
34 |
-
"version_major": 2,
|
35 |
-
"version_minor": 0
|
36 |
-
},
|
37 |
-
"text/plain": [
|
38 |
-
"Saving the dataset (0/1 shards): 0%| | 0/60000 [00:00<?, ? examples/s]"
|
39 |
-
]
|
40 |
-
},
|
41 |
-
"metadata": {},
|
42 |
-
"output_type": "display_data"
|
43 |
-
},
|
44 |
-
{
|
45 |
-
"name": "stdout",
|
46 |
-
"output_type": "stream",
|
47 |
-
"text": [
|
48 |
-
"ok\n"
|
49 |
-
]
|
50 |
-
}
|
51 |
-
],
|
52 |
-
"source": [
|
53 |
-
"# pip install datasets diffusers transformers\n",
|
54 |
-
"# pip install accelerate\n",
|
55 |
-
"# pip install flash-attn --no-build-isolation\n",
|
56 |
-
"# git config --global credential.helper store\n",
|
57 |
-
"# pip install -U \"huggingface_hub[cli]\"\n",
|
58 |
-
"# huggingface-cli login\n",
|
59 |
-
"from datasets import load_dataset, DatasetDict\n",
|
60 |
-
"from diffusers import AutoencoderKL\n",
|
61 |
-
"from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode\n",
|
62 |
-
"from transformers import AutoModel, AutoImageProcessor, AutoTokenizer\n",
|
63 |
-
"import torch\n",
|
64 |
-
"import os\n",
|
65 |
-
"import numpy as np\n",
|
66 |
-
"from PIL import Image\n",
|
67 |
-
"from tqdm import tqdm\n",
|
68 |
-
"import random\n",
|
69 |
-
"from pathlib import Path\n",
|
70 |
-
"\n",
|
71 |
-
"# ---------------- 1️⃣ Настройки ----------------\n",
|
72 |
-
"dtype = torch.float16\n",
|
73 |
-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
74 |
-
"batch_size = 64\n",
|
75 |
-
"img_size = 64\n",
|
76 |
-
"\n",
|
77 |
-
"# 1. Явно создать все необходимые директории\n",
|
78 |
-
"cache_root = Path(\"cache\")\n",
|
79 |
-
"(cache_root/\"datasets\").mkdir(parents=True, exist_ok=True)\n",
|
80 |
-
"Path(\"datasets\").mkdir(parents=True, exist_ok=True)\n",
|
81 |
-
"\n",
|
82 |
-
"# 2. Установить переменные среды ПЕРЕД импортом библиотек\n",
|
83 |
-
"os.environ[\"HF_HOME\"] = str(cache_root)\n",
|
84 |
-
"os.environ[\"HF_DATASETS_CACHE\"] = str(cache_root/\"datasets\")\n",
|
85 |
-
"\n",
|
86 |
-
"\n",
|
87 |
-
"# ---------------- 2️⃣ Загрузка датасета ----------------\n",
|
88 |
-
"dataset = load_dataset(\"mnist\", split=\"train\", cache_dir=str(cache_root))\n",
|
89 |
-
"\n",
|
90 |
-
"# ---------------- 3️⃣ Загрузка моделей ----------------\n",
|
91 |
-
"vae = AutoencoderKL.from_pretrained(\"AuraDiffusion/16ch-vae\", torch_dtype=dtype).to(device).eval()\n",
|
92 |
-
"model = AutoModel.from_pretrained(\"visheratin/mexma-siglip\", torch_dtype=dtype, trust_remote_code=True, optimized=True).to(device)\n",
|
93 |
-
"processor = AutoImageProcessor.from_pretrained(\"visheratin/mexma-siglip\",use_fast=True)\n",
|
94 |
-
"tokenizer = AutoTokenizer.from_pretrained(\"visheratin/mexma-siglip\")\n",
|
95 |
-
"\n",
|
96 |
-
"# ---------------- 4️⃣ Трансформации ----------------\n",
|
97 |
-
"transform = Compose([\n",
|
98 |
-
" lambda img: img.convert(\"RGB\"), \n",
|
99 |
-
" Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC), # Ресайз\n",
|
100 |
-
" ToTensor(), # В тензор\n",
|
101 |
-
" Normalize(mean=0.5, std=0.5) # [-1, 1]\n",
|
102 |
-
"])\n",
|
103 |
-
"\n",
|
104 |
-
"# ---------------- 5️⃣ Функция обработки изображений ----------------\n",
|
105 |
-
"def encode_images_batch(images):\n",
|
106 |
-
" pixel_values = torch.stack([processor(images=img, return_tensors=\"pt\")[\"pixel_values\"].squeeze(0) for img in images]).to(device, dtype)\n",
|
107 |
-
" \n",
|
108 |
-
" with torch.inference_mode():\n",
|
109 |
-
" image_embeddings = model.vision_model(pixel_values).pooler_output #chang on last_hidden_state # (B, 729, 1152)\n",
|
110 |
-
"\n",
|
111 |
-
" return image_embeddings.unsqueeze(1).cpu().numpy()\n",
|
112 |
-
"\n",
|
113 |
-
"def encode_texts_batch(texts):\n",
|
114 |
-
" with torch.inference_mode():\n",
|
115 |
-
" text_tokenized = tokenizer(texts, return_tensors=\"pt\", padding=True).to(device)\n",
|
116 |
-
" text_embeddings = model.encode_texts(text_tokenized.input_ids,text_tokenized.attention_mask)\n",
|
117 |
-
" return text_embeddings.unsqueeze(1).cpu().numpy()\n",
|
118 |
-
"\n",
|
119 |
-
"\n",
|
120 |
-
"# return empty str with prob\n",
|
121 |
-
"def maybe_empty_label(label, prob=0.01):\n",
|
122 |
-
" return \"\" if random.random() < prob else label\n",
|
123 |
-
"\n",
|
124 |
-
"\n",
|
125 |
-
"def encode_to_latents(examples):\n",
|
126 |
-
" pixel_values = torch.stack([transform(img) for img in examples[\"image\"]]).to(device, dtype) # (B, 3, 256, 256)\n",
|
127 |
-
" \n",
|
128 |
-
" # VAE Latents\n",
|
129 |
-
" with torch.no_grad():\n",
|
130 |
-
" posterior = vae.encode(pixel_values.to(device)).latent_dist.mode()\n",
|
131 |
-
" z = (posterior - vae.config.shift_factor) * vae.config.scaling_factor\n",
|
132 |
-
" latents = z.cpu().numpy()\n",
|
133 |
-
"\n",
|
134 |
-
" # Преобразование числовых меток в строковые\n",
|
135 |
-
" text_labels = [str(lbl) for lbl in examples[\"label\"]]\n",
|
136 |
-
" \n",
|
137 |
-
" if random.random() < 0.5:\n",
|
138 |
-
" # Image Embeddings\n",
|
139 |
-
" pil_images = [Image.fromarray(((img.cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)) for img in pixel_values]\n",
|
140 |
-
" embeddings = encode_images_batch(pil_images)\n",
|
141 |
-
" #print(\"image_embeddings\",embeddings.shape)\n",
|
142 |
-
" else:\n",
|
143 |
-
" text_labels_with_empty = [maybe_empty_label(lbl) for lbl in text_labels]\n",
|
144 |
-
" #print(\"text_labels_with_empty\",text_labels_with_empty)\n",
|
145 |
-
" embeddings = encode_texts_batch(text_labels_with_empty)\n",
|
146 |
-
" #print(\"text_embeddings\",embeddings.shape)\n",
|
147 |
-
"\n",
|
148 |
-
" return {\n",
|
149 |
-
" \"vae\": latents,\n",
|
150 |
-
" \"embeddings\": embeddings,\n",
|
151 |
-
" \"text\": text_labels\n",
|
152 |
-
" }\n",
|
153 |
-
"\n",
|
154 |
-
"# ---------------- 6️⃣ Обработка датасета ----------------\n",
|
155 |
-
"limited_dataset = dataset#.select(range(10))#00000)) # Ограничиваем 1000 семплов\n",
|
156 |
-
"encoded_dataset = limited_dataset.map(\n",
|
157 |
-
" encode_to_latents,\n",
|
158 |
-
" batched=True,\n",
|
159 |
-
" batch_size=batch_size,\n",
|
160 |
-
" remove_columns=[\"image\"],\n",
|
161 |
-
" desc=\"Encoding images to latents\"\n",
|
162 |
-
")\n",
|
163 |
-
"\n",
|
164 |
-
"# ---------------- 7️⃣ Сохранение ----------------\n",
|
165 |
-
"save_path = \"datasets/mnist\"\n",
|
166 |
-
"os.makedirs(save_path, exist_ok=True)\n",
|
167 |
-
"#encoded_dataset.to_parquet(os.path.join(save_path, \"dataset.parquet\")) # Оптимальный формат\n",
|
168 |
-
"encoded_dataset.save_to_disk(save_path)\n",
|
169 |
-
"print(\"ok\")\n",
|
170 |
-
"\n"
|
171 |
-
]
|
172 |
-
},
|
173 |
-
{
|
174 |
-
"cell_type": "code",
|
175 |
-
"execution_count": 2,
|
176 |
-
"id": "c0507b7e-7dbe-43ca-999a-6dc38aa1fb40",
|
177 |
-
"metadata": {},
|
178 |
-
"outputs": [
|
179 |
-
{
|
180 |
-
"name": "stdout",
|
181 |
-
"output_type": "stream",
|
182 |
-
"text": [
|
183 |
-
"Форма латентного представления: (16, 8, 8)\n",
|
184 |
-
"embedding shape: (1, 1152)\n"
|
185 |
-
]
|
186 |
-
},
|
187 |
-
{
|
188 |
-
"data": {
|
189 |
-
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAFeCAYAAADnm4a1AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAI0VJREFUeJzt3XmQVNX58PGnu6e7Z6ZnxVmAGZiBQTYlaEBcACFqpBLxJ7jgUlHGBXeMKTTBskB/xlIxUkWVccuiWEWsRBGRLBU1UfF1iZrlFSUqiyKLs+9bz9J93j8s5vX2cw72DCMMw/dTlarcZ07fvt1z5/FwnnvO8RljjAAAFP/hvgAAGKxIkADgQIIEAAcSJAA4kCABwIEECQAOJEgAcCBBAoADCRIAHEiQOGrt2rVLfD6frF279nBfCgYpEmQfrV27Vnw+X+//UlJSpKioSMrLy2Xfvn2H+/IG3KOPPnrYE8jhvobXX39dfD6frF+//rBdAw6PlMN9AUeqe+65R8aMGSPRaFT+8Y9/yNq1a+XNN9+Ujz76SFJTUw/35Q2YRx99VPLy8qS8vPyovgYcnUiQ/fSDH/xApk+fLiIi11xzjeTl5cmqVatk06ZNsmjRosN8dYdHW1ubRCKRw30ZwIDhn9gDZPbs2SIisnPnTk/8k08+kQsvvFCGDRsmqampMn36dNm0aZN6fWNjo/zkJz+R0tJSCYfDUlxcLFdccYXU1tb2tqmurparr75aCgsLJTU1VaZOnSpPP/205zz7x9Ueeugh+dWvfiVlZWUSDoflpJNOkvfff9/TtrKyUq688kopLi6WcDgsI0aMkPPOO0927dolIiKlpaWydetW2bx5c++Qwty5c0Xk/w81bN68WW688UYpKCiQ4uJiEREpLy+X0tJS9Rnvvvtu8fl8Kr5u3TqZMWOGpKenS25urpx++uny8ssvf+M17P/ebr31Vhk1apSEw2EZN26crFq1SuLxuPp+y8vLJTs7W3JycmTx4sXS2NioriVZ+z/Ltm3b5Ec/+pFkZ2dLfn6+rFixQowxsmfPHjnvvPMkKytLhg8fLqtXr/a8vqurS1auXCnTpk2T7OxsiUQiMnv2bHnttdfUe9XV1cnll18uWVlZvdf+wQcfWMdPk73fkBx6kANkf1LJzc3tjW3dulVmzpwpRUVFsnz5colEIvLss8/KggUL5Pnnn5eFCxeKiEhra6vMnj1bPv74Y7nqqqvku9/9rtTW1sqmTZtk7969kpeXJx0dHTJ37lzZsWOH3HzzzTJmzBh57rnnpLy8XBobG+XHP/6x53qeeeYZaWlpkeuuu058Pp88+OCDcv7558tnn30mwWBQREQuuOAC2bp1qyxdulRKS0ulurpaXnnlFdm9e7eUlpbKmjVrZOnSpZKRkSF33nmniIgUFhZ63ufGG2+U/Px8WblypbS1tfX5e/vf//1fufvuu+W0006Te+65R0KhkLz77rvy6quvytlnn33Aa2hvb5c5c+bIvn375LrrrpPRo0fL22+/LXfccYdUVFTImjVrRETEGCPnnXeevPnmm3L99dfLpEmT5IUXXpDFixf3+XoTXXzxxTJp0iR54IEH5M9//rPce++9MmzYMHniiSfkjDPOkFWrVsnvfvc7ue222+Skk06S008/XUREmpub5Te/+Y1ceumlsmTJEmlpaZHf/va3Mm/ePHnvvffkhBNOEBGReDwu5557rrz33ntyww03yMSJE+XFF1+0Xnuy9xv6wKBPnnrqKSMi5m9/+5upqakxe/bsMevXrzf5+fkmHA6bPXv29LY988wzzZQpU0w0Gu2NxeNxc9ppp5ljjz22N7Zy5UojImbDhg3q/eLxuDHGmDVr1hgRMevWrev9WVdXlzn11FNNRkaGaW5uNsYY8/nnnxsRMcccc4ypr6/vbfviiy8aETF//OMfjTHGNDQ0GBExv/jFLw74eY877jgzZ84c5/cwa9Ys09PT4/nZ4sWLTUlJiXrNXXfdZb5+y23fvt34/X6zcOFCE4vFrJ/7QNfw85//3EQiEbNt2zZPfPny5SYQCJjdu3cbY4zZuHGjERHz4IMP9rbp6ekxs2fPNiJinnrqKdfHN8YY89prrxkRMc8995z6LNdee63nnMXFxcbn85kHHnigN97Q0GDS0tLM4sWLPW07Ozs979PQ0GAKCwvNVVdd1Rt7/vnnjYiYNWvW9MZisZg544wz1LUne78hefwTu5/OOussyc/Pl1GjRsmFF14okUhENm3a1PvPzPr6enn11Vdl0aJF0tLSIrW1tVJbWyt1dXUyb9482b59e2/V+/nnn5epU6da/wu//5+kf/nLX2T48OFy6aWX9v4sGAzKLbfcIq2trbJ582bP6y6++GJPb3b/EMBnn30mIiJpaWkSCoXk9ddfl4aGhn5/D0uWLJFAINCv127cuFHi8bisXLlS/H7vrWj7p3ii5557TmbPni25ubm9329tba2cddZZEovF5I033hCRr767lJQUueGGG3pfGwgEZOnSpf267q+75pprPOecPn26GGPk6quv7o3n5OTIhAkTer/7/W1DoZCIfNVLrK+vl56eHpk+fbr8+9//7m3317/+VYLBoCxZsqQ35vf75aabbvJcR1/uNySPf2L30yOPPCLjx4+XpqYmefLJJ+WNN96QcDjc+/MdO3aIMUZWrFghK1assJ6jurpaioqKZOfOnXLBBRcc8P2++OILOfbYY1UimTRpUu/Pv2706NGe4/3Jcn8yDIfDsmrVKlm2bJkUFhbKKaecIvPnz5crrrhChg8fnsQ38JUxY8Yk3TbRzp07xe/3y+TJk/v1+u3bt8uWLVskPz/f+vPq6moR+eq7GTFihGRkZHh+PmHChH6979clfs/Z2dmSmpoqeXl5Kl5XV+eJPf3007J69Wr55JNPpLu7uzf+9e90/7Wnp6d7Xjtu3DjPcV/uNySPBNlPM2bM6K1iL1iwQGbNmiWXXXaZfPrpp5KRkdFbJLjttttk3rx51nMk3uQDydWrM1/bYePWW2+Vc889VzZu3CgvvfSSrFixQu6//3559dVX5cQTT0zqfdLS0lTM1fuLxWJJnTNZ8Xhcvv/978tPf/pT68/Hjx8/oO9nY/uek/nu161bJ+Xl5bJgwQK5/fbbpaCgQAKBgNx///2q0JeMw32/DVUkyAGw/8b+3ve+J7/85S9l+fLlMnbsWBH56p/BZ5111gFfX1ZWJh999NEB25SUlMiWLVskHo97epGffPJJ78/7o6ysTJYtWybLli2T7du3ywknnCCrV6+WdevWiUhy/9RNlJuba60QJ/Zyy8rKJB6Py3//+9/eooSN6xrKysqktbX1G7/fkpIS+fvf/y6tra2eXuSnn356wNd9m9avXy9jx46VDRs2eD7fXXfd5WlXUlIir732mrS3t3t6kTt27PC068v9huQxBjlA5s6dKzNmzJA1a9ZINBqVgoICmTt3rjzxxBNSUVGh2tfU1PT+/wsuuEA++OADeeGFF1S7/b2OH/7wh1JZWSl/+MMfen/W09MjDz/8sGRkZMicOXP6dL3t7e0SjUY9sbKyMsnMzJTOzs7eWCQS6fPjMGVlZdLU1CRbtmzpjVVUVKjPt2DBAvH7/XLPPfeox3K+3ttyXcOiRYvknXfekZdeekn9rLGxUXp6ekTkq++up6dHHnvssd6fx2Ixefjhh/v0uQbS/l7m1z/nu+++K++8846n3bx586S7u1t+/etf98bi8bg88sgjnnZ9ud+QPHqQA+j222+Xiy66SNauXSvXX3+9PPLIIzJr1iyZMmWKLFmyRMaOHStVVVXyzjvvyN69e+WDDz7ofd369evloosukquuukqmTZsm9fX1smnTJnn88cdl6tSpcu2118oTTzwh5eXl8q9//UtKS0tl/fr18tZbb8maNWskMzOzT9e6bds2OfPMM2XRokUyefJkSUlJkRdeeEGqqqrkkksu6W03bdo0eeyxx+Tee++VcePGSUFBgZxxxhkHPPcll1wiP/vZz2ThwoVyyy23SHt7uzz22GMyfvx4TwFi3Lhxcuedd8rPf/5zmT17tpx//vkSDofl/fffl5EjR8r9999/wGu4/fbbZdOmTTJ//nwpLy+XadOmSVtbm3z44Yeyfv162bVrl+Tl5cm5554rM2fOlOXLl8uuXbtk8uTJsmHDBmlqaurTdzaQ5s+fLxs2bJCFCxfKOeecI59//rk8/vjjMnnyZGltbe1tt2DBApkxY4YsW7ZMduzYIRMnTpRNmzZJfX29iHh718neb+iDw1dAPzLtf7zl/fffVz+LxWKmrKzMlJWV9T76snPnTnPFFVeY4cOHm2AwaIqKisz8+fPN+vXrPa+tq6szN998sykqKjKhUMgUFxebxYsXm9ra2t42VVVV5sorrzR5eXkmFAqZKVOmqEdU9j/mY3t8R0TMXXfdZYwxpra21tx0001m4sSJJhKJmOzsbHPyySebZ5991vOayspKc84555jMzEwjIr2P2xzoezDGmJdfftkcf/zxJhQKmQkTJph169apx3z2e/LJJ82JJ55owuGwyc3NNXPmzDGvvPLKN16DMca0tLSYO+64w4wbN86EQiGTl5dnTjvtNPPQQw+Zrq4uz/d7+eWXm6ysLJOdnW0uv/xy85///OegH/OpqanxtF28eLGJRCLqHHPmzDHHHXdc73E8Hjf33XefKSkpMeFw2Jx44onmT3/6k/URqZqaGnPZZZeZzMxMk52dbcrLy81bb71lRMT8/ve/97RN9n5DcnzGsC82cKTZuHGjLFy4UN58802ZOXPm4b6cIYsECQxyHR0dnqcFYrGYnH322fLPf/5TKisrrU8SYGAwBgkMckuXLpWOjg459dRTpbOzUzZs2CBvv/223HfffSTHbxk9SGCQe+aZZ2T16tWyY8cOiUajMm7cOLnhhhvk5ptvPtyXNuSRIAHAgecgAcCBBAkADiRIAHBIuood8OmZGkaillhcxTTbsOdADoUmO3/Y9p621/b32pI9V2K7oTgsfDC/E/Tdt30fJyPZ90v2/Mm0020iIV3pb+1sSeod6UECgAMJEgAcSJAA4ECCBACHpB8U78/CqQBw+OlatDHdlnYaPUgAcCBBAoADCRIAHEiQAODAepAAhrSAr/+zs+hBAoADCRIAHEiQAODAGCSAIS1+EHNc6EECgAMJEgAcSJAA4ECCBAAHijQAhraD2NmaHiQAOJAgAcCBBAkADiRIAHCgSANgSDMHsdc6PUgAcCBBAoADCRIAHEiQAOBAkQbAkOZjuTMAGHgkSABwIEECgAMJEgAcKNIAGNoo0gDAwCNBAoADCRIAHBiDBDDE9X8Qkh4kADiQIAHAgQQJAA4kSABwoEgDYEg7iG2x6UECgAsJEgAcSJAA4ECCBAAHijTAoKVngKSk6D/ZUGpQt/OFVKytM+o5jnV1Wt7zICoaQxA9SABwIEECgAMJEgAcSJAA4ECRBhgkAgHvn2NWJKLaTJg0QcVmzjpVxdKDmSr2f95+y3P83n/+qdq0t7R843UeefpfeKIHCQAOJEgAcCBBAoADCRIAHCjSAINEeka653hUSYlqM3P26Sp24aIFKpaWkqFi8RTvzJzP932u2uyJdujXdfeo2BGF5c4AYOCRIAHAgQQJAA4kSABwGJJFmtSwXuop3TIrwe8PJHW+9oSB6/a2Nt2IVaJwkNIjaZ7jEUUjVJtj8oapWCQ9XcV6LCuZdXZ1e44Dfv3nH7Qsp9Z5pBdp9KpxSaMHCQAOJEgAcCBBAoDDET8GaRszGT5Sj90UFY9UsXAoVcUCfj1gUVFR6Tn+dMd21aY72q1iQF9kZnrHEm1jkBnpeiy9vb1LxZrq9Ko8dXU13oBtw2jG0j3oQQKAAwkSABxIkADgQIIEAIcjv0gT0g+F51keph09ShdpsjKy9fmCeo9hv89buNm1+wvVhiINDlYkYTLDsOws3SiuqyjRjqiKdVr2vPYnvDQU1H/+B/FM9ZBEDxIAHEiQAOBAggQABxIkADgc+UUay0yarCxdfBlZWKRiw3LzVCycpos0zU1NnuOQpTAkYlnhB3Dw+XQ5JBLxzqQZdkyuapNpKywG9KpU6RlpKlZYVOh9XTis2sRts2uOcAdTeKIHCQAOJEgAcCBBAoADCRIAHI74Ik1aul6ybFSxLsiMHz9OxfIL9Oyari5dbPl02zbPsV+S26oBcAlYCiuZWZme4+JRo1SbsjH6Pi4Yma9iX+yuUrGm5lbPcXOzXhKtuyemL/aI1/8yDT1IAHAgQQKAAwkSABxIkADgMASKNHpP4LFleiB78pQpKhbw6xkxlZX7VCxxboE/wKJQODgBn+6b5OV4iy1lY8eqNrZiY3dML7VXW1OjYl/s2vWNbUzsCN8De4DRgwQABxIkADiQIAHA4Ygfg0xJ0avv5B9zjIoVFhSoWNzy34eKyr26XTzuOe6JDcWHafFtSQnqh8ILLPdjSUmJ57i4qES1SY9kqFhNTbWKVVZ8qWJVVd6Hxzs69LYM8KIHCQAOJEgAcCBBAoADCRIAHI78Io1lb9+MLL2fsG1wu6NN7yccs6xmEot7Yz52D4aDP6Dvx0nHTVKxE6eeoGJz5s7xHB+Tq/d391seMG+or1Ox+gYda231rlQV7zla9nLv/zYS9CABwIEECQAOJEgAcCBBAoDDEV+ksWwvLH5LEcW2xH0oVe8LHEuYNSMi0tPtLdKYIbh3MAZGyLJP+7hj9ao858yfp2LfOeF477nCerWpHsv9WVGpV+Vpam5WsXhPl4odHdhyAQAGHAkSABxIkADgQIIEAIcjvkgTCurlzlIssxlsg+ddUT2TIHFpMxGRUNhbzPH52Rcbdj6/7nNEIpkqVjJaL2WWm5XrDVhqgT1d+p7ds/cLFauqrlCxxJk0+Gb0IAHAgQQJAA4kSABwIEECgEMfijS2p9EP/YySzMyI53jEiBGqTc6wXBULplqKOd16D+BwaqqKqcINE2ngELDsPyMxXfiLhPV+7olLpdkmbHVE9T4ye/fqvdw//fhTFWtqaNInPCqw3BkADDgSJAA4kCABwGFQPygeTtXjNCOKiz3Ho0fpB27z8/S+2EHL/tkpIf3QbUpIjyH5fN4YQ5BwCVtWiMrKzlGxtIi+t32JN5ZlELK9o13FmhobVKymWq/w09Z+tD4ozmo+ADDgSJAA4ECCBAAHEiQAOAzqIk16epqKjSwa6TkuKRml2gyzFGnE6IFanyXW3aUfHjeWFX4Am6xMvSd7WljfxyFLzCTej5baQqflQfG2Vl24scXYKqTv6EECgAMJEgAcSJAA4ECCBACHwV2kscw2yC/I9xwXFg5XbTIyMvTJLBtod3bqAe9oNKpiXQn7CccZ7IaIpFi23igdpYuGpaN1LD1NF2kS52gZE1MtOrv03tbtnToW7T5a98AeWPQgAcCBBAkADiRIAHAgQQKAw6Au0qSl6yJNblaO5zgnJ0e1CYVDKuYP6P8WNLe0qlhTc7OKtTR728V79OA5jj5FJaNVbMYpp6rYxIkTVCycqu9RX0IhMWa5zWx7W7c161isi3t0INCDBAAHEiQAOJAgAcCBBAkADoO6SBMO6X1kMjMzPcdZWZmqTcAyw8FnWTuqublRxSorKlUscc+PWEzvZYMjmb43QineIsrwkXrG1ukzZ6nYlMnHqdjEiZNULDVs2X895p1J09ik97GuqdV7zTS36HY9Mb1s39Gh//vP2NCDBAAHEiQAOJAgAcCBBAkADn0o0tgGP7/dZb/8AV1sCYfDBzwWEQlYZs3E4/paGxr1rJmKKl2kqav3Fmm6uynSHKlycnNULHGfIxGRgsIRnuNJk3WhZdZpp6jYyTP0TJphCUv0iYiIT9+jHdEOz/HWLf9VbT788EMVq6mqUrF4z9FapLHkJMtSh8miBwkADiRIAHAgQQKAw6B+UDxgGaexxRL5LG3ilqVR2iwro9TW1qlYU8IDu4xBDgbecaVAir6Vsy0rPU2Zoh/knjL1Oyo2LDfPczxpkl6R5+STv6ti+fl6vDHFcm22obL6ulrP8ebNm1WbDz78SMWqq/TD4992feBoQQ8SABxIkADgQIIEAAcSJAA4DOoije35Tr8/Iaf79MPkYtm3uiemY92Wh2m7u3QBpj3hAV4Tj+v3HCR8fv2lJX4dQcsqSalpenuLNMvWFUHLw/uBoD5f4nX4RL8uLSuiYlkRvad5RqZesSk15J0gELJcw/DhhSo2wrIqT3GR3rc6JzfXczxqtN5e4Zi8YSqWEtB/UrbHlG21xrY27332xe7dqs1nn+3Ur2vVW4dgYNCDBAAHEiQAOJAgAcCBBAkADkkXaWwFE0stZEDZtkkwCfWRuG1peduCHpbzRyK6MOFP0YP9sUFQlPEH9HVl52apWGqqXt0oljCLKBLRRY+i4iIVKyzUs0KybAWTSJqKJRZz4jH9GxhhWUVn/LgyFRtZVKzfM+Fzpli22QiH9XcWtHyP9gJel/dcqfozmrj+TJYameMPRTfsiLZ7jmsb9KyuGstMr2hnp+X86HUQeYoeJAA4kCABwIEECQAOJEgAcEi6SPNtF2RsLLskiC+hWhSwzVxIscz2sFSZRlsKE8cdf7yK1dXVe46//PJL1aajo13F4l16iTVj+VApQe9nSLPsmZybMLNDRGRk0QgVCyTONBKRlmbv1hI5OdmqzdgxJSpWMqZUxWz7OQctM25MYjXNMpNmVImenTJp4kQVy8rWxajEIoft/uzqiKpYa8JsFRGRpjq9ZcGXFdWeY7/fUqg7Tl/r6FH6M9mW37NpTNgCpKa6VrVpatbbhLC02beHHiQAOJAgAcCBBAkADiRIAHDow3Jnh34gOBbtUjF/QrElkqFnOMTjujgSsCzTNWq0Xubqf845R8UmTvAOxldX60H9qKVI0xXVMxziMT3YH0wo0gT9+teSk60LK6WWwkp6RC8hFk+Y8ZGRqQstkTS9zFhWjp41k2JZXi5m+b67Evbt6bR8Fz09+ruoqdQzRSor9Pfd1u7dT8hWvNizd5+K7bPEdluWFatJ2Oclz1Iku/yKK1SsqFDPDvIH9e/TWIqGzQl7HzUm7McuIhKN6sITvkH/t8WmBwkALiRIAHAgQQKAw6Aeg4x26zHI5jbvWF9DfZNq01ijx26yhukxpHBQP+B87LHjVCxx6f6OqH7YuKtLX2u8R4/NxWwPige8/50KWvZRzszQY4Q5OZbVfNL0+GLcJIwb2pacsaxoE7esYtRl+Z10WsbFYgljlbF2vZVFbY0eb9y9S48H7qvSD+Z/ucfbbl+lbrN79x4Vs+0h3WAZ64vHvatEnTB1qmrT3qH3VbeNx6ZY/sxs+7S3tLZ4jju7LKv0WM6Pbw89SABwIEECgAMJEgAcSJAA4DCo98WurtUD6m+/86bnuDOq9wTeO32Gik397gkqlpWlH4QOW4ociUWTQLp+GDsWtBRHLAUZ+xYR3nZ+297TlgeL/Zb/vjU36e+jtrbRc9zarh9qb27Wxa46S/GitUU/kF1bp4stiXs1tzTq13V06GJXY0OjjjXp66irqU5oo89fZ7mu1hZdWOm2FEzSIt7fZ1aOLvKlZ+gtO3wB/TuxrTTUE9dFq8Tvo6vbcq+wcs8hRQ8SABxIkADgQIIEAAcSJAA4DOoiTVOjHpx/7x//9Bx/YZl58eWX1SpW19ioYkUjhqtYjmXGTUjNuNEFE9uCIbaVe2KWmEmYsZIS0r8WfQ0iKYFKFav4Uq98s2XLVs/xlxV6RZvaer28f2WFPn9tjW6XuKWDiEi0yzu7prtTFyXiloJDzDL7yNYucaaLsZ7LVuTQfYJIpp6lNOU7kz3Hp5x2qmpTNEKv3OOzbC1hjP5MnZaZV80t3pk0tu05cGjRgwQABxIkADiQIAHAgQQJAA6DukhjE4t5B7crK3UhYctH/1fFbEWU/PxjVCwS0QP2oZC3QOK37D1tmegixjKFQm0XLba9vvX5bSuU2WZo1DXUq9jObTs9x7aZL03NjSpWXWvZl7lRz7gxYvlQg5b+0oYN0/fBidOmeY6nTTvB8rphKubz6fP3xHSBqsqybUdtfcLvhe7LYcevAAAcSJAA4ECCBAAHEiQAOBxxRZpEtuXDdn2hZ9ckLvklIhIOBlUsxTJjxZ9QIbEVfKxb71qKKD5LgUe3sS1tpsUse8bY9jFpbfEuPWYrGnRbZp3Y2h359C8l17K3z4jhBZ7jYyz7YicW70RELPU1aWnVy8t9vPVjFduTsD93t2UvcRxa9CABwIEECQAOJEgAcDjixyBtejr1SimNnfqhZ0BEJCWkx6LTUr3bKSQ+zC8iErSsumRbVaiyUj8U/q9//1vF9uz17uPd3aPvYxxa9CABwIEECQAOJEgAcCBBAoDDkCzSAH0RsDyYn1iTsRVp/JanwnssD9xXVen93Xfu/EzF9uzxFmmi7XrfcBxa9CABwIEECQAOJEgAcCBBAoADRRrAUoBJ3C7DusKS5XW2vdCra/Q+7fu+1HuTt7V598WOWfbTxqFFDxIAHEiQAOBAggQABxIkADhQpAEsW2OYmDfot2yqYZtdYxO1bJ0Q7dCxrm7vLJx43HJhOKToQQKAAwkSABxIkADgQIIEAAeKNIDRxZZYQoHEZ1nazG/Z49z4dWElGNTtAsGA5XwJMVsNiLrNIUUPEgAcSJAA4ECCBAAHEiQAOFCkASyVj8T6iLHMavH7bEUavdxZKBhSsdycYSqWnZXpOW5raVFtYj0sgXYo0YMEAAcSJAA4kCABwIExSMC2Kk/CkKNtewXbg9w+y7hkVna2ig0fWahie3bv9hxXVlapNoxB9sNBPFxPDxIAHEiQAOBAggQABxIkADhQpAFsWy6oiK2QY2llKeakBPSfWSgYtLRLXOGHpXv6LrltMJJFDxIAHEiQAOBAggQABxIkADhQpMFRz1iKLYmxuKVN3LJVg23GjS6+iMS79ao/scR9sWO6Db6JrbDV/8INPUgAcCBBAoADCRIAHEiQAOBAkQZHPWMsxZAkJrHYhv5tRZpQWM+aSUtLU7FIesRznJKi/zxZ7uzQogcJAA4kSABwIEECgAMJEgAcKNIAlj2v9b7YSb1MApY9aTIzs1RsZHGxiuWP8O5JE/z0E9WmM9qp3xQHdDCLxtGDBAAHEiQAOJAgAcCBMUgc9fx+vdqOP2EFHlsb25YLtpWBMjMyVWx4YYGKRTLSE96T/svA6P8oJL8BAHAgQQKAAwkSABxIkADgQJEGR72AZVkef8Kj4gFLI5/lofBYT4+KhVPD+vy2wlDCFg62gg8OLXqQAOBAggQABxIkADiQIAHAgSINjnpx+94JnsNAit42weezrAJk2XLBMgdHYt3dul3CewSD+j1xaNGDBAAHEiQAOJAgAcCBBAkADhRpcNTrinaoWEtzs+e4qbFBtem2FFrClj2ww6GQiqWFUy1X4i36xGOWfR7QZ7YaXLLoQQKAAwkSABxIkADgQIIEAAeKNDjqtba1q9hnn33uOR5TUqrajJ84XsUCwRwVi/v1n1k4PV3FEvlt67ChH/r/PdKDBAAHEiQAOJAgAcCBMUgc9apralTsw60feY7zC/JUm1FlJSqWX5CvYq0trSr2+a7dKtbY0Og57umOqTbou4PZuIIeJAA4kCABwIEECQAOJEgAcPCZJDfftS0lDwwNup+QGol4jkcOH6HajC0rU7HMjDQVa7YUaSorKlSsuqrSc1xXV6faxOOs8NNnfv37NbHkCmD0IAHAgQQJAA4kSABwIEECgANFGgBDms+n+4HxOEUaADgoJEgAcCBBAoADCRIAHFjuDMCQZnz9X/CMHiQAOJAgAcCBBAkADiRIAHCgSANgiKNIAwADjgQJAA4kSABwIEECgANFGgBDms+yJ02y6EECgAMJEgAcSJAA4JD0GKRtw4X+P34JHC7Jbh3C3X149Tfj6NcZwxgkAAw4EiQAOJAgAcCBBAkADkkXaWzDowHbgGhCy7jldZZtasVYG/ZzoNbyOtupbGc6iNXZB0yy12oPfsvlNGuNI4nCR7KXkOz5E7dz7+91JdnOeq8keXqf9fq/+V2sW9ZbYkn/dvu9tf1B/fIS2P7QdULwWWLG8lpf4ndmOX0w3faeyaEHCQAOJEgAcCBBAoADCRIAHJIv0tgGjAFgCKMHCQAOJEgAcCBBAoADCRIAHEiQAOBAggQABxIkADiQIAHAgQQJAA7/D5Err3THjWEFAAAAAElFTkSuQmCC",
|
190 |
-
"text/plain": [
|
191 |
-
"<Figure size 400x400 with 1 Axes>"
|
192 |
-
]
|
193 |
-
},
|
194 |
-
"metadata": {},
|
195 |
-
"output_type": "display_data"
|
196 |
-
},
|
197 |
-
{
|
198 |
-
"name": "stdout",
|
199 |
-
"output_type": "stream",
|
200 |
-
"text": [
|
201 |
-
"🔹 Текстовое описание: 4\n",
|
202 |
-
"ok\n"
|
203 |
-
]
|
204 |
-
}
|
205 |
-
],
|
206 |
-
"source": [
|
207 |
-
"from datasets import load_from_disk\n",
|
208 |
-
"import matplotlib.pyplot as plt\n",
|
209 |
-
"import numpy as np\n",
|
210 |
-
"import torch\n",
|
211 |
-
"from PIL import Image \n",
|
212 |
-
"\n",
|
213 |
-
"dtype = torch.float16\n",
|
214 |
-
"\n",
|
215 |
-
"# Загружаем сохраненный датасет\n",
|
216 |
-
"loaded_dataset = load_from_disk(save_path)\n",
|
217 |
-
"\n",
|
218 |
-
"# Проверяем структуру датасета\n",
|
219 |
-
"#print(\"Структура датасета:\", loaded_dataset.features)\n",
|
220 |
-
"\n",
|
221 |
-
"# Выбираем пример для демонстрации\n",
|
222 |
-
"example = loaded_dataset[2]\n",
|
223 |
-
"\n",
|
224 |
-
"# Выводим информацию о примере\n",
|
225 |
-
"print(\"Форма латентного представления:\", np.array(example[\"vae\"]).shape)\n",
|
226 |
-
"print(\"embedding shape:\", np.array(example[\"embeddings\"]).shape)\n",
|
227 |
-
"\n",
|
228 |
-
"# Преобразуем латентное представление в тензор PyTorch\n",
|
229 |
-
"latent_tensor = torch.tensor(example[\"vae\"], dtype=dtype).unsqueeze(0).to(device)\n",
|
230 |
-
"\n",
|
231 |
-
"# Декодируем латентное представление обратно в изображение\n",
|
232 |
-
"with torch.no_grad():\n",
|
233 |
-
" #reconstructed_image = vae.decode(latent_tensor).sample # Результат — тензор\n",
|
234 |
-
" latent = (latent_tensor.detach() / vae.config.scaling_factor) + vae.config.shift_factor\n",
|
235 |
-
" reconstructed_image = vae.decode(latent).sample\n",
|
236 |
-
"\n",
|
237 |
-
"# Переносим тензор на CPU и преобразуем в NumPy массив\n",
|
238 |
-
"\n",
|
239 |
-
"reconstructed_image = reconstructed_image.squeeze(0).cpu().numpy() # Удаляем размерность батча\n",
|
240 |
-
"\n",
|
241 |
-
"# Переносим каналы в правильный формат (CHW -> HWC) и нормализуем значения пикселей\n",
|
242 |
-
"reconstructed_image = np.transpose(reconstructed_image, (1, 2, 0))\n",
|
243 |
-
"reconstructed_image = (reconstructed_image + 1) / 2 # Нормализация в диапазон [0, 1]\n",
|
244 |
-
"\n",
|
245 |
-
"# Преобразуем тип данных к float32\n",
|
246 |
-
"reconstructed_image = reconstructed_image.astype(np.float32)\n",
|
247 |
-
"reconstructed_image = np.clip(reconstructed_image, 0.0, 1.0)\n",
|
248 |
-
"\n",
|
249 |
-
"# Отображаем восстановленное изображение\n",
|
250 |
-
"plt.figure(figsize=(4, 4))\n",
|
251 |
-
"plt.imshow(reconstructed_image)\n",
|
252 |
-
"plt.title(f\"Reconstructed Image\")\n",
|
253 |
-
"plt.axis(\"off\")\n",
|
254 |
-
"plt.show()\n",
|
255 |
-
"print(f\"🔹 Текстовое описание: {example['text']}\")\n",
|
256 |
-
"print(\"ok\")"
|
257 |
-
]
|
258 |
-
},
|
259 |
-
{
|
260 |
-
"cell_type": "code",
|
261 |
-
"execution_count": 5,
|
262 |
-
"id": "3818a9a7-f72c-42ea-9805-43bfd4b214a0",
|
263 |
-
"metadata": {},
|
264 |
-
"outputs": [],
|
265 |
-
"source": [
|
266 |
-
"#!pip install matplotlib"
|
267 |
-
]
|
268 |
-
},
|
269 |
-
{
|
270 |
-
"cell_type": "code",
|
271 |
-
"execution_count": 2,
|
272 |
-
"id": "f5336198-8925-4e03-ad81-72ee0eb5248a",
|
273 |
-
"metadata": {},
|
274 |
-
"outputs": [
|
275 |
-
{
|
276 |
-
"name": "stdout",
|
277 |
-
"output_type": "stream",
|
278 |
-
"text": [
|
279 |
-
"HF_HOME is set to: cache\n",
|
280 |
-
"HF_DATASETS_CACHE is set to: cache/datasets\n"
|
281 |
-
]
|
282 |
-
}
|
283 |
-
],
|
284 |
-
"source": [
|
285 |
-
"# Проверьте переменные окружения\n",
|
286 |
-
"hf_home = os.environ.get(\"HF_HOME\")\n",
|
287 |
-
"hf_datasets_cache = os.environ.get(\"HF_DATASETS_CACHE\")\n",
|
288 |
-
"\n",
|
289 |
-
"if hf_home:\n",
|
290 |
-
" print(f\"HF_HOME is set to: {hf_home}\")\n",
|
291 |
-
"if hf_datasets_cache:\n",
|
292 |
-
" print(f\"HF_DATASETS_CACHE is set to: {hf_datasets_cache}\")"
|
293 |
-
]
|
294 |
-
},
|
295 |
-
{
|
296 |
-
"cell_type": "code",
|
297 |
-
"execution_count": 3,
|
298 |
-
"id": "94d7f0fe-3b27-4d08-ba35-ba7de096c635",
|
299 |
-
"metadata": {},
|
300 |
-
"outputs": [
|
301 |
-
{
|
302 |
-
"name": "stdout",
|
303 |
-
"output_type": "stream",
|
304 |
-
"text": [
|
305 |
-
"cache/datasets\n"
|
306 |
-
]
|
307 |
-
}
|
308 |
-
],
|
309 |
-
"source": [
|
310 |
-
"print(hf_datasets_cache)"
|
311 |
-
]
|
312 |
-
},
|
313 |
-
{
|
314 |
-
"cell_type": "code",
|
315 |
-
"execution_count": 2,
|
316 |
-
"id": "8206ec60-e828-4a56-b902-3b44b99b563c",
|
317 |
-
"metadata": {},
|
318 |
-
"outputs": [
|
319 |
-
{
|
320 |
-
"name": "stdout",
|
321 |
-
"output_type": "stream",
|
322 |
-
"text": [
|
323 |
-
"Collecting transformers\n",
|
324 |
-
" Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)\n",
|
325 |
-
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.13.1)\n",
|
326 |
-
"Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.28.1)\n",
|
327 |
-
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.3)\n",
|
328 |
-
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.1)\n",
|
329 |
-
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
|
330 |
-
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n",
|
331 |
-
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n",
|
332 |
-
"Collecting tokenizers<0.22,>=0.21 (from transformers)\n",
|
333 |
-
" Downloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n",
|
334 |
-
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.2)\n",
|
335 |
-
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n",
|
336 |
-
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2024.2.0)\n",
|
337 |
-
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.12.2)\n",
|
338 |
-
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.3.2)\n",
|
339 |
-
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n",
|
340 |
-
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.2.3)\n",
|
341 |
-
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2024.8.30)\n",
|
342 |
-
"Downloading transformers-4.49.0-py3-none-any.whl (10.0 MB)\n",
|
343 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.0/10.0 MB\u001b[0m \u001b[31m98.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
344 |
-
"\u001b[?25hDownloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)\n",
|
345 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m94.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
346 |
-
"\u001b[?25hInstalling collected packages: tokenizers, transformers\n",
|
347 |
-
"Successfully installed tokenizers-0.21.0 transformers-4.49.0\n",
|
348 |
-
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
349 |
-
"\u001b[0m"
|
350 |
-
]
|
351 |
-
}
|
352 |
-
],
|
353 |
-
"source": [
|
354 |
-
"!pip install -U transformers --break-system-packages"
|
355 |
-
]
|
356 |
-
},
|
357 |
-
{
|
358 |
-
"cell_type": "code",
|
359 |
-
"execution_count": 4,
|
360 |
-
"id": "48e6e1d3-310c-4f67-a27b-90a54036dc4d",
|
361 |
-
"metadata": {},
|
362 |
-
"outputs": [
|
363 |
-
{
|
364 |
-
"name": "stdout",
|
365 |
-
"output_type": "stream",
|
366 |
-
"text": [
|
367 |
-
"Collecting flash-attn\n",
|
368 |
-
" Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)\n",
|
369 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m94.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
370 |
-
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
|
371 |
-
"\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from flash-attn) (2.4.1+cu124)\n",
|
372 |
-
"Collecting einops (from flash-attn)\n",
|
373 |
-
" Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)\n",
|
374 |
-
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.13.1)\n",
|
375 |
-
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (4.12.2)\n",
|
376 |
-
"Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (1.12)\n",
|
377 |
-
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.2.1)\n",
|
378 |
-
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.1.3)\n",
|
379 |
-
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (2024.2.0)\n",
|
380 |
-
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
381 |
-
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
382 |
-
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
383 |
-
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (9.1.0.70)\n",
|
384 |
-
"Requirement already satisfied: nvidia-cublas-cu12==12.4.2.65 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.2.65)\n",
|
385 |
-
"Requirement already satisfied: nvidia-cufft-cu12==11.2.0.44 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (11.2.0.44)\n",
|
386 |
-
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.119 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (10.3.5.119)\n",
|
387 |
-
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.0.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (11.6.0.99)\n",
|
388 |
-
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.0.142 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.3.0.142)\n",
|
389 |
-
"Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (2.20.5)\n",
|
390 |
-
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
391 |
-
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n",
|
392 |
-
"Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.0.0)\n",
|
393 |
-
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch->flash-attn) (2.1.5)\n",
|
394 |
-
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.11/dist-packages (from sympy->torch->flash-attn) (1.3.0)\n",
|
395 |
-
"Downloading einops-0.8.1-py3-none-any.whl (64 kB)\n",
|
396 |
-
"Building wheels for collected packages: flash-attn\n",
|
397 |
-
" Building wheel for flash-attn (setup.py) ... \u001b[?25ldone\n",
|
398 |
-
"\u001b[?25h Created wheel for flash-attn: filename=flash_attn-2.7.4.post1-cp311-cp311-linux_x86_64.whl size=187805408 sha256=92cf49e6f66795b6934cec0cba526ed6e45d3313de3f905d45df8773f19092a9\n",
|
399 |
-
" Stored in directory: /root/.cache/pip/wheels/3d/88/d8/284b89f56af7d5bf366b10d6b8e251ac8a7c7bf3f04203fb4f\n",
|
400 |
-
"Successfully built flash-attn\n",
|
401 |
-
"Installing collected packages: einops, flash-attn\n",
|
402 |
-
"Successfully installed einops-0.8.1 flash-attn-2.7.4.post1\n",
|
403 |
-
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
404 |
-
"\u001b[0m"
|
405 |
-
]
|
406 |
-
}
|
407 |
-
],
|
408 |
-
"source": [
|
409 |
-
"!pip install flash-attn --no-build-isolation"
|
410 |
-
]
|
411 |
-
},
|
412 |
-
{
|
413 |
-
"cell_type": "code",
|
414 |
-
"execution_count": null,
|
415 |
-
"id": "24b9dd4e-789a-4162-b5b2-45c26f9b7504",
|
416 |
-
"metadata": {},
|
417 |
-
"outputs": [],
|
418 |
-
"source": []
|
419 |
-
}
|
420 |
-
],
|
421 |
-
"metadata": {
|
422 |
-
"kernelspec": {
|
423 |
-
"display_name": "Python 3 (ipykernel)",
|
424 |
-
"language": "python",
|
425 |
-
"name": "python3"
|
426 |
-
},
|
427 |
-
"language_info": {
|
428 |
-
"codemirror_mode": {
|
429 |
-
"name": "ipython",
|
430 |
-
"version": 3
|
431 |
-
},
|
432 |
-
"file_extension": ".py",
|
433 |
-
"mimetype": "text/x-python",
|
434 |
-
"name": "python",
|
435 |
-
"nbconvert_exporter": "python",
|
436 |
-
"pygments_lexer": "ipython3",
|
437 |
-
"version": "3.11.10"
|
438 |
-
}
|
439 |
-
},
|
440 |
-
"nbformat": 4,
|
441 |
-
"nbformat_minor": 5
|
442 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
model_index.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "SdxsPipeline",
|
3 |
+
"_diffusers_version": "0.33.0.dev0",
|
4 |
+
"scheduler": [
|
5 |
+
"diffusers",
|
6 |
+
"DDPMScheduler"
|
7 |
+
],
|
8 |
+
"text_encoder": [
|
9 |
+
"transformers",
|
10 |
+
"XLMRobertaModel"
|
11 |
+
],
|
12 |
+
"text_projector": [
|
13 |
+
"torch.nn.modules.linear",
|
14 |
+
"Linear"
|
15 |
+
],
|
16 |
+
"tokenizer": [
|
17 |
+
"transformers",
|
18 |
+
"XLMRobertaTokenizerFast"
|
19 |
+
],
|
20 |
+
"unet": [
|
21 |
+
"diffusers",
|
22 |
+
"UNet2DConditionModel"
|
23 |
+
],
|
24 |
+
"vae": [
|
25 |
+
"diffusers",
|
26 |
+
"AutoencoderKL"
|
27 |
+
]
|
28 |
+
}
|
pipeline_sdxs.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import os
|
5 |
+
from diffusers.utils import BaseOutput
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import List, Union, Optional
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import json
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class SdxsPipelineOutput(BaseOutput):
|
16 |
+
images: Union[List[Image.Image], np.ndarray]
|
17 |
+
|
18 |
+
class SdxsPipeline(DiffusionPipeline):
|
19 |
+
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
# Register components
|
23 |
+
self.register_modules(
|
24 |
+
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
|
25 |
+
unet=unet, scheduler=scheduler
|
26 |
+
)
|
27 |
+
|
28 |
+
# Get the model path, which is either provided directly or from internal dict
|
29 |
+
model_path = None
|
30 |
+
if hasattr(self, '_internal_dict') and self._internal_dict.get('_name_or_path'):
|
31 |
+
model_path = self._internal_dict.get('_name_or_path')
|
32 |
+
|
33 |
+
# Get device and dtype from existing components
|
34 |
+
device = "cuda"
|
35 |
+
dtype = torch.float16
|
36 |
+
|
37 |
+
# Always load text_projector, regardless of whether one was provided
|
38 |
+
projector_path = None
|
39 |
+
|
40 |
+
# Try to find projector path
|
41 |
+
if model_path and os.path.exists(f"{model_path}/text_projector"):
|
42 |
+
projector_path = f"{model_path}/text_projector"
|
43 |
+
elif os.path.exists("./text_projector"):
|
44 |
+
projector_path = "./text_projector"
|
45 |
+
|
46 |
+
if projector_path:
|
47 |
+
# Create and load projector
|
48 |
+
try:
|
49 |
+
with open(f"{projector_path}/config.json", "r") as f:
|
50 |
+
projector_config = json.load(f)
|
51 |
+
|
52 |
+
# Create Linear layer with bias=False
|
53 |
+
self.text_projector = nn.Linear(
|
54 |
+
in_features=projector_config["in_features"],
|
55 |
+
out_features=projector_config["out_features"],
|
56 |
+
bias=False
|
57 |
+
)
|
58 |
+
|
59 |
+
# Load the state dict using safetensors
|
60 |
+
self.text_projector.load_state_dict(load_file(f"{projector_path}/model.safetensors"))
|
61 |
+
self.text_projector.to(device=device, dtype=dtype)
|
62 |
+
print(f"Successfully loaded text_projector from {projector_path}",device, dtype)
|
63 |
+
except Exception as e:
|
64 |
+
print(f"Error loading text_projector: {e}")
|
65 |
+
|
66 |
+
self.vae_scale_factor = 8
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
|
71 |
+
"""Кодирование текстовых промптов в эмбеддинги.
|
72 |
+
|
73 |
+
Возвращает:
|
74 |
+
- text_embeddings: Тензор эмбеддингов [batch_size, 1, dim] или [2*batch_size, 1, dim] с guidance
|
75 |
+
"""
|
76 |
+
if prompt is None and negative_prompt is None:
|
77 |
+
raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
|
78 |
+
|
79 |
+
# Устанавливаем device и dtype
|
80 |
+
device = device or self.device
|
81 |
+
dtype = dtype or next(self.unet.parameters()).dtype
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
# Обрабатываем позитивный промпт
|
85 |
+
if prompt is not None:
|
86 |
+
if isinstance(prompt, str):
|
87 |
+
prompt = [prompt]
|
88 |
+
|
89 |
+
text_inputs = self.tokenizer(
|
90 |
+
prompt, return_tensors="pt", padding="max_length",
|
91 |
+
max_length=512, truncation=True
|
92 |
+
).to(device)
|
93 |
+
|
94 |
+
# Получаем эмбеддинги
|
95 |
+
outputs = self.text_encoder(text_inputs.input_ids, text_inputs.attention_mask)
|
96 |
+
last_hidden_state = outputs.last_hidden_state.to(device, dtype=dtype)
|
97 |
+
pos_embeddings = self.text_projector(last_hidden_state[:, 0])
|
98 |
+
|
99 |
+
# Добавляем размерность для batch processing
|
100 |
+
if pos_embeddings.ndim == 2:
|
101 |
+
pos_embeddings = pos_embeddings.unsqueeze(1)
|
102 |
+
else:
|
103 |
+
# Создаем пустые эмбеддинги, если нет позитивного промпта
|
104 |
+
# (полезно для некоторых сценариев с unconditional generation)
|
105 |
+
batch_size = len(negative_prompt) if isinstance(negative_prompt, list) else 1
|
106 |
+
pos_embeddings = torch.zeros(
|
107 |
+
batch_size, 1, self.unet.config.cross_attention_dim,
|
108 |
+
device=device, dtype=dtype
|
109 |
+
)
|
110 |
+
|
111 |
+
# Обрабатываем негативный промпт
|
112 |
+
if negative_prompt is not None:
|
113 |
+
if isinstance(negative_prompt, str):
|
114 |
+
negative_prompt = [negative_prompt]
|
115 |
+
|
116 |
+
# Убеждаемся, что размеры негативного и позитивного промптов совпадают
|
117 |
+
if prompt is not None and len(negative_prompt) != len(prompt):
|
118 |
+
neg_batch_size = len(prompt)
|
119 |
+
if len(negative_prompt) == 1:
|
120 |
+
negative_prompt = negative_prompt * neg_batch_size
|
121 |
+
else:
|
122 |
+
negative_prompt = negative_prompt[:neg_batch_size]
|
123 |
+
|
124 |
+
neg_inputs = self.tokenizer(
|
125 |
+
negative_prompt, return_tensors="pt", padding="max_length",
|
126 |
+
max_length=512, truncation=True
|
127 |
+
).to(device)
|
128 |
+
|
129 |
+
neg_outputs = self.text_encoder(neg_inputs.input_ids, neg_inputs.attention_mask)
|
130 |
+
neg_last_hidden_state = neg_outputs.last_hidden_state.to(device, dtype=dtype)
|
131 |
+
neg_embeddings = self.text_projector(neg_last_hidden_state[:, 0])
|
132 |
+
|
133 |
+
if neg_embeddings.ndim == 2:
|
134 |
+
neg_embeddings = neg_embeddings.unsqueeze(1)
|
135 |
+
|
136 |
+
# Объединяем для classifier-free guidance
|
137 |
+
text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
|
138 |
+
else:
|
139 |
+
# Если нет негативного промпта, используем нулевые эмбеддинги
|
140 |
+
batch_size = pos_embeddings.shape[0]
|
141 |
+
neg_embeddings = torch.zeros_like(pos_embeddings)
|
142 |
+
text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
|
143 |
+
|
144 |
+
return text_embeddings.to(device=device, dtype=dtype)
|
145 |
+
|
146 |
+
@torch.no_grad()
|
147 |
+
def generate_latents(
|
148 |
+
self,
|
149 |
+
text_embeddings,
|
150 |
+
height: int = 576,
|
151 |
+
width: int = 576,
|
152 |
+
num_inference_steps: int = 40,
|
153 |
+
guidance_scale: float = 5.0,
|
154 |
+
latent_channels: int = 16,
|
155 |
+
batch_size: int = 1,
|
156 |
+
generator = None,
|
157 |
+
):
|
158 |
+
"""Генерация латентов с использованием эмбеддингов промптов."""
|
159 |
+
device = self.device
|
160 |
+
dtype = next(self.unet.parameters()).dtype
|
161 |
+
|
162 |
+
# Проверка размера эмбеддингов
|
163 |
+
do_classifier_free_guidance = guidance_scale > 0
|
164 |
+
embedding_dim = text_embeddings.shape[0] // 2 if do_classifier_free_guidance else text_embeddings.shape[0]
|
165 |
+
|
166 |
+
if batch_size > embedding_dim:
|
167 |
+
# Повторяем эмбеддинги до нужного размера батча
|
168 |
+
if do_classifier_free_guidance:
|
169 |
+
neg_embeds, pos_embeds = text_embeddings.chunk(2)
|
170 |
+
neg_embeds = neg_embeds.repeat(batch_size // embedding_dim, 1, 1)
|
171 |
+
pos_embeds = pos_embeds.repeat(batch_size // embedding_dim, 1, 1)
|
172 |
+
text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
|
173 |
+
else:
|
174 |
+
text_embeddings = text_embeddings.repeat(batch_size // embedding_dim, 1, 1)
|
175 |
+
|
176 |
+
# Установка timesteps
|
177 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
178 |
+
|
179 |
+
# Инициализация латентов с заданным seed
|
180 |
+
latent_shape = (
|
181 |
+
batch_size,
|
182 |
+
latent_channels,
|
183 |
+
height // self.vae_scale_factor,
|
184 |
+
width // self.vae_scale_factor
|
185 |
+
)
|
186 |
+
latents = torch.randn(
|
187 |
+
latent_shape,
|
188 |
+
device=device,
|
189 |
+
dtype=dtype,
|
190 |
+
generator=generator
|
191 |
+
)
|
192 |
+
|
193 |
+
# Процесс диффузии
|
194 |
+
for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
|
195 |
+
# Подготовка входных данных
|
196 |
+
if do_classifier_free_guidance:
|
197 |
+
latent_input = torch.cat([latents] * 2)
|
198 |
+
else:
|
199 |
+
latent_input = latents
|
200 |
+
|
201 |
+
latent_input = self.scheduler.scale_model_input(latent_input, t)
|
202 |
+
|
203 |
+
# Предсказание шума
|
204 |
+
noise_pred = self.unet(latent_input, t, text_embeddings).sample
|
205 |
+
|
206 |
+
# Применение guidance
|
207 |
+
if do_classifier_free_guidance:
|
208 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
209 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
210 |
+
noise_pred_text - noise_pred_uncond
|
211 |
+
)
|
212 |
+
|
213 |
+
# Обновление латентов
|
214 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
215 |
+
|
216 |
+
return latents
|
217 |
+
|
218 |
+
def decode_latents(self, latents, output_type="pil"):
|
219 |
+
"""Декодирование латентов в изображения."""
|
220 |
+
# Нормализация латентов
|
221 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
222 |
+
|
223 |
+
# Декодирование
|
224 |
+
with torch.no_grad():
|
225 |
+
images = self.vae.decode(latents).sample
|
226 |
+
|
227 |
+
# Нормализация изображений
|
228 |
+
images = (images / 2 + 0.5).clamp(0, 1)
|
229 |
+
|
230 |
+
# Конвертация в нужный формат
|
231 |
+
if output_type == "pil":
|
232 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
233 |
+
images = (images * 255).round().astype("uint8")
|
234 |
+
return [Image.fromarray(image) for image in images]
|
235 |
+
else:
|
236 |
+
return images.cpu().permute(0, 2, 3, 1).float().numpy()
|
237 |
+
|
238 |
+
@torch.no_grad()
|
239 |
+
def __call__(
|
240 |
+
self,
|
241 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
242 |
+
height: int = 576,
|
243 |
+
width: int = 576,
|
244 |
+
num_inference_steps: int = 40,
|
245 |
+
guidance_scale: float = 5.0,
|
246 |
+
latent_channels: int = 16,
|
247 |
+
output_type: str = "pil",
|
248 |
+
return_dict: bool = True,
|
249 |
+
batch_size: int = 1,
|
250 |
+
seed: Optional[int] = None,
|
251 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
252 |
+
text_embeddings: Optional[torch.FloatTensor] = None,
|
253 |
+
):
|
254 |
+
"""Генерация изображения из текстовых промптов или эмбеддингов."""
|
255 |
+
device = self.device
|
256 |
+
|
257 |
+
# Устанавливаем генератор с seed для воспроизводимости
|
258 |
+
generator = None
|
259 |
+
if seed is not None:
|
260 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
261 |
+
|
262 |
+
# Получаем эмбеддинги, если они не предоставлены
|
263 |
+
if text_embeddings is None:
|
264 |
+
if prompt is None and negative_prompt is None:
|
265 |
+
raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
|
266 |
+
|
267 |
+
# Вычисляем эмбеддинги
|
268 |
+
text_embeddings = self.encode_prompt(
|
269 |
+
prompt=prompt,
|
270 |
+
negative_prompt=negative_prompt,
|
271 |
+
device=device
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
# Убеждаемся, что эмбеддинги на правильном устройстве
|
275 |
+
text_embeddings = text_embeddings.to(device)
|
276 |
+
|
277 |
+
# Генерируем латенты
|
278 |
+
latents = self.generate_latents(
|
279 |
+
text_embeddings=text_embeddings,
|
280 |
+
height=height,
|
281 |
+
width=width,
|
282 |
+
num_inference_steps=num_inference_steps,
|
283 |
+
guidance_scale=guidance_scale,
|
284 |
+
latent_channels=latent_channels,
|
285 |
+
batch_size=batch_size,
|
286 |
+
generator=generator
|
287 |
+
)
|
288 |
+
|
289 |
+
# Декодируем латенты в изображения
|
290 |
+
images = self.decode_latents(latents, output_type=output_type)
|
291 |
+
|
292 |
+
if not return_dict:
|
293 |
+
return images
|
294 |
+
|
295 |
+
return SdxsPipelineOutput(images=images)
|
samples/sdxs_320x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_384x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_448x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_512x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x320_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x384_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x448_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x512_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/sdxs_576x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
scheduler/scheduler_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "DDPMScheduler",
|
3 |
+
"_diffusers_version": "0.33.0.dev0",
|
4 |
+
"beta_end": 0.02,
|
5 |
+
"beta_schedule": "linear",
|
6 |
+
"beta_start": 0.0001,
|
7 |
+
"clip_sample": true,
|
8 |
+
"clip_sample_range": 1.0,
|
9 |
+
"dynamic_thresholding_ratio": 0.995,
|
10 |
+
"num_train_timesteps": 1000,
|
11 |
+
"prediction_type": "v_prediction",
|
12 |
+
"rescale_betas_zero_snr": true,
|
13 |
+
"sample_max_value": 1.0,
|
14 |
+
"steps_offset": 1,
|
15 |
+
"thresholding": false,
|
16 |
+
"timestep_spacing": "leading",
|
17 |
+
"trained_betas": null,
|
18 |
+
"variance_type": "fixed_small"
|
19 |
+
}
|
sdxs/diffusion_pytorch_model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4529095968
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ddd8d0c467b18d515d1505d21905c36081896edbeb0813f8ffc92dd62d2405b7
|
3 |
size 4529095968
|
sdxs_create.ipynb
DELETED
@@ -1,153 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 1,
|
6 |
-
"id": "84d99aec-8010-4ad8-be83-a2db26d78b7a",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"name": "stdout",
|
11 |
-
"output_type": "stream",
|
12 |
-
"text": [
|
13 |
-
"test unet\n",
|
14 |
-
"Количество параметров: 2264320720\n",
|
15 |
-
"Output shape: torch.Size([1, 16, 60, 48])\n"
|
16 |
-
]
|
17 |
-
}
|
18 |
-
],
|
19 |
-
"source": [
|
20 |
-
"\n",
|
21 |
-
"config_sdxs = {\n",
|
22 |
-
" # === Основные размеры и каналы ===\n",
|
23 |
-
" \"sample_size\": 64, # Размер латентного пространства (должен соответствовать VAE)\n",
|
24 |
-
" \"in_channels\": 16, # Количество входных каналов (совместимость с 16-канальным VAE)\n",
|
25 |
-
" \"out_channels\": 16, # Количество выходных каналов (симметрично in_channels)\n",
|
26 |
-
" \"center_input_sample\": False, # Отключение центрирования входных данных (стандарт для диффузионных моделей)\n",
|
27 |
-
" \"flip_sin_to_cos\": True, # Автоматическое преобразование sin/cos в эмбеддингах времени (для стабильности)\n",
|
28 |
-
" \"freq_shift\": 0, # Сдвиг частоты (0 - стандартное значение для частотных эмбеддингов)\n",
|
29 |
-
"\n",
|
30 |
-
" # === Архитектура блоков ===\n",
|
31 |
-
" \"down_block_types\": [ # Типы блоков энкодера (иерархия обработки):\n",
|
32 |
-
" \"CrossAttnDownBlock2D\",\n",
|
33 |
-
" \"CrossAttnDownBlock2D\", # Высокий уровень - интеграция текстовых условий через cross-attention\n",
|
34 |
-
" \"CrossAttnDownBlock2D\", # Средне-высокий уровень - продолжение интеграции контекста\n",
|
35 |
-
" \"CrossAttnDownBlock2D\", # Средний уровень - фокус на семантических признаках\n",
|
36 |
-
" ],\n",
|
37 |
-
" \"mid_block_type\": \"UNetMidBlock2DCrossAttn\", # Центральный блок с cross-attention (бутылочное горлышко сети)\n",
|
38 |
-
" \"up_block_types\": [ # Типы блоков декодера (восстановление изображения):\n",
|
39 |
-
" \"CrossAttnUpBlock2D\", # Средний уровень - интеграция текстовых условий\n",
|
40 |
-
" \"CrossAttnUpBlock2D\", # Средне-высокий уровень - уточнение деталей\n",
|
41 |
-
" \"CrossAttnUpBlock2D\", # Высокий уровень - финальная детализация с учетом текста\n",
|
42 |
-
" \"CrossAttnUpBlock2D\"\n",
|
43 |
-
" ],\n",
|
44 |
-
" \"only_cross_attention\": False, # Использование как cross-attention, так и self-attention\n",
|
45 |
-
"\n",
|
46 |
-
" # === Конфигурация каналов ===\n",
|
47 |
-
" \"block_out_channels\": [384, 576, 768, 960], #[256, 512, 768, 1024], # Число каналов на каждом уровне (оптимизировано для баланса параметров)\n",
|
48 |
-
" \"layers_per_block\": [2, 2, 2, 2], # Число слоев в блоках (больше слоев на средних уровнях)\n",
|
49 |
-
" \"downsample_padding\": 1, # Паддинг при уменьшении разрешения (предотвращение артефактов)\n",
|
50 |
-
" \"mid_block_scale_factor\": 1.0, # Усиление сигнала в центральном блоке (улучшение детализации)\n",
|
51 |
-
" \"dropout\": 0.1, # Регуляризация для предотвращения переобучения\n",
|
52 |
-
" \"act_fn\": \"silu\", # Активационная функция (баланс между скоростью и качеством)\n",
|
53 |
-
"\n",
|
54 |
-
" # === Нормализация ===\n",
|
55 |
-
" \"norm_num_groups\": 16, # Число групп для GroupNorm (оптимально для стабильности)\n",
|
56 |
-
" \"norm_eps\": 1e-05, # Эпсилон для нормализации (стандартное значение)\n",
|
57 |
-
"\n",
|
58 |
-
" # === Cross-Attention ===\n",
|
59 |
-
" \"cross_attention_dim\": 1152, # Размерность текстовых эмбеддингов (должна совпадать с text_encoder)\n",
|
60 |
-
" \"transformer_layers_per_block\": [4, 6, 8, 10], # Число трансформерных слоев (уменьшение с глубиной)\n",
|
61 |
-
" \"attention_head_dim\": 48, # Размерность головы внима��ия (1152/48=24 голов)\n",
|
62 |
-
" \"dual_cross_attention\": False, # Отключение двойного внимания (упрощение архитектуры)\n",
|
63 |
-
" \"use_linear_projection\": True, # Изменено на True для лучшей организации памяти\n",
|
64 |
-
"\n",
|
65 |
-
" # === ResNet Блоки ===\n",
|
66 |
-
" \"resnet_time_scale_shift\": \"default\", # Способ интеграции временных эмбеддингов\n",
|
67 |
-
" \"resnet_skip_time_act\": False, # Отключение активации в skip-соединениях\n",
|
68 |
-
" \"resnet_out_scale_factor\": 1.0, # Коэффициент масштабирования выхода ResNet\n",
|
69 |
-
"\n",
|
70 |
-
" # === Временные эмбеддинги ===\n",
|
71 |
-
" \"time_embedding_type\": \"positional\", # Тип временных эмбеддингов (стандартный подход)\n",
|
72 |
-
"\n",
|
73 |
-
" # === Свертки ===\n",
|
74 |
-
" \"conv_in_kernel\": 3, # Ядро входной свертки (баланс между рецептивным полем и параметрами)\n",
|
75 |
-
" \"conv_out_kernel\": 3, # Ядро выходной свертки (симметрично входной)\n",
|
76 |
-
"\n",
|
77 |
-
" # Другие параметры\n",
|
78 |
-
" \"attention_type\": \"default\",\n",
|
79 |
-
" \"class_embeddings_concat\": False,\n",
|
80 |
-
" \"mid_block_only_cross_attention\": None,\n",
|
81 |
-
" \"cross_attention_norm\": None,\n",
|
82 |
-
" \"addition_embed_type_num_heads\": 48 # Уменьшено с 64\n",
|
83 |
-
" # === Тестирование ===\n",
|
84 |
-
" # Проверка кратности каналов 32 (требование для некоторых оптимизаций)\n",
|
85 |
-
" # Выходная форма соответствует ожиданиям: [1, 16, 60, 48]\n",
|
86 |
-
"}\n",
|
87 |
-
"#Ключевые особенности:\n",
|
88 |
-
"\n",
|
89 |
-
"#Иерархическая архитектура с акцентом на cross-attention в верхних слоях\n",
|
90 |
-
"#Прогрессивное уменьшение числа трансформерных слоев с глубиной\n",
|
91 |
-
"#Совместимость с 16-канальным VAE через симметричные in/out channels\n",
|
92 |
-
"\n",
|
93 |
-
"if 1:\n",
|
94 |
-
" checkpoint_path = \"test\"#\"sdxs\"\n",
|
95 |
-
" import torch\n",
|
96 |
-
" from diffusers import UNet2DConditionModel\n",
|
97 |
-
" print(\"test unet\")\n",
|
98 |
-
" new_unet = UNet2DConditionModel(**config_sdxs).to(\"cuda\", dtype=torch.float16)\n",
|
99 |
-
"\n",
|
100 |
-
" assert all(ch % 32 == 0 for ch in new_unet.config[\"block_out_channels\"]), \"Каналы должны быть кратны 32\"\n",
|
101 |
-
" num_params = sum(p.numel() for p in new_unet.parameters())\n",
|
102 |
-
" print(f\"Количество параметров: {num_params}\")\n",
|
103 |
-
"\n",
|
104 |
-
" # Генерация тестового латента (640x512 в latent space)\n",
|
105 |
-
" test_latent = torch.randn(1, 16, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n",
|
106 |
-
" timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n",
|
107 |
-
" encoder_hidden_states = torch.randn(1, 77, 1152).to(\"cuda\", dtype=torch.float16)\n",
|
108 |
-
" \n",
|
109 |
-
" with torch.no_grad():\n",
|
110 |
-
" output = new_unet(\n",
|
111 |
-
" test_latent, \n",
|
112 |
-
" timesteps, \n",
|
113 |
-
" encoder_hidden_states\n",
|
114 |
-
" ).sample\n",
|
115 |
-
" \n",
|
116 |
-
" print(f\"Output shape: {output.shape}\") \n",
|
117 |
-
" new_unet.save_pretrained(checkpoint_path)\n",
|
118 |
-
" del new_unet\n",
|
119 |
-
" torch.cuda.empty_cache()\n",
|
120 |
-
" # Количество параметров: 2264320720"
|
121 |
-
]
|
122 |
-
},
|
123 |
-
{
|
124 |
-
"cell_type": "code",
|
125 |
-
"execution_count": null,
|
126 |
-
"id": "64232924-6bdd-4373-a9cd-058b696d6394",
|
127 |
-
"metadata": {},
|
128 |
-
"outputs": [],
|
129 |
-
"source": []
|
130 |
-
}
|
131 |
-
],
|
132 |
-
"metadata": {
|
133 |
-
"kernelspec": {
|
134 |
-
"display_name": "Python 3 (ipykernel)",
|
135 |
-
"language": "python",
|
136 |
-
"name": "python3"
|
137 |
-
},
|
138 |
-
"language_info": {
|
139 |
-
"codemirror_mode": {
|
140 |
-
"name": "ipython",
|
141 |
-
"version": 3
|
142 |
-
},
|
143 |
-
"file_extension": ".py",
|
144 |
-
"mimetype": "text/x-python",
|
145 |
-
"name": "python",
|
146 |
-
"nbconvert_exporter": "python",
|
147 |
-
"pygments_lexer": "ipython3",
|
148 |
-
"version": "3.11.10"
|
149 |
-
}
|
150 |
-
},
|
151 |
-
"nbformat": 4,
|
152 |
-
"nbformat_minor": 5
|
153 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/captions_moondream2.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d07dc72b1cf8b504632c01ce54688f431978fdf4772ac1ab568c1a6ebebb790
|
3 |
+
size 4897
|
captions_qwen2-vl-7b.py → src/captions_qwen2-vl-7b.py
RENAMED
File without changes
|
src/captions_wd.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b46993285995bdc52e69d82e900f239ebafd9dc924be046e80372639c8796ed8
|
3 |
+
size 29850
|
dataset_combine.py → src/dataset_combine.py
RENAMED
File without changes
|
src/dataset_fromzip.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d12f225412aeed4f5ff6cd2dc23db6bcdfd944ff00fe6f0d9c2e8fe0ec426ee
|
3 |
+
size 6167
|
src/dataset_imagenet.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92ea6bc9e4033a778b9e36defff6adf481baae0d8b0a3fa537313df6fb5b4472
|
3 |
+
size 318505
|
src/dataset_laion_coco.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31395e7f40ef370971b523fb9d9ab56b404ca8cc1e8e932cc602beaf72140411
|
3 |
+
size 25403
|
src/dataset_mjnj.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d09861f88b7ce2eec45aec8c10879baf5043253c2cfe7444cc3b6a63019d3b30
|
3 |
+
size 2631555
|
src/dataset_mnist-te.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac1571369244de9ff15d4b1785e962e06521630fa1be32f0471175e42ef00630
|
3 |
+
size 34388
|
src/dataset_mnist.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c644e111748cb374d2fb9fec28ef99a5ed616898100e689cd02c6ba80b3431a7
|
3 |
+
size 33829
|
src/inference.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:16438cb1f590bf011c890a0c5e6fb0ffd70ae313107e485a7aef51d4e39a6d08
|
3 |
+
size 6997529
|
src/sdxs_create.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1a303d31ce80486d5a00149df2551340dfbf21e588b81bd3c2f5c909a0fb017
|
3 |
+
size 9510
|
text_encoder/config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/home/recoilme/sdxs576/text_encoder",
|
3 |
+
"architectures": [
|
4 |
+
"XLMRobertaModel"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"classifier_dropout": null,
|
9 |
+
"eos_token_id": 2,
|
10 |
+
"hidden_act": "gelu",
|
11 |
+
"hidden_dropout_prob": 0.1,
|
12 |
+
"hidden_size": 1024,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"intermediate_size": 4096,
|
15 |
+
"layer_norm_eps": 1e-05,
|
16 |
+
"max_position_embeddings": 514,
|
17 |
+
"model_type": "xlm-roberta",
|
18 |
+
"num_attention_heads": 16,
|
19 |
+
"num_hidden_layers": 24,
|
20 |
+
"output_past": true,
|
21 |
+
"pad_token_id": 1,
|
22 |
+
"position_embedding_type": "absolute",
|
23 |
+
"torch_dtype": "float16",
|
24 |
+
"transformers_version": "4.48.3",
|
25 |
+
"type_vocab_size": 1,
|
26 |
+
"use_cache": true,
|
27 |
+
"vocab_size": 250002
|
28 |
+
}
|
text_encoder/model.fp16.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:107fe15da52fe6d13d877512fa36861d1100534d1b9b88015ad9fd017db095a7
|
3 |
+
size 1119825680
|
text_projector/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"in_features": 1024, "out_features": 1152, "bias": false, "_class_name": "Linear"}
|
text_projector/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e7060a387b4a6419f9d1d852759cb5b94541a1845e996f6062a07462d8b7b6a
|
3 |
+
size 2359384
|
tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"cls_token": {
|
10 |
+
"content": "<s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"eos_token": {
|
17 |
+
"content": "</s>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"mask_token": {
|
24 |
+
"content": "<mask>",
|
25 |
+
"lstrip": true,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"pad_token": {
|
31 |
+
"content": "<pad>",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
},
|
37 |
+
"sep_token": {
|
38 |
+
"content": "</s>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false
|
43 |
+
},
|
44 |
+
"unk_token": {
|
45 |
+
"content": "<unk>",
|
46 |
+
"lstrip": false,
|
47 |
+
"normalized": false,
|
48 |
+
"rstrip": false,
|
49 |
+
"single_word": false
|
50 |
+
}
|
51 |
+
}
|
tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<s>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<pad>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "</s>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "<unk>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"250001": {
|
36 |
+
"content": "<mask>",
|
37 |
+
"lstrip": true,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"bos_token": "<s>",
|
45 |
+
"clean_up_tokenization_spaces": false,
|
46 |
+
"cls_token": "<s>",
|
47 |
+
"eos_token": "</s>",
|
48 |
+
"extra_special_tokens": {},
|
49 |
+
"mask_token": "<mask>",
|
50 |
+
"model_max_length": 512,
|
51 |
+
"pad_token": "<pad>",
|
52 |
+
"sep_token": "</s>",
|
53 |
+
"tokenizer_class": "XLMRobertaTokenizer",
|
54 |
+
"unk_token": "<unk>"
|
55 |
+
}
|
train.py_
CHANGED
@@ -18,33 +18,36 @@ from accelerate.state import DistributedType
|
|
18 |
from torch.distributed import broadcast_object_list
|
19 |
from torch.utils.checkpoint import checkpoint
|
20 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
21 |
-
|
22 |
|
23 |
# --------------------------- Параметры ---------------------------
|
24 |
-
save_path = "datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
|
25 |
-
batch_size =
|
26 |
-
base_learning_rate = 5e-5 #8e-5
|
27 |
min_learning_rate = 2.5e-5 #2e-5
|
28 |
-
num_epochs =
|
29 |
project = "sdxs"
|
30 |
use_wandb = True
|
31 |
-
limit = 0
|
32 |
save_model = True
|
|
|
|
|
33 |
checkpoints_folder = ""
|
34 |
-
|
35 |
use_lr_decay = False # отключить затухание
|
36 |
|
37 |
# Параметры для диффузии
|
38 |
n_diffusion_steps = 40
|
39 |
samples_to_generate = 12
|
40 |
guidance_scale = 5
|
|
|
41 |
|
42 |
# Папки для сохранения результатов
|
43 |
generated_folder = "samples"
|
44 |
os.makedirs(generated_folder, exist_ok=True)
|
45 |
|
46 |
# Настройка seed для воспроизводимости
|
47 |
-
|
|
|
48 |
torch.manual_seed(seed)
|
49 |
np.random.seed(seed)
|
50 |
random.seed(seed)
|
@@ -59,7 +62,7 @@ dtype = torch.bfloat16
|
|
59 |
accelerator = Accelerator(mixed_precision="bf16")
|
60 |
device = accelerator.device
|
61 |
gen = torch.Generator(device=device)
|
62 |
-
gen.manual_seed(
|
63 |
|
64 |
# --------------------------- Инициализация WandB ---------------------------
|
65 |
if use_wandb and accelerator.is_main_process:
|
@@ -234,29 +237,55 @@ if os.path.isdir(latest_checkpoint):
|
|
234 |
unet.set_use_memory_efficient_attention_xformers(True)
|
235 |
|
236 |
# --------------------------- Оптимизатор и кастомный LR scheduler ---------------------------
|
237 |
-
if
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
[p], # Каждый параметр получает свой оптимизатор
|
245 |
-
lr=base_learning_rate,
|
246 |
-
betas=(0.9, 0.999),
|
247 |
-
weight_decay=1e-5,
|
248 |
-
eps=1e-8
|
249 |
-
) for p in unet.parameters()
|
250 |
-
}
|
251 |
-
|
252 |
-
# [2] Опред��ляем hook для применения оптимизатора сразу после накопления градиента
|
253 |
-
def optimizer_hook(param):
|
254 |
-
optimizer_dict[param].step()
|
255 |
-
optimizer_dict[param].zero_grad(set_to_none=True)
|
256 |
-
|
257 |
-
# [3] Регистрируем hook для всех параметров модели
|
258 |
-
for param in unet.parameters():
|
259 |
-
param.register_post_accumulate_grad_hook(optimizer_hook)
|
260 |
else:
|
261 |
# Улучшенный AdamW с правильными параметрами
|
262 |
from optimi import StableAdamW, Lion
|
@@ -300,7 +329,7 @@ def custom_lr_lambda(step):
|
|
300 |
use_lr_decay) / base_learning_rate
|
301 |
|
302 |
# Подготовка через Accelerator
|
303 |
-
if
|
304 |
unet, optimizer = accelerator.prepare(unet, optimizer_dict)
|
305 |
else:
|
306 |
lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
|
@@ -396,9 +425,8 @@ def generate_and_save_samples(fixed_samples,step):
|
|
396 |
padded_img = ImageOps.pad(pil_img, (max_width, max_height), color='white')
|
397 |
|
398 |
all_generated_images.append(padded_img)
|
399 |
-
#all_generated_images.append(pil_img)
|
400 |
|
401 |
-
caption_text = sample_text[img_idx][:
|
402 |
all_captions.append(caption_text)
|
403 |
|
404 |
# Сохраняем с информацией о размере в имени файла
|
@@ -413,11 +441,15 @@ def generate_and_save_samples(fixed_samples,step):
|
|
413 |
]
|
414 |
wandb.log({"generated_images": wandb_images, "global_step": step})
|
415 |
|
416 |
-
finally:
|
417 |
# Гарантированное перемещение VAE обратно на CPU
|
418 |
vae.to("cpu")
|
419 |
if original_model is not None:
|
420 |
del original_model
|
|
|
|
|
|
|
|
|
421 |
torch.cuda.empty_cache()
|
422 |
gc.collect()
|
423 |
|
@@ -438,7 +470,7 @@ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local
|
|
438 |
|
439 |
# Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
|
440 |
steps_per_epoch = len(dataloader)
|
441 |
-
sample_interval = max(1, steps_per_epoch //
|
442 |
|
443 |
# Начинаем с указанной эпохи (полезно при возобновлении)
|
444 |
for epoch in range(start_epoch, start_epoch + num_epochs):
|
@@ -447,7 +479,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
447 |
|
448 |
for step, (latents, embeddings) in enumerate(dataloader):
|
449 |
with accelerator.accumulate(unet):
|
450 |
-
if step ==
|
451 |
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
452 |
print(f"Шаг {step}: {used_gb:.2f} GB")
|
453 |
# Forward pass
|
@@ -475,12 +507,10 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
475 |
# Делаем backward через Accelerator
|
476 |
accelerator.backward(loss)
|
477 |
|
478 |
-
|
479 |
-
if adam8bit:
|
480 |
-
grad_norm = accelerator.clip_grad_norm_(unet.parameters(), 1.0)
|
481 |
-
else:
|
482 |
optimizer.step()
|
483 |
lr_scheduler.step()
|
|
|
484 |
optimizer.zero_grad(set_to_none=True)
|
485 |
|
486 |
# Увеличиваем счетчик глобальных шагов
|
@@ -491,7 +521,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
491 |
|
492 |
# Логируем метрики
|
493 |
if accelerator.is_main_process:
|
494 |
-
if
|
495 |
current_lr = base_learning_rate
|
496 |
else:
|
497 |
current_lr = lr_scheduler.get_last_lr()[0]
|
@@ -499,17 +529,20 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
499 |
batch_losses.append(loss.detach().item())
|
500 |
|
501 |
# Логируем в Wandb
|
502 |
-
if
|
503 |
wandb.log({
|
504 |
"loss": loss.detach().item(),
|
505 |
"learning_rate": current_lr,
|
506 |
"epoch": epoch,
|
507 |
-
"grad_norm": grad_norm.item(),
|
508 |
"global_step": global_step
|
509 |
})
|
510 |
|
511 |
# Генерируем сэмплы с заданным интервалом
|
512 |
if global_step % sample_interval == 0:
|
|
|
|
|
|
|
513 |
generate_and_save_samples(fixed_samples,global_step)
|
514 |
|
515 |
# Выводим текущий лосс
|
@@ -520,13 +553,13 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
520 |
|
521 |
|
522 |
# По окончании эпохи
|
523 |
-
accelerator.wait_for_everyone()
|
524 |
# Сохраняем чекпоинт в конце каждой эпохи
|
525 |
if accelerator.is_main_process:
|
526 |
|
527 |
# Сохраняем UNet отдельно для удобства использования
|
528 |
-
if save_model:
|
529 |
-
|
530 |
avg_epoch_loss = np.mean(batch_losses)
|
531 |
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
|
532 |
if use_wandb:
|
|
|
18 |
from torch.distributed import broadcast_object_list
|
19 |
from torch.utils.checkpoint import checkpoint
|
20 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
21 |
+
from datetime import datetime
|
22 |
|
23 |
# --------------------------- Параметры ---------------------------
|
24 |
+
save_path = "datasets/576" #"datasets/576p2" #"datasets/1152p2" #"datasets/576p2" #"datasets/dataset384_temp" #"datasets/dataset384" #"datasets/imagenet-1kk" #"datasets/siski576" #"datasets/siski384" #"datasets/siski64" #"datasets/mnist"
|
25 |
+
batch_size = 45 #11 #45 #555 #35 #7
|
26 |
+
base_learning_rate = 1e-6 #2e-6 #1e-6 #2e-6 #6e-6 #2e-6 #8e-7 #6e-6 #2e-5 #4e-5 #3e-5 #5e-5 #8e-5
|
27 |
min_learning_rate = 2.5e-5 #2e-5
|
28 |
+
num_epochs = 4 #2 #36 #18
|
29 |
project = "sdxs"
|
30 |
use_wandb = True
|
|
|
31 |
save_model = True
|
32 |
+
adamw8bit = True
|
33 |
+
limit = 0 #200000 #0
|
34 |
checkpoints_folder = ""
|
35 |
+
lowram = True
|
36 |
use_lr_decay = False # отключить затухание
|
37 |
|
38 |
# Параметры для диффузии
|
39 |
n_diffusion_steps = 40
|
40 |
samples_to_generate = 12
|
41 |
guidance_scale = 5
|
42 |
+
sample_interval_share = 20
|
43 |
|
44 |
# Папки для сохранения результатов
|
45 |
generated_folder = "samples"
|
46 |
os.makedirs(generated_folder, exist_ok=True)
|
47 |
|
48 |
# Настройка seed для воспроизводимости
|
49 |
+
current_date = datetime.now()
|
50 |
+
seed = int(current_date.strftime("%Y%m%d"))
|
51 |
torch.manual_seed(seed)
|
52 |
np.random.seed(seed)
|
53 |
random.seed(seed)
|
|
|
62 |
accelerator = Accelerator(mixed_precision="bf16")
|
63 |
device = accelerator.device
|
64 |
gen = torch.Generator(device=device)
|
65 |
+
gen.manual_seed(seed)
|
66 |
|
67 |
# --------------------------- Инициализация WandB ---------------------------
|
68 |
if use_wandb and accelerator.is_main_process:
|
|
|
237 |
unet.set_use_memory_efficient_attention_xformers(True)
|
238 |
|
239 |
# --------------------------- Оптимизатор и кастомный LR scheduler ---------------------------
|
240 |
+
if lowram:
|
241 |
+
if adamw8bit:
|
242 |
+
# pip install bitsandbytes
|
243 |
+
import bitsandbytes as bnb
|
244 |
+
|
245 |
+
# [1] Создаем словарь оптимизаторов (fused backward)
|
246 |
+
optimizer_dict = {
|
247 |
+
p: bnb.optim.AdamW8bit(
|
248 |
+
[p], # Каждый параметр получает свой оптимизатор
|
249 |
+
lr=base_learning_rate,
|
250 |
+
betas=(0.9, 0.999),
|
251 |
+
weight_decay=1e-5,
|
252 |
+
eps=1e-8
|
253 |
+
) for p in unet.parameters()
|
254 |
+
}
|
255 |
+
|
256 |
+
# [2] Определяем hook для применения оптимизатора сразу после накопления градиента
|
257 |
+
def optimizer_hook(param):
|
258 |
+
optimizer_dict[param].step()
|
259 |
+
optimizer_dict[param].zero_grad(set_to_none=True)
|
260 |
+
|
261 |
+
# [3] Регистрируем hook для всех параметров модели
|
262 |
+
for param in unet.parameters():
|
263 |
+
param.register_post_accumulate_grad_hook(optimizer_hook)
|
264 |
+
|
265 |
+
else:
|
266 |
+
# Импортируем Adafactor из transformers
|
267 |
+
from transformers.optimization import Adafactor, AdafactorSchedule
|
268 |
+
|
269 |
+
# [1] Создаем словарь оптимизаторов (для каждого параметра)
|
270 |
+
base_learning_rate = 0
|
271 |
+
optimizer_dict = {
|
272 |
+
p: Adafactor(
|
273 |
+
[p],
|
274 |
+
relative_step=True, # Важно! Включает режим без явного LR
|
275 |
+
scale_parameter=True, # Автоматически масштабирует параметры
|
276 |
+
warmup_init=False, # Постепенно увеличивает шаги в начале
|
277 |
+
weight_decay=1e-5, # Можно оставить как есть
|
278 |
+
) for p in unet.parameters()
|
279 |
+
}
|
280 |
+
|
281 |
+
# [2] Определяем hook для применения оптимизатора сразу после накопления градиента
|
282 |
+
def optimizer_hook(param):
|
283 |
+
optimizer_dict[param].step()
|
284 |
+
optimizer_dict[param].zero_grad(set_to_none=True)
|
285 |
|
286 |
+
# [3] Регистрируем hook для всех параметров модели
|
287 |
+
for param in unet.parameters():
|
288 |
+
param.register_post_accumulate_grad_hook(optimizer_hook)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
else:
|
290 |
# Улучшенный AdamW с правильными параметрами
|
291 |
from optimi import StableAdamW, Lion
|
|
|
329 |
use_lr_decay) / base_learning_rate
|
330 |
|
331 |
# Подготовка через Accelerator
|
332 |
+
if lowram:
|
333 |
unet, optimizer = accelerator.prepare(unet, optimizer_dict)
|
334 |
else:
|
335 |
lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
|
|
|
425 |
padded_img = ImageOps.pad(pil_img, (max_width, max_height), color='white')
|
426 |
|
427 |
all_generated_images.append(padded_img)
|
|
|
428 |
|
429 |
+
caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
|
430 |
all_captions.append(caption_text)
|
431 |
|
432 |
# Сохраняем с информацией о размере в имени файла
|
|
|
441 |
]
|
442 |
wandb.log({"generated_images": wandb_images, "global_step": step})
|
443 |
|
444 |
+
finally:
|
445 |
# Гарантированное перемещение VAE обратно на CPU
|
446 |
vae.to("cpu")
|
447 |
if original_model is not None:
|
448 |
del original_model
|
449 |
+
# Очистка всех тензоров
|
450 |
+
for var in list(locals().keys()):
|
451 |
+
if isinstance(locals()[var], torch.Tensor):
|
452 |
+
del locals()[var]
|
453 |
torch.cuda.empty_cache()
|
454 |
gc.collect()
|
455 |
|
|
|
470 |
|
471 |
# Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
|
472 |
steps_per_epoch = len(dataloader)
|
473 |
+
sample_interval = max(1, steps_per_epoch // sample_interval_share)
|
474 |
|
475 |
# Начинаем с указанной эпохи (полезно при возобновлении)
|
476 |
for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
|
479 |
|
480 |
for step, (latents, embeddings) in enumerate(dataloader):
|
481 |
with accelerator.accumulate(unet):
|
482 |
+
if save_model == False and step == 3 :
|
483 |
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
484 |
print(f"Шаг {step}: {used_gb:.2f} GB")
|
485 |
# Forward pass
|
|
|
507 |
# Делаем backward через Accelerator
|
508 |
accelerator.backward(loss)
|
509 |
|
510 |
+
if not lowram:
|
|
|
|
|
|
|
511 |
optimizer.step()
|
512 |
lr_scheduler.step()
|
513 |
+
# Используем ограничение нормы градиентов через Accelerator
|
514 |
optimizer.zero_grad(set_to_none=True)
|
515 |
|
516 |
# Увеличиваем счетчик глобальных шагов
|
|
|
521 |
|
522 |
# Логируем метрики
|
523 |
if accelerator.is_main_process:
|
524 |
+
if lowram:
|
525 |
current_lr = base_learning_rate
|
526 |
else:
|
527 |
current_lr = lr_scheduler.get_last_lr()[0]
|
|
|
529 |
batch_losses.append(loss.detach().item())
|
530 |
|
531 |
# Логируем в Wandb
|
532 |
+
if use_wandb:
|
533 |
wandb.log({
|
534 |
"loss": loss.detach().item(),
|
535 |
"learning_rate": current_lr,
|
536 |
"epoch": epoch,
|
537 |
+
#"grad_norm": grad_norm.item(),
|
538 |
"global_step": global_step
|
539 |
})
|
540 |
|
541 |
# Генерируем сэмплы с заданным интервалом
|
542 |
if global_step % sample_interval == 0:
|
543 |
+
if save_model:
|
544 |
+
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
|
545 |
+
|
546 |
generate_and_save_samples(fixed_samples,global_step)
|
547 |
|
548 |
# Выводим текущий лосс
|
|
|
553 |
|
554 |
|
555 |
# По окончании эпохи
|
556 |
+
#accelerator.wait_for_everyone()
|
557 |
# Сохраняем чекпоинт в конце каждой эпохи
|
558 |
if accelerator.is_main_process:
|
559 |
|
560 |
# Сохраняем UNet отдельно для удобства использования
|
561 |
+
#if save_model:
|
562 |
+
# accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
|
563 |
avg_epoch_loss = np.mean(batch_losses)
|
564 |
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
|
565 |
if use_wandb:
|
unet/config.json
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNet2DConditionModel",
|
3 |
+
"_diffusers_version": "0.33.0.dev0",
|
4 |
+
"_name_or_path": "sdxs",
|
5 |
+
"act_fn": "silu",
|
6 |
+
"addition_embed_type": null,
|
7 |
+
"addition_embed_type_num_heads": 48,
|
8 |
+
"addition_time_embed_dim": null,
|
9 |
+
"attention_head_dim": 48,
|
10 |
+
"attention_type": "default",
|
11 |
+
"block_out_channels": [
|
12 |
+
384,
|
13 |
+
576,
|
14 |
+
768,
|
15 |
+
960
|
16 |
+
],
|
17 |
+
"center_input_sample": false,
|
18 |
+
"class_embed_type": null,
|
19 |
+
"class_embeddings_concat": false,
|
20 |
+
"conv_in_kernel": 3,
|
21 |
+
"conv_out_kernel": 3,
|
22 |
+
"cross_attention_dim": 1152,
|
23 |
+
"cross_attention_norm": null,
|
24 |
+
"down_block_types": [
|
25 |
+
"CrossAttnDownBlock2D",
|
26 |
+
"CrossAttnDownBlock2D",
|
27 |
+
"CrossAttnDownBlock2D",
|
28 |
+
"CrossAttnDownBlock2D"
|
29 |
+
],
|
30 |
+
"downsample_padding": 1,
|
31 |
+
"dropout": 0.1,
|
32 |
+
"dual_cross_attention": false,
|
33 |
+
"encoder_hid_dim": null,
|
34 |
+
"encoder_hid_dim_type": null,
|
35 |
+
"flip_sin_to_cos": true,
|
36 |
+
"freq_shift": 0,
|
37 |
+
"in_channels": 16,
|
38 |
+
"layers_per_block": [
|
39 |
+
2,
|
40 |
+
2,
|
41 |
+
2,
|
42 |
+
2
|
43 |
+
],
|
44 |
+
"mid_block_only_cross_attention": null,
|
45 |
+
"mid_block_scale_factor": 1.0,
|
46 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
47 |
+
"norm_eps": 1e-05,
|
48 |
+
"norm_num_groups": 16,
|
49 |
+
"num_attention_heads": null,
|
50 |
+
"num_class_embeds": null,
|
51 |
+
"only_cross_attention": false,
|
52 |
+
"out_channels": 16,
|
53 |
+
"projection_class_embeddings_input_dim": null,
|
54 |
+
"resnet_out_scale_factor": 1.0,
|
55 |
+
"resnet_skip_time_act": false,
|
56 |
+
"resnet_time_scale_shift": "default",
|
57 |
+
"reverse_transformer_layers_per_block": null,
|
58 |
+
"sample_size": 64,
|
59 |
+
"time_cond_proj_dim": null,
|
60 |
+
"time_embedding_act_fn": null,
|
61 |
+
"time_embedding_dim": null,
|
62 |
+
"time_embedding_type": "positional",
|
63 |
+
"timestep_post_act": null,
|
64 |
+
"transformer_layers_per_block": [
|
65 |
+
4,
|
66 |
+
6,
|
67 |
+
8,
|
68 |
+
10
|
69 |
+
],
|
70 |
+
"up_block_types": [
|
71 |
+
"CrossAttnUpBlock2D",
|
72 |
+
"CrossAttnUpBlock2D",
|
73 |
+
"CrossAttnUpBlock2D",
|
74 |
+
"CrossAttnUpBlock2D"
|
75 |
+
],
|
76 |
+
"upcast_attention": false,
|
77 |
+
"use_linear_projection": true
|
78 |
+
}
|
unet/diffusion_pytorch_model.fp16.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ddd8d0c467b18d515d1505d21905c36081896edbeb0813f8ffc92dd62d2405b7
|
3 |
+
size 4529095968
|
vae/config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "AutoencoderKL",
|
3 |
+
"_diffusers_version": "0.33.0.dev0",
|
4 |
+
"_name_or_path": "/home/recoilme/sdxs576/vae",
|
5 |
+
"act_fn": "silu",
|
6 |
+
"block_out_channels": [
|
7 |
+
128,
|
8 |
+
256,
|
9 |
+
512,
|
10 |
+
512
|
11 |
+
],
|
12 |
+
"down_block_types": [
|
13 |
+
"DownEncoderBlock2D",
|
14 |
+
"DownEncoderBlock2D",
|
15 |
+
"DownEncoderBlock2D",
|
16 |
+
"DownEncoderBlock2D"
|
17 |
+
],
|
18 |
+
"force_upcast": false,
|
19 |
+
"in_channels": 3,
|
20 |
+
"latent_channels": 16,
|
21 |
+
"latents_mean": null,
|
22 |
+
"latents_std": null,
|
23 |
+
"layers_per_block": 2,
|
24 |
+
"mid_block_add_attention": false,
|
25 |
+
"norm_num_groups": 32,
|
26 |
+
"out_channels": 3,
|
27 |
+
"sample_size": 1024,
|
28 |
+
"scaling_factor": 0.18215,
|
29 |
+
"shift_factor": 0,
|
30 |
+
"up_block_types": [
|
31 |
+
"UpDecoderBlock2D",
|
32 |
+
"UpDecoderBlock2D",
|
33 |
+
"UpDecoderBlock2D",
|
34 |
+
"UpDecoderBlock2D"
|
35 |
+
],
|
36 |
+
"use_post_quant_conv": true,
|
37 |
+
"use_quant_conv": true
|
38 |
+
}
|
vae/diffusion_pytorch_model.fp16.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a27d33cadf456598eee3cf5638ba3e8e01dd020a5f9c80c1a7b162bf97096701
|
3 |
+
size 163460798
|