File size: 11,345 Bytes
8691e45
 
 
 
 
 
 
 
 
 
 
 
 
 
eb4a450
 
8691e45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb4a450
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af040d11-fa07-4a2c-80a6-33c1ce556eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers, re, os\n",
    "import gradio as gr\n",
    "import torch\n",
    "import datetime, json\n",
    "\n",
    "os.makedirs(\"limerick_logs\",exist_ok=True)\n",
    "\n",
    "model_name = \"pcalhoun/gpt-j-6b-limericks-finetuned\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1d83bf7-f84d-4027-b38f-c4d019b241dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n",
    "gpt = transformers.GPTJForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16,device_map=\"auto\")\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028dfd38-0398-47f0-8541-bcb2edd93ac7",
   "metadata": {},
   "outputs": [],
   "source": [
    "tag = {\n",
    "    \"topic_to_limerick\": \"T2L\",\n",
    "    \"topic_to_rhymes\": \"T2R\",\n",
    "    \"rhymes_to_limerick\":\"R2L\",\n",
    "}\n",
    "\n",
    "class EndCriteria(transformers.StoppingCriteria):\n",
    "  def __init__(self, start_length, eof_strings, tokenizer):\n",
    "    self.start_length = start_length\n",
    "    self.eof_strings = eof_strings\n",
    "    self.tokenizer = tokenizer\n",
    "  def __call__(self, input_ids, scores, **kwargs):\n",
    "    decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])\n",
    "    done = []\n",
    "    for decoded_generation in decoded_generations:\n",
    "      done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))\n",
    "    return all(done)\n",
    "\n",
    "def logall(*args,**kwargs):\n",
    "    open(\n",
    "        \"limerick_logs/\"+datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+\".txt\",\n",
    "        \"w\"\n",
    "        ).write(json.dumps({\"args\":args,\"kwargs\":kwargs},indent=4))\n",
    "\n",
    "def qry(string,\n",
    "        result_length = 400,\n",
    "        top_k=50,\n",
    "        top_p=0.98,\n",
    "        temperature=0.7,\n",
    "        num_return_sequences=14,\n",
    "        do_sample=True,\n",
    "        pad_token_id=50256,\n",
    "        stop_tokens = [\">\"],\n",
    "        **extra_kwargs):\n",
    "    with torch.no_grad():\n",
    "        inputs = tokenizer(string, return_tensors=\"pt\")\n",
    "        input_size = len(inputs[\"input_ids\"])\n",
    "        inputs = inputs.to('cuda')\n",
    "        beam_outputs = gpt.generate(inputs[\"input_ids\"], \n",
    "                        stopping_criteria=transformers.StoppingCriteriaList([EndCriteria(\n",
    "                                input_size,\n",
    "                                stop_tokens,\n",
    "                                tokenizer,\n",
    "                            ),]), max_length=result_length,\n",
    "                        top_k=top_k, top_p=top_p, do_sample=do_sample,\n",
    "                        temperature=temperature, pad_token_id=pad_token_id,\n",
    "                        num_return_sequences=num_return_sequences,**extra_kwargs)\n",
    "        ret_texts = []\n",
    "        for beam_output in beam_outputs:\n",
    "            ret_texts.append(tokenizer.decode(beam_output, skip_special_tokens=True))\n",
    "    logall(string,*ret_texts)\n",
    "    return ret_texts\n",
    "\n",
    "def get_rhymes_strings(topic):\n",
    "    return [n.split(\"=\")[-1].split(\">\")[0] for n in qry( \"<\" + topic + \" =\" + tag[\"topic_to_rhymes\"] + \"=\" )]\n",
    "\n",
    "def get_next_limeric_strings(in_string):\n",
    "    return_strings = qry(in_string,top_p=0.95)\n",
    "    for n in range(len(return_strings)):\n",
    "        return_strings[n] = return_strings[n].split(\">\")[0]\n",
    "        return_strings[n] = return_strings[n][len(in_string):].split(\"/\")[0].strip()\n",
    "    return return_strings\n",
    "        \n",
    "def get_next_options(topic,rhymes=None,lines=[]):\n",
    "    if not rhymes:\n",
    "        return get_rhymes_strings(topic)\n",
    "    assert len(lines) < 5\n",
    "    prompt = \"<\" + topic + \": \" + rhymes + \" =\"+tag[\"rhymes_to_limerick\"]+\"= \"\n",
    "    if len(lines):\n",
    "        prompt+=\" / \".join(lines) + \" / \"\n",
    "    print(prompt)\n",
    "    return get_next_limeric_strings(prompt)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fb525ce-faa6-4af7-853c-ac1129f6ee27",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_next(*args):\n",
    "    args=[a for a in args]\n",
    "    is_final = False\n",
    "    if len(args) == 13:\n",
    "        is_final = True\n",
    "    topic = args.pop(0)\n",
    "    rhymes = None\n",
    "    lines = []\n",
    "    if len(args):\n",
    "        rhymes_string = args.pop(0).strip()\n",
    "        rhyme_choice = args.pop(0)\n",
    "        rhymes = rhymes_string.split(\"\\n\")[int(rhyme_choice)].split(\":\")[-1].strip()\n",
    "    while len(args):\n",
    "        lines_string = args.pop(0).strip()\n",
    "        line_choice = args.pop(0)\n",
    "        lines += [lines_string.split(\"\\n\")[int(line_choice)].split(\":\")[-1].strip(),]\n",
    "    if is_final:\n",
    "        return \"\\n\".join(lines)\n",
    "    return \"\\n\".join([str(n)+\": \" + r for n,r in enumerate(get_next_options(topic,rhymes=rhymes,lines=lines))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "407602e4-6371-400c-b816-74730efae664",
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"\"\"To generate a limerick, start by entering a topic. Click the first button to generate a list of words following the AABBA limerick format.  \n",
    "The first numerical input is for choosing which rhyme set to use. Subsequent inputs are for choosing the next line.\n",
    "\n",
    "<table><tr><th>Some examples:</th></tr>\n",
    "<tr><td><p style=\"font-size:12px;\">Scrooge Jeff Bezos, a man with a greed  \n",
    "<br>For control, took the Amazon's seed  \n",
    "<br>And with profit the gold  \n",
    "<br>That it brought, now he's old,  \n",
    "<br>Has billions, and no real need.</p></td>\n",
    "<td><p style=\"font-size:12px;\">When you're asking the Greeks, \"So what's Greek?\"  \n",
    "<br>You're asking for mythological speak.  \n",
    "<br>You may hear some tale  \n",
    "<br>Of Achilles and fail  \n",
    "<br>If you're not in the know what to seek.</p></td>\n",
    "<td><p style=\"font-size:12px;\">On your index cards, write down your need,  \n",
    "<br>And arrange them in order of speed.  \n",
    "<br>When you're done, you'll recall  \n",
    "<br>Which one's quicker than all,  \n",
    "<br>And you'll know which is best, if indeed.</p></td></tr></table>\"\"\")\n",
    "    topic_input = gr.Textbox(label=\"Limerick topic (two to three words)\")\n",
    "    gen_rhymes_button = gr.Button(\"Generate end of line rhyming scheme options\")\n",
    "    rhyme_output = gr.Textbox(lines=3,label=\"Rhyme options\")\n",
    "    gen_rhymes_button.click(\n",
    "        gen_next, inputs=topic_input, outputs=rhyme_output\n",
    "    )\n",
    "    rhyme_input = gr.Number(label=\"Selection number for rhyme line\")\n",
    "    gen_line_one_button = gr.Button(\"Generate line one options\")\n",
    "    line_one_output = gr.Textbox(lines=3,label=\"Line one options\")\n",
    "    gen_line_one_button.click(\n",
    "        gen_next,\n",
    "        inputs=[topic_input,rhyme_output,rhyme_input],\n",
    "        outputs=line_one_output,\n",
    "    )\n",
    "    line_one_input = gr.Number(label=\"Selection number of line one\")\n",
    "    gen_line_two_button = gr.Button(\"Generate line two options\")\n",
    "    line_two_output = gr.Textbox(lines=3,label=\"Line two options\")\n",
    "    gen_line_two_button.click(\n",
    "        gen_next,\n",
    "        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
    "                line_one_input],\n",
    "        outputs=line_two_output,\n",
    "    )\n",
    "    line_two_input = gr.Number(label=\"Selection number of line two\")\n",
    "    gen_line_three_button = gr.Button(\"Generate line three options\")\n",
    "    line_three_output = gr.Textbox(lines=3,label=\"Line three options\")\n",
    "    gen_line_three_button.click(\n",
    "        gen_next,\n",
    "        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
    "                line_one_input,line_two_output,line_two_input],\n",
    "        outputs=line_three_output,\n",
    "    )\n",
    "    line_three_input = gr.Number(label=\"Selection number of line three\")\n",
    "    gen_line_four_button = gr.Button(\"Generate line four options\")\n",
    "    line_four_output = gr.Textbox(lines=3,label=\"Line four options\")\n",
    "    gen_line_four_button.click(\n",
    "        gen_next,\n",
    "        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
    "                line_one_input,line_two_output,line_two_input,\n",
    "                line_three_output,line_three_input],\n",
    "        outputs=line_four_output,\n",
    "    )\n",
    "    line_four_input = gr.Number(label=\"Selection number of line four\")\n",
    "    gen_line_five_button = gr.Button(\"Generate line five options\")\n",
    "    line_five_output = gr.Textbox(lines=3,label=\"Line five options\")\n",
    "    gen_line_five_button.click(\n",
    "        gen_next,\n",
    "        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
    "                line_one_input,line_two_output,line_two_input,\n",
    "                line_three_output,line_three_input,line_four_output,\n",
    "                line_four_input],\n",
    "        outputs=line_five_output,\n",
    "    )\n",
    "    line_five_input = gr.Number(label=\"Selection number of line five\")\n",
    "    final_button = gr.Button(\"Display full limerick\")\n",
    "    final_output = gr.Textbox(lines=5)\n",
    "    final_button.click(\n",
    "        gen_next,\n",
    "        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
    "                line_one_input,line_two_output,line_two_input,\n",
    "                line_three_output,line_three_input,line_four_output,\n",
    "                line_four_input,line_five_output,line_five_input],\n",
    "        outputs=final_output,\n",
    "    )\n",
    "    gr.Markdown(\"\"\"(This is a remote-accessible demo of Robert A. Gonsalves' DeepLimericks. The model in use is pretrained GPT-J-6B that was finetuned on two RTX 3090s for the purpose of warming my pantry for a week.)\"\"\".strip())\n",
    "demo.launch()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}