Maykeye commited on
Commit
4d5aebb
·
1 Parent(s): 61e27ee

Initial commit

Browse files
README.md CHANGED
@@ -1,3 +1,31 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+ This is a first version of recreating roneneldan/TinyStories-1M but using Llama architecture.
5
+
6
+ * Full training process is included in the notebook train.ipynb. Recreating it as simple as downloading
7
+ TinyStoriesV2-GPT4-train.txt and TinyStoriesV2-GPT4-valid.txt in the same folder with the notebook and running
8
+ the cells. Validation content is not used by the script so you put anythin in
9
+
10
+ * Backup directory has a script do_backup that I used to copy weights from remote machine to local.
11
+ Weight are generated too quickly, so by the time script copied weihgt N+1
12
+
13
+ * This is extremely PoC version. Training truncates stories that are longer than context size and doesn't use
14
+ any sliding window to train story not from the start
15
+
16
+ * Training took approximately 9 hours (3 hours per epoch) on 40GB A100. ~30GB VRAM was used
17
+
18
+ * I use tokenizer from open_llama_3b. However I had troubles with it locally(https://github.com/openlm-research/open_llama/issues/69).
19
+ I had no troubles on the cloud machine with preninstalled libraries.
20
+
21
+ * Demo script is demo.py
22
+
23
+ * Validation script is provided: valid.py. use it like `python valid.py path/to/TinyStoriesV2-GPT4-valid.txt [optional-model-id-or-path]`:
24
+ After training I decided that it's not necessary to beat validation into chunks
25
+
26
+ * Also this version uses very stupid caching mechinsm to shuffle stories for training: it keeps cache of N recently loaded chunks
27
+ so if random shuffle asks for a story, it may use cache or load chunk.
28
+ Training dataset is too small, so in next versions I will get rid of it.
29
+
30
+
31
+ from transformers import AutoModelForCausalLM, AutoTokenizer
backup/do_backup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+ import time
4
+
5
+ copied = []
6
+ while True:
7
+ existing = [x.name for x in Path(".").glob("*.bin")]
8
+ copy_from = [x for x in Path("/home/fella/mnt/selectel/tiny-llama/").glob("*.bin")]
9
+ for file in copy_from:
10
+ if file.name not in existing:
11
+ print(file)
12
+ try:
13
+ shutil.copy(file, file.name)
14
+ copied.append(file.name)
15
+ if len(copied) > 6:
16
+ delete_me = copied.pop(0)
17
+ Path(delete_me).unlink()
18
+ except Exception as e:
19
+ print(f"Skipping {file.name}: {e}")
20
+ pass
21
+ time.sleep(15)
backup/step-1-10300.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6ec1c440a393eb0cce4c215de321efa232fb77a549cc85bd98e64a635ca28d9
3
+ size 9275989
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "bos_token_id": 1,
6
+ "eos_token_id": 2,
7
+ "hidden_act": "silu",
8
+ "hidden_size": 64,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 256,
11
+ "max_position_embeddings": 2048,
12
+ "model_type": "llama",
13
+ "num_attention_heads": 16,
14
+ "num_hidden_layers": 8,
15
+ "pad_token_id": 0,
16
+ "rms_norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "torch_dtype": "bfloat16",
19
+ "transformers_version": "4.30.2",
20
+ "use_cache": true,
21
+ "vocab_size": 32000
22
+ }
demo.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import sys
3
+ import os
4
+
5
+ model_id = os.getcwd()
6
+ if len(sys.argv) > 1:
7
+ model_id = sys.argv[1]
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(model_id).cuda().bfloat16()
11
+ prompt = "Lily picked up a flower."
12
+ inputs = tokenizer(prompt, return_tensors="pt", return_token_type_ids=False).to('cuda')
13
+ out = model.generate(**inputs, max_new_tokens=80).ravel()
14
+ out = tokenizer.decode(out)
15
+ print(out)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.30.2"
7
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae959aaff509d66f9dd85c53f16481463286950a21e0349c7793f8412fc4a094
3
+ size 9269994
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab1b681ec7fc02fed5edd3026687d7a692a918c4dd8e150ca2e3994a6229843b
3
+ size 534194
tokenizer_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "__type": "AddedToken",
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "clean_up_tokenization_spaces": false,
11
+ "eos_token": {
12
+ "__type": "AddedToken",
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "model_max_length": 2048,
20
+ "pad_token": null,
21
+ "sp_model_kwargs": {},
22
+ "tokenizer_class": "LlamaTokenizer",
23
+ "unk_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<unk>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
train.ipynb ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "f41486ad",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "NVIDIA A100-PCIE-40GB\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "# step 0. Preliminary\n",
19
+ "import torch\n",
20
+ "# check that cuda doesn't crash on us\n",
21
+ "print(torch.cuda.get_device_name())\n",
22
+ "# check that transformers installed\n",
23
+ "import transformers"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 2,
29
+ "id": "ffd19cfb",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "EPOCHS=3"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 3,
39
+ "id": "3a91ef1f",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "# Step 1. Preparing the training\n",
44
+ "# First ensure that required files are here\n",
45
+ "from pathlib import Path\n",
46
+ "assert Path(\"TinyStoriesV2-GPT4-train.txt\").exists()\n",
47
+ "assert Path(\"TinyStoriesV2-GPT4-valid.txt\").exists()"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 4,
53
+ "id": "56b046d5",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# Then prepare directories\n",
58
+ "Path(\"chunks.txt/train\").mkdir(parents=True, exist_ok=True)\n",
59
+ "Path(\"chunks.tensors/train\").mkdir(parents=True, exist_ok=True)\n",
60
+ "Path(\"chunks.txt/valid\").mkdir(parents=True, exist_ok=True)\n",
61
+ "Path(\"chunks.tensors/valid\").mkdir(parents=True, exist_ok=True)"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 5,
67
+ "id": "1bddb2ee",
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "# Then prepare method to split one text to several\n",
72
+ "from multiprocessing.pool import Pool\n",
73
+ "from tqdm.contrib.concurrent import process_map\n",
74
+ "import os\n",
75
+ "_chunk_me = None\n",
76
+ "def extract_chunk(chunk):\n",
77
+ " split, i, chunk_from, chunk_to = chunk\n",
78
+ " chunk = _chunk_me[chunk_from:chunk_to].strip() \n",
79
+ " name = f\"chunks.txt/{split}/chunk-{i+1}.txt\"\n",
80
+ " with open(name, \"w\") as f:\n",
81
+ " f.write(chunk)\n",
82
+ " return name\n",
83
+ "\n",
84
+ "def split_to_text_chunks(split:str, chunk_size = 16*1024*1024, max_workers=None):\n",
85
+ " global _chunk_me #text is too chunky to pass as argument. storing as global so fork() can take care of it\n",
86
+ " print(f\"reading {split}\")\n",
87
+ " text = _chunk_me = Path(f\"./TinyStoriesV2-GPT4-{split}.txt\").read_text()\n",
88
+ " offsets = [] \n",
89
+ " delimiter = \"<|endoftext|>\"\n",
90
+ " i=0\n",
91
+ " while i < len(text): \n",
92
+ " offsets.append(i)\n",
93
+ " i += chunk_size\n",
94
+ " i = text.find(delimiter, i)\n",
95
+ " if i < 0:\n",
96
+ " break\n",
97
+ " i += len(delimiter)\n",
98
+ " offsets.append(len(text))\n",
99
+ " chunks = [(split, i, start,end) for (i, (start, end)) in enumerate(zip(offsets[:-1], offsets[1:]))]\n",
100
+ " \n",
101
+ " print(\"writing\")\n",
102
+ " process_map(extract_chunk, chunks, max_workers=max_workers)\n",
103
+ " "
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 7,
109
+ "id": "e60017ee",
110
+ "metadata": {},
111
+ "outputs": [
112
+ {
113
+ "name": "stdout",
114
+ "output_type": "stream",
115
+ "text": [
116
+ "Assuming split has finished already\n"
117
+ ]
118
+ }
119
+ ],
120
+ "source": [
121
+ "# Prepare text of train split\n",
122
+ "if not Path(\"chunks.txt/train/chunk-133.txt\").exists():\n",
123
+ " split_to_text_chunks(\"train\")\n",
124
+ "else:\n",
125
+ " print(\"Assuming split has finished already\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 9,
131
+ "id": "e9b7effe",
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "name": "stdout",
136
+ "output_type": "stream",
137
+ "text": [
138
+ "Assuming split has finished already\n"
139
+ ]
140
+ }
141
+ ],
142
+ "source": [
143
+ "# Prepare text of valid split\n",
144
+ "if not Path(\"chunks.txt/valid/chunk-2.txt\").exists():\n",
145
+ " split_to_text_chunks(\"valid\") \n",
146
+ "else:\n",
147
+ " print(\"Assuming split has finished already\")"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 10,
153
+ "id": "b4706f24",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "# Step 2. Prepare OpenLLAMA tokenizer. \n",
158
+ "#Needed to be done once(TODO: add code to load tokenizer?)\n",
159
+ "from transformers import AutoTokenizer\n",
160
+ "import os\n",
161
+ "if not Path('tokenizer.json').exists(): \n",
162
+ " try:\n",
163
+ " tokenizer = AutoTokenizer.from_pretrained(\"openlm-research/open_llama_3b\")\n",
164
+ " except:\n",
165
+ " os.environ[\"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION\"]=\"python\" \n",
166
+ " tokenizer = AutoTokenizer.from_pretrained(\"openlm-research/open_llama_3b\")\n",
167
+ " tokenizer.save_pretrained(\".\")\n",
168
+ " del os.environ[\"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION\"]\n",
169
+ "tokenizer = AutoTokenizer.from_pretrained(\".\")"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 11,
175
+ "id": "f9c935b0",
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "# Step 3. Preparing to tokenize each text chunk\n",
180
+ "from tqdm.contrib.concurrent import process_map\n",
181
+ "def tokenize_file(filename:Path):\n",
182
+ " text = Path.read_text(filename)\n",
183
+ " stories = text.split(\"<|endoftext|>\")\n",
184
+ " result = []\n",
185
+ " while stories:\n",
186
+ " story = stories.pop(0).strip()\n",
187
+ " tokenized = tokenizer(story, max_length=None).input_ids\n",
188
+ " tokenized.append(tokenizer.eos_token_id)\n",
189
+ " result.append(torch.tensor(tokenized))\n",
190
+ " output_name = str(filename).replace(\".txt\", \".tensors\")\n",
191
+ " torch.save(result, output_name)\n",
192
+ "\n",
193
+ "def tokenize_split(split, max_workers=None):\n",
194
+ " to_process = list(Path(f\"chunks.txt/{split}\").glob(\"*\")) \n",
195
+ " process_map(tokenize_file, to_process, max_workers=max_workers)\n",
196
+ " "
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": 12,
202
+ "id": "95257f12",
203
+ "metadata": {},
204
+ "outputs": [
205
+ {
206
+ "name": "stdout",
207
+ "output_type": "stream",
208
+ "text": [
209
+ "Assuming train was tokenized already\n"
210
+ ]
211
+ }
212
+ ],
213
+ "source": [
214
+ "# processing train(this can take several minutes)\n",
215
+ "if not Path(\"chunks.tensors/train/chunk-133.tensors\").exists():\n",
216
+ " tokenize_split(\"train\")\n",
217
+ "else:\n",
218
+ " print(\"Assuming train was tokenized already\")"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 13,
224
+ "id": "bbbe4599",
225
+ "metadata": {},
226
+ "outputs": [
227
+ {
228
+ "name": "stdout",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "Assuming valid was tokenized already\n"
232
+ ]
233
+ }
234
+ ],
235
+ "source": [
236
+ "# processing valid(this can take one minutes)\n",
237
+ "if not Path(\"chunks.tensors/valid/chunk-2.tensors\").exists():\n",
238
+ " tokenize_split(\"valid\")\n",
239
+ "else:\n",
240
+ " print(\"Assuming valid was tokenized already\")"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 14,
246
+ "id": "a31a4aa7",
247
+ "metadata": {},
248
+ "outputs": [
249
+ {
250
+ "name": "stdout",
251
+ "output_type": "stream",
252
+ "text": [
253
+ "Resetting [PAD] to [EOS]\n"
254
+ ]
255
+ }
256
+ ],
257
+ "source": [
258
+ "# Step 4. Training. \n",
259
+ "# Step 4.1 Preparing tokenizer and setting pad token if it is not set(it is not set)\n",
260
+ "tokenizer = AutoTokenizer.from_pretrained(\".\")\n",
261
+ "if not tokenizer.pad_token_id:\n",
262
+ " tokenizer.pad_token_id = tokenizer.eos_token_id\n",
263
+ " print(\"Resetting [PAD] to [EOS]\")"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": 18,
269
+ "id": "f677c9c0",
270
+ "metadata": {
271
+ "scrolled": true
272
+ },
273
+ "outputs": [],
274
+ "source": [
275
+ "# Step 4.2. Preparing model\n",
276
+ "from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM\n",
277
+ "\n",
278
+ "tiny_llama = LlamaConfig(\n",
279
+ " hidden_size=64, \n",
280
+ " vocab_size=tokenizer.vocab_size,\n",
281
+ " intermediate_size=256, \n",
282
+ " num_attention_heads=16, \n",
283
+ " num_hidden_layers=8)\n",
284
+ "\n",
285
+ "torch.manual_seed(11010)\n",
286
+ "torch.cuda.manual_seed(11010)\n",
287
+ "model = LlamaForCausalLM(tiny_llama).cuda().bfloat16()"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 16,
293
+ "id": "aad9620b",
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "import functools\n",
298
+ "import torch.nn.functional as F\n",
299
+ "from tqdm.contrib.concurrent import process_map\n",
300
+ "from tqdm.auto import tqdm\n",
301
+ "\n",
302
+ "# Step 4.3 Preparing dataset class\n",
303
+ "def get_file_data_len(filename):\n",
304
+ " data = torch.load(filename)\n",
305
+ " return (filename, len(data))\n",
306
+ "from datasets import Dataset\n",
307
+ "\n",
308
+ "CACHE_SIZE = 2000 # There are ~150 train splits. We can fit them in memory, so let's do it\n",
309
+ "\n",
310
+ "class TinyDataset(Dataset):\n",
311
+ " def __init__(self, split: str, populate_cache=True):\n",
312
+ " print(f\"Reading dataset {split} data\")\n",
313
+ " self.file_lens = process_map(\n",
314
+ " get_file_data_len,\n",
315
+ " list(Path(f\"chunks.tensors/{split}\").glob(\"*\")))\n",
316
+ " self.file_lens.sort()\n",
317
+ " if populate_cache:\n",
318
+ " print(\"Populating a cache\")\n",
319
+ " for filename, _ in tqdm(self.file_lens):\n",
320
+ " self.load_tensor_file(filename)\n",
321
+ "\n",
322
+ " @functools.lru_cache(maxsize=CACHE_SIZE)\n",
323
+ " def load_tensor_file(self, filename):\n",
324
+ " return torch.load(filename)\n",
325
+ "\n",
326
+ " def __len__(self):\n",
327
+ " return sum(x[1] for x in self.file_lens)\n",
328
+ "\n",
329
+ " def global_index_to_local(self, i):\n",
330
+ " for (file, length) in self.file_lens:\n",
331
+ " if i < length:\n",
332
+ " return (file, i)\n",
333
+ " i -= length\n",
334
+ " raise IndexError(f\"{i} is out-of-bonds, have {len(self)} sample\")\n",
335
+ "\n",
336
+ " def __getitem__(self, index):\n",
337
+ " if torch.is_tensor(index):\n",
338
+ " index = index.tolist()\n",
339
+ " if isinstance(index, int):\n",
340
+ " filename, local_index = self.global_index_to_local(index)\n",
341
+ " tensors = self.load_tensor_file(filename)\n",
342
+ " return {\n",
343
+ " 'input_ids': tensors[local_index]\n",
344
+ " }\n",
345
+ " if isinstance(index, list):\n",
346
+ " data = []\n",
347
+ " indices = index\n",
348
+ " for index in indices:\n",
349
+ " filename, local_index = self.global_index_to_local(index)\n",
350
+ " tensors = self.load_tensor_file(filename)\n",
351
+ " data.append(tensors[local_index])\n",
352
+ "\n",
353
+ " return {'input_ids': data}\n",
354
+ "\n",
355
+ " raise TypeError(f'Invaldi index type {type(index)}')\n",
356
+ " \n",
357
+ "def batch_collate(data: list[torch.Tensor]):\n",
358
+ " max_len = max(len(datum[\"input_ids\"]) for datum in data)\n",
359
+ " inputs = []\n",
360
+ " attentions = []\n",
361
+ " for row in data:\n",
362
+ " input_ids = row[\"input_ids\"]\n",
363
+ " attention_mask = torch.ones_like(input_ids)\n",
364
+ " attention_mask[-1] = 0 # don't care about EOS\n",
365
+ " # Manual padding\n",
366
+ " to_pad = max_len - len(input_ids)\n",
367
+ " is_left_pad = tokenizer.padding_side == \"left\"\n",
368
+ " padding = (is_left_pad * to_pad, (1 - is_left_pad) * to_pad)\n",
369
+ " input_ids = F.pad(input_ids, padding, value=tokenizer.pad_token_id)\n",
370
+ " attention_mask = F.pad(attention_mask, padding, value=0)\n",
371
+ " inputs.append(input_ids)\n",
372
+ " attentions.append(attention_mask)\n",
373
+ "\n",
374
+ " attention_masks = torch.stack(attentions)\n",
375
+ " input_ids = torch.stack(inputs)\n",
376
+ " labels = input_ids.clone()\n",
377
+ "\n",
378
+ " # disable prediction of the padding\n",
379
+ " labels[attention_masks == 0] = -100\n",
380
+ " # enable prediction of an actual EOS\n",
381
+ " labels[:, -1] = tokenizer.eos_token_id\n",
382
+ "\n",
383
+ " return {\n",
384
+ " 'input_ids': input_ids,\n",
385
+ " 'attention_mask': attention_masks,\n",
386
+ " 'labels': labels\n",
387
+ " }\n",
388
+ "\n",
389
+ "def get_max_story_length(ds): \n",
390
+ " return max(file_len[1] for file_len in ds.file_lens)\n"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "execution_count": 17,
396
+ "id": "2e828afe",
397
+ "metadata": {},
398
+ "outputs": [
399
+ {
400
+ "name": "stdout",
401
+ "output_type": "stream",
402
+ "text": [
403
+ "Reading dataset train data\n"
404
+ ]
405
+ },
406
+ {
407
+ "data": {
408
+ "application/vnd.jupyter.widget-view+json": {
409
+ "model_id": "8ca542afc1694073af6dcf9ce5f7e13a",
410
+ "version_major": 2,
411
+ "version_minor": 0
412
+ },
413
+ "text/plain": [
414
+ " 0%| | 0/133 [00:00<?, ?it/s]"
415
+ ]
416
+ },
417
+ "metadata": {},
418
+ "output_type": "display_data"
419
+ },
420
+ {
421
+ "name": "stdout",
422
+ "output_type": "stream",
423
+ "text": [
424
+ "Populating a cache\n"
425
+ ]
426
+ },
427
+ {
428
+ "data": {
429
+ "application/vnd.jupyter.widget-view+json": {
430
+ "model_id": "8035a75107e84a54870a8c6f15c4100a",
431
+ "version_major": 2,
432
+ "version_minor": 0
433
+ },
434
+ "text/plain": [
435
+ " 0%| | 0/133 [00:00<?, ?it/s]"
436
+ ]
437
+ },
438
+ "metadata": {},
439
+ "output_type": "display_data"
440
+ },
441
+ {
442
+ "ename": "AssertionError",
443
+ "evalue": "WARNIING: split long stories",
444
+ "output_type": "error",
445
+ "traceback": [
446
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
447
+ "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
448
+ "Cell \u001b[0;32mIn[17], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m tokenizer\u001b[38;5;241m.\u001b[39mpadding_side \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mleft\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mright\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 2\u001b[0m train_ds \u001b[38;5;241m=\u001b[39m TinyDataset(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m get_max_story_length(train_ds) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mmodel_max_length, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWARNIING: split long stories\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
449
+ "\u001b[0;31mAssertionError\u001b[0m: WARNIING: split long stories"
450
+ ]
451
+ }
452
+ ],
453
+ "source": [
454
+ "assert tokenizer.padding_side in [\"left\", \"right\"]\n",
455
+ "train_ds = TinyDataset(\"train\")\n",
456
+ "assert get_max_story_length(train_ds) <= tokenizer.model_max_length, \"WARNIING: split long stories\""
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": 19,
462
+ "id": "6412e7c5",
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": [
466
+ "from torch.utils.data import DataLoader\n",
467
+ "torch.manual_seed(11010)\n",
468
+ "torch.cuda.manual_seed(11010)\n",
469
+ "train_dl = DataLoader(train_ds, 16, True, collate_fn=batch_collate)"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "code",
474
+ "execution_count": 20,
475
+ "id": "f3ff5a66",
476
+ "metadata": {},
477
+ "outputs": [
478
+ {
479
+ "name": "stderr",
480
+ "output_type": "stream",
481
+ "text": [
482
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mggg4\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
483
+ ]
484
+ },
485
+ {
486
+ "data": {
487
+ "text/html": [
488
+ "Tracking run with wandb version 0.15.5"
489
+ ],
490
+ "text/plain": [
491
+ "<IPython.core.display.HTML object>"
492
+ ]
493
+ },
494
+ "metadata": {},
495
+ "output_type": "display_data"
496
+ },
497
+ {
498
+ "data": {
499
+ "text/html": [
500
+ "Run data is saved locally in <code>/home/mayk/tiny-llama/wandb/run-20230707_181234-rilt4m6f</code>"
501
+ ],
502
+ "text/plain": [
503
+ "<IPython.core.display.HTML object>"
504
+ ]
505
+ },
506
+ "metadata": {},
507
+ "output_type": "display_data"
508
+ },
509
+ {
510
+ "data": {
511
+ "text/html": [
512
+ "Syncing run <strong><a href='https://wandb.ai/ggg4/training-tiny-llama-preview/runs/rilt4m6f' target=\"_blank\">grateful-jazz-4</a></strong> to <a href='https://wandb.ai/ggg4/training-tiny-llama-preview' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
513
+ ],
514
+ "text/plain": [
515
+ "<IPython.core.display.HTML object>"
516
+ ]
517
+ },
518
+ "metadata": {},
519
+ "output_type": "display_data"
520
+ },
521
+ {
522
+ "data": {
523
+ "text/html": [
524
+ " View project at <a href='https://wandb.ai/ggg4/training-tiny-llama-preview' target=\"_blank\">https://wandb.ai/ggg4/training-tiny-llama-preview</a>"
525
+ ],
526
+ "text/plain": [
527
+ "<IPython.core.display.HTML object>"
528
+ ]
529
+ },
530
+ "metadata": {},
531
+ "output_type": "display_data"
532
+ },
533
+ {
534
+ "data": {
535
+ "text/html": [
536
+ " View run at <a href='https://wandb.ai/ggg4/training-tiny-llama-preview/runs/rilt4m6f' target=\"_blank\">https://wandb.ai/ggg4/training-tiny-llama-preview/runs/rilt4m6f</a>"
537
+ ],
538
+ "text/plain": [
539
+ "<IPython.core.display.HTML object>"
540
+ ]
541
+ },
542
+ "metadata": {},
543
+ "output_type": "display_data"
544
+ },
545
+ {
546
+ "data": {
547
+ "text/html": [
548
+ "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/ggg4/training-tiny-llama-preview/runs/rilt4m6f?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
549
+ ],
550
+ "text/plain": [
551
+ "<wandb.sdk.wandb_run.Run at 0x7f6af8170b50>"
552
+ ]
553
+ },
554
+ "execution_count": 20,
555
+ "metadata": {},
556
+ "output_type": "execute_result"
557
+ }
558
+ ],
559
+ "source": [
560
+ "# prepare wandb\n",
561
+ "import wandb\n",
562
+ "wandb.init(\n",
563
+ " project=\"training-tiny-llama-preview\",\n",
564
+ " config={\n",
565
+ " \"architecture\": \"llama\",\n",
566
+ " \"dataset\": \"tiny-stories\",\n",
567
+ " \"epochs\": EPOCHS,\n",
568
+ " } \n",
569
+ ")"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "id": "aed7b7a4",
576
+ "metadata": {},
577
+ "outputs": [],
578
+ "source": []
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 21,
583
+ "id": "166a4a27",
584
+ "metadata": {},
585
+ "outputs": [],
586
+ "source": [
587
+ "from tqdm.auto import tqdm\n",
588
+ "def save_imm(epoch, step, saved=[]):\n",
589
+ " fname = f\"step-{epoch}-{step}.bin\"\n",
590
+ " torch.save(model.state_dict(), f\"step-{epoch}-{step}.bin\")\n",
591
+ " saved.append(fname)\n",
592
+ " if len(saved) > 5:\n",
593
+ " delete_me = saved.pop(0)\n",
594
+ " Path(delete_me).unlink(missing_ok=True)\n",
595
+ "\n",
596
+ "def epoch_step(epoch, opt):\n",
597
+ " for i, batch in enumerate(bar := tqdm(train_dl)):\n",
598
+ " for k in batch:\n",
599
+ " batch[k] = batch[k].to(device=model.lm_head.weight.device)\n",
600
+ " \n",
601
+ " n_batch, n_seq = batch[\"input_ids\"].shape\n",
602
+ " if n_seq > tokenizer.model_max_length:\n",
603
+ " assert tokenizer.padding_side == \"right\", \"Left-pad truncation only supported[as model should not see >2k token anyway]\"\n",
604
+ " batch[\"input_ids\"] = batch[\"input_ids\"][:, -tokenizer.model_max_length]\n",
605
+ " batch[\"labels\"] = batch[\"labels\"][:, -tokenizer.model_max_length]\n",
606
+ " batch[\"attention_mask\"] = batch[\"attention_mask\"][:, -tokenizer.model_max_length]\n",
607
+ " \n",
608
+ " \n",
609
+ " loss = model(**batch).loss\n",
610
+ " loss.backward()\n",
611
+ " opt.step()\n",
612
+ " opt.zero_grad()\n",
613
+ " bar.set_description(f'L:{loss.item():.4f}')\n",
614
+ " wandb.log({\"loss\": loss.item()})\n",
615
+ " if (i+1) % 100 == 0:\n",
616
+ " save_imm(epoch, i+1)\n",
617
+ " \n",
618
+ " torch.save(model.state_dict(), f\"epoch-{epoch}.bin\")\n"
619
+ ]
620
+ },
621
+ {
622
+ "cell_type": "code",
623
+ "execution_count": 22,
624
+ "id": "ec4943c7",
625
+ "metadata": {},
626
+ "outputs": [],
627
+ "source": [
628
+ "opt = torch.optim.AdamW(model.parameters(), fused=True)\n"
629
+ ]
630
+ },
631
+ {
632
+ "cell_type": "code",
633
+ "execution_count": null,
634
+ "id": "daae9020",
635
+ "metadata": {},
636
+ "outputs": [
637
+ {
638
+ "data": {
639
+ "application/vnd.jupyter.widget-view+json": {
640
+ "model_id": "f7ab6fe3b99546f49acb0d43888b7ceb",
641
+ "version_major": 2,
642
+ "version_minor": 0
643
+ },
644
+ "text/plain": [
645
+ " 0%| | 0/169865 [00:00<?, ?it/s]"
646
+ ]
647
+ },
648
+ "metadata": {},
649
+ "output_type": "display_data"
650
+ }
651
+ ],
652
+ "source": [
653
+ "for e in range(EPOCHS):\n",
654
+ " epoch_step(e+1, opt)"
655
+ ]
656
+ },
657
+ {
658
+ "cell_type": "code",
659
+ "execution_count": 45,
660
+ "id": "87988cf5",
661
+ "metadata": {},
662
+ "outputs": [
663
+ {
664
+ "name": "stdout",
665
+ "output_type": "stream",
666
+ "text": [
667
+ " total used free shared buff/cache available\r\n",
668
+ "Mem: 85Gi 1.5Gi 72Gi 8.0Mi 11Gi 83Gi\r\n",
669
+ "Swap: 0B 0B 0B\r\n"
670
+ ]
671
+ }
672
+ ],
673
+ "source": [
674
+ "!free -h"
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "code",
679
+ "execution_count": 65,
680
+ "id": "e43eb9f3",
681
+ "metadata": {},
682
+ "outputs": [
683
+ {
684
+ "name": "stdout",
685
+ "output_type": "stream",
686
+ "text": [
687
+ "Fri Jul 7 17:44:05 2023 \n",
688
+ "+-----------------------------------------------------------------------------+\n",
689
+ "| NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |\n",
690
+ "|-------------------------------+----------------------+----------------------+\n",
691
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
692
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
693
+ "| | | MIG M. |\n",
694
+ "|===============================+======================+======================|\n",
695
+ "| 0 NVIDIA A100-PCI... On | 00000000:05:00.0 Off | 0 |\n",
696
+ "| N/A 30C P0 34W / 250W | 5739MiB / 40960MiB | 0% Default |\n",
697
+ "| | | Disabled |\n",
698
+ "+-------------------------------+----------------------+----------------------+\n",
699
+ " \n",
700
+ "+-----------------------------------------------------------------------------+\n",
701
+ "| Processes: |\n",
702
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
703
+ "| ID ID Usage |\n",
704
+ "|=============================================================================|\n",
705
+ "| 0 N/A N/A 13768 C /opt/conda/bin/python 5736MiB |\n",
706
+ "+-----------------------------------------------------------------------------+\n"
707
+ ]
708
+ }
709
+ ],
710
+ "source": [
711
+ "!nvidia-smi"
712
+ ]
713
+ },
714
+ {
715
+ "cell_type": "code",
716
+ "execution_count": 73,
717
+ "id": "0351f57f",
718
+ "metadata": {},
719
+ "outputs": [
720
+ {
721
+ "data": {
722
+ "text/plain": [
723
+ "Parameter containing:\n",
724
+ "tensor([[ 8.3618e-03, 3.8330e-02, -5.9204e-03, ..., 2.0752e-02,\n",
725
+ " 4.4861e-03, 1.2512e-02],\n",
726
+ " [ 3.9978e-03, 2.1118e-02, -3.5645e-02, ..., -1.6846e-02,\n",
727
+ " 5.0659e-03, -3.8818e-02],\n",
728
+ " [-1.6928e-05, -1.2756e-02, -1.1536e-02, ..., -1.6235e-02,\n",
729
+ " 4.8218e-03, -1.4099e-02],\n",
730
+ " ...,\n",
731
+ " [-9.8267e-03, -6.8665e-03, 1.0864e-02, ..., -1.0864e-02,\n",
732
+ " -2.4170e-02, -5.6076e-04],\n",
733
+ " [-9.5749e-04, 7.3853e-03, 4.9438e-03, ..., 1.2390e-02,\n",
734
+ " -2.1606e-02, -9.2163e-03],\n",
735
+ " [ 5.1758e-02, 2.1484e-02, -1.5381e-02, ..., -2.4292e-02,\n",
736
+ " -3.4912e-02, 3.0823e-03]], device='cuda:0', dtype=torch.bfloat16,\n",
737
+ " requires_grad=True)"
738
+ ]
739
+ },
740
+ "execution_count": 73,
741
+ "metadata": {},
742
+ "output_type": "execute_result"
743
+ }
744
+ ],
745
+ "source": []
746
+ },
747
+ {
748
+ "cell_type": "code",
749
+ "execution_count": null,
750
+ "id": "ace72db5",
751
+ "metadata": {},
752
+ "outputs": [],
753
+ "source": []
754
+ }
755
+ ],
756
+ "metadata": {
757
+ "kernelspec": {
758
+ "display_name": "Python 3 (ipykernel)",
759
+ "language": "python",
760
+ "name": "python3"
761
+ },
762
+ "language_info": {
763
+ "codemirror_mode": {
764
+ "name": "ipython",
765
+ "version": 3
766
+ },
767
+ "file_extension": ".py",
768
+ "mimetype": "text/x-python",
769
+ "name": "python",
770
+ "nbconvert_exporter": "python",
771
+ "pygments_lexer": "ipython3",
772
+ "version": "3.10.10"
773
+ }
774
+ },
775
+ "nbformat": 4,
776
+ "nbformat_minor": 5
777
+ }
valid.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import sys
4
+ import os
5
+ from pathlib import Path
6
+ from tqdm.auto import tqdm
7
+
8
+ model_id = os.getcwd()
9
+ if len(sys.argv) == 2:
10
+ filename = sys.argv[1]
11
+ elif len(sys.argv) == 3:
12
+ filename = sys.argv[1]
13
+ model_id = sys.argv[2]
14
+ else:
15
+ raise Exception("use valid.py <path-to-text> [model-id]")
16
+
17
+ text = Path(filename).read_text()
18
+ stories = text.split("<|endoftext|>")
19
+ print(len(stories))
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+ model = AutoModelForCausalLM.from_pretrained(model_id).cuda().bfloat16()
22
+
23
+ ctx_size = tokenizer.model_max_length
24
+ sliding_window = ctx_size // 2
25
+
26
+ total_loss = 0.0
27
+ measurements = 0
28
+ model.eval()
29
+ for story in (bar := tqdm(stories)):
30
+ story = story.strip()
31
+ tokens = tokenizer(story, add_special_tokens=False).input_ids + [tokenizer.eos_token_id]
32
+ i = 0
33
+ while i < len(tokens):
34
+ current_window = tokens[i:i+ctx_size-1]
35
+ part_tokens = [tokenizer.bos_token_id] + current_window
36
+ input_ids = torch.tensor(part_tokens, device="cuda")[None]
37
+ labels = input_ids.clone()
38
+ if i:
39
+ # disable seen tokens
40
+ labels[:, :-sliding_window] = -100
41
+
42
+ with torch.no_grad():
43
+ loss = model(input_ids, labels=labels).loss
44
+ total_loss += loss.item()
45
+ measurements += 1
46
+
47
+ i += len(current_window)
48
+ bar.set_description(f"L {total_loss/measurements:.4f}")
49
+
50
+ print(f"FINAL LOSS: {total_loss/measurements:.4f}")
51
+
52
+