johnowhitaker commited on
Commit
9c23187
·
verified ·
1 Parent(s): d8382f2

Upload MB_dLLM_sample.ipynb

Browse files
Files changed (1) hide show
  1. MB_dLLM_sample.ipynb +244 -0
MB_dLLM_sample.ipynb ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "metadata": {
23
+ "id": "Q4qAMMPkQhfY"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "import os, random, itertools, math, torch\n",
28
+ "from torch.utils.data import DataLoader\n",
29
+ "from transformers import (\n",
30
+ " AutoTokenizer, AutoModelForMaskedLM,\n",
31
+ " get_cosine_schedule_with_warmup\n",
32
+ ")\n",
33
+ "from torch.optim import AdamW\n",
34
+ "from datasets import load_dataset\n",
35
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "source": [
41
+ "model_id = \"johnowhitaker/modernbert-diffusion\"\n",
42
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
43
+ "SEP_ID, CLS_ID, MASK_ID = tokenizer.sep_token_id, tokenizer.cls_token_id, tokenizer.mask_token_id\n",
44
+ "model = AutoModelForMaskedLM.from_pretrained(model_id, device_map=device)\n",
45
+ "model.eval();"
46
+ ],
47
+ "metadata": {
48
+ "id": "e4kbDTS3Qo_a"
49
+ },
50
+ "execution_count": null,
51
+ "outputs": []
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "source": [
56
+ "# Single forward pass:\n",
57
+ "prompt = \"User: Which is the best programming language? \" + tokenizer.sep_token + \" Assistant:\"\n",
58
+ "prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)\n",
59
+ "ans_len = 12\n",
60
+ "ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]\n",
61
+ "with torch.no_grad():\n",
62
+ " outs = model(input_ids=torch.tensor([ids]).to(device)).logits\n",
63
+ "print(outs.shape)\n",
64
+ "out_ids = outs[0].argmax(dim=-1).tolist()\n",
65
+ "print(tokenizer.decode(out_ids))"
66
+ ],
67
+ "metadata": {
68
+ "colab": {
69
+ "base_uri": "https://localhost:8080/"
70
+ },
71
+ "id": "Y7ZwaE3IQzJT",
72
+ "outputId": "bd8a6d10-41c3-4531-d244-32094e71b1d3"
73
+ },
74
+ "execution_count": 3,
75
+ "outputs": [
76
+ {
77
+ "output_type": "stream",
78
+ "name": "stdout",
79
+ "text": [
80
+ "torch.Size([1, 28, 50368])\n",
81
+ "[CLS]User: Which is the best programming language? \n",
82
+ " Assistant: Python, Python,,,,,, is Python..[SEP]\n"
83
+ ]
84
+ }
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "source": [
90
+ "# In a loop, keeping the most confident\n",
91
+ "prompt = \"User: Which is the best programming language? \" + tokenizer.sep_token + \" Assistant:\"\n",
92
+ "prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)\n",
93
+ "ans_len = 32\n",
94
+ "ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]\n",
95
+ "for i in range(ans_len):\n",
96
+ " with torch.no_grad():\n",
97
+ " outs = model(input_ids=torch.tensor([ids]).to(device)).logits\n",
98
+ " out_probs = torch.softmax(outs[0], dim=-1)\n",
99
+ " mask_locs = (torch.tensor(ids) == MASK_ID).nonzero(as_tuple=True)[0]\n",
100
+ " new_probs = torch.zeros_like(out_probs)\n",
101
+ " new_probs[mask_locs] = out_probs[mask_locs]\n",
102
+ " max_probs, max_locs = new_probs.max(dim=-1)\n",
103
+ " max_loc = max_probs.argmax(dim=-1)\n",
104
+ " ids[max_loc] = new_probs[max_loc].argmax().item()\n",
105
+ "print(tokenizer.decode(ids))"
106
+ ],
107
+ "metadata": {
108
+ "colab": {
109
+ "base_uri": "https://localhost:8080/"
110
+ },
111
+ "id": "wadlDG2DUUjX",
112
+ "outputId": "06317b7c-7f71-4621-e0b6-c173df0839b7"
113
+ },
114
+ "execution_count": 24,
115
+ "outputs": [
116
+ {
117
+ "output_type": "stream",
118
+ "name": "stdout",
119
+ "text": [
120
+ "[CLS]User: Which is the best programming language? [SEP] Assistant:[SEP] is the best programming language?\n",
121
+ "\n",
122
+ "A: Python is the best programming language. It is simple, powerful, and has a wide range of useful features.[SEP]\n"
123
+ ]
124
+ }
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "source": [
130
+ "# Wrapping that in a function\n",
131
+ "def sample(q, ans_len=32):\n",
132
+ " prompt = f\"User: {q} \" + tokenizer.sep_token + \" Assistant:\"\n",
133
+ " prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)\n",
134
+ " ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]\n",
135
+ " for i in range(ans_len):\n",
136
+ " with torch.no_grad():\n",
137
+ " outs = model(input_ids=torch.tensor([ids]).to(device)).logits\n",
138
+ " out_probs = torch.softmax(outs[0], dim=-1)\n",
139
+ " mask_locs = (torch.tensor(ids) == MASK_ID).nonzero(as_tuple=True)[0]\n",
140
+ " new_probs = torch.zeros_like(out_probs)\n",
141
+ " new_probs[mask_locs] = out_probs[mask_locs]\n",
142
+ " max_probs, max_locs = new_probs.max(dim=-1)\n",
143
+ " max_loc = max_probs.argmax(dim=-1)\n",
144
+ " ids[max_loc] = new_probs[max_loc].argmax().item()\n",
145
+ " return tokenizer.decode(ids)"
146
+ ],
147
+ "metadata": {
148
+ "id": "FAj0rtmhYcjF"
149
+ },
150
+ "execution_count": 25,
151
+ "outputs": []
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "source": [
156
+ "sample(\"Tell me a fun fact about cows\")"
157
+ ],
158
+ "metadata": {
159
+ "colab": {
160
+ "base_uri": "https://localhost:8080/",
161
+ "height": 52
162
+ },
163
+ "id": "HAS20X0oZhw5",
164
+ "outputId": "4f157101-1652-4c25-b67e-b957512bf632"
165
+ },
166
+ "execution_count": 26,
167
+ "outputs": [
168
+ {
169
+ "output_type": "execute_result",
170
+ "data": {
171
+ "text/plain": [
172
+ "\"[CLS]User: Tell me a fun fact about cows [SEP] Assistant:[SEP], here's a fun fact about cows:\\n\\nThe fact is that cows are the most intelligent animals in the world. They can think and make decisions.[SEP]\""
173
+ ],
174
+ "application/vnd.google.colaboratory.intrinsic+json": {
175
+ "type": "string"
176
+ }
177
+ },
178
+ "metadata": {},
179
+ "execution_count": 26
180
+ }
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "source": [
186
+ "sample(\"Tell me a funny joke about lemons\")"
187
+ ],
188
+ "metadata": {
189
+ "colab": {
190
+ "base_uri": "https://localhost:8080/",
191
+ "height": 52
192
+ },
193
+ "id": "f0S3ZQLNUUnU",
194
+ "outputId": "ddfc0e47-bbb1-496b-8177-5d796b8bd9af"
195
+ },
196
+ "execution_count": 30,
197
+ "outputs": [
198
+ {
199
+ "output_type": "execute_result",
200
+ "data": {
201
+ "text/plain": [
202
+ "'[CLS]User: Tell me a funny joke about lemons [SEP] Assistant:[SEP]\\'s a funny joke about lemons: \"I have a lemonade stand, and I\\'m going to sell lemons.\"\\n Assistant: That\\'s funny.[SEP]'"
203
+ ],
204
+ "application/vnd.google.colaboratory.intrinsic+json": {
205
+ "type": "string"
206
+ }
207
+ },
208
+ "metadata": {},
209
+ "execution_count": 30
210
+ }
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "source": [
216
+ "sample(\"Which OS is best?\")"
217
+ ],
218
+ "metadata": {
219
+ "colab": {
220
+ "base_uri": "https://localhost:8080/",
221
+ "height": 52
222
+ },
223
+ "id": "KugOpLPHaQSA",
224
+ "outputId": "43767abf-5a3e-48e0-c14b-b180f7ba9a14"
225
+ },
226
+ "execution_count": 31,
227
+ "outputs": [
228
+ {
229
+ "output_type": "execute_result",
230
+ "data": {
231
+ "text/plain": [
232
+ "\"[CLS]User: Which OS is best? [SEP] Assistant:[SEP], I don't know. I haven't used them personally. I'm sure there are some that are better than others, but I can't tell you.[SEP]\""
233
+ ],
234
+ "application/vnd.google.colaboratory.intrinsic+json": {
235
+ "type": "string"
236
+ }
237
+ },
238
+ "metadata": {},
239
+ "execution_count": 31
240
+ }
241
+ ]
242
+ }
243
+ ]
244
+ }