File size: 11,345 Bytes
8691e45 eb4a450 8691e45 eb4a450 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "af040d11-fa07-4a2c-80a6-33c1ce556eb0",
"metadata": {},
"outputs": [],
"source": [
"import transformers, re, os\n",
"import gradio as gr\n",
"import torch\n",
"import datetime, json\n",
"\n",
"os.makedirs(\"limerick_logs\",exist_ok=True)\n",
"\n",
"model_name = \"pcalhoun/gpt-j-6b-limericks-finetuned\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1d83bf7-f84d-4027-b38f-c4d019b241dd",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n",
"gpt = transformers.GPTJForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16,device_map=\"auto\")\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "028dfd38-0398-47f0-8541-bcb2edd93ac7",
"metadata": {},
"outputs": [],
"source": [
"tag = {\n",
" \"topic_to_limerick\": \"T2L\",\n",
" \"topic_to_rhymes\": \"T2R\",\n",
" \"rhymes_to_limerick\":\"R2L\",\n",
"}\n",
"\n",
"class EndCriteria(transformers.StoppingCriteria):\n",
" def __init__(self, start_length, eof_strings, tokenizer):\n",
" self.start_length = start_length\n",
" self.eof_strings = eof_strings\n",
" self.tokenizer = tokenizer\n",
" def __call__(self, input_ids, scores, **kwargs):\n",
" decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])\n",
" done = []\n",
" for decoded_generation in decoded_generations:\n",
" done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))\n",
" return all(done)\n",
"\n",
"def logall(*args,**kwargs):\n",
" open(\n",
" \"limerick_logs/\"+datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+\".txt\",\n",
" \"w\"\n",
" ).write(json.dumps({\"args\":args,\"kwargs\":kwargs},indent=4))\n",
"\n",
"def qry(string,\n",
" result_length = 400,\n",
" top_k=50,\n",
" top_p=0.98,\n",
" temperature=0.7,\n",
" num_return_sequences=14,\n",
" do_sample=True,\n",
" pad_token_id=50256,\n",
" stop_tokens = [\">\"],\n",
" **extra_kwargs):\n",
" with torch.no_grad():\n",
" inputs = tokenizer(string, return_tensors=\"pt\")\n",
" input_size = len(inputs[\"input_ids\"])\n",
" inputs = inputs.to('cuda')\n",
" beam_outputs = gpt.generate(inputs[\"input_ids\"], \n",
" stopping_criteria=transformers.StoppingCriteriaList([EndCriteria(\n",
" input_size,\n",
" stop_tokens,\n",
" tokenizer,\n",
" ),]), max_length=result_length,\n",
" top_k=top_k, top_p=top_p, do_sample=do_sample,\n",
" temperature=temperature, pad_token_id=pad_token_id,\n",
" num_return_sequences=num_return_sequences,**extra_kwargs)\n",
" ret_texts = []\n",
" for beam_output in beam_outputs:\n",
" ret_texts.append(tokenizer.decode(beam_output, skip_special_tokens=True))\n",
" logall(string,*ret_texts)\n",
" return ret_texts\n",
"\n",
"def get_rhymes_strings(topic):\n",
" return [n.split(\"=\")[-1].split(\">\")[0] for n in qry( \"<\" + topic + \" =\" + tag[\"topic_to_rhymes\"] + \"=\" )]\n",
"\n",
"def get_next_limeric_strings(in_string):\n",
" return_strings = qry(in_string,top_p=0.95)\n",
" for n in range(len(return_strings)):\n",
" return_strings[n] = return_strings[n].split(\">\")[0]\n",
" return_strings[n] = return_strings[n][len(in_string):].split(\"/\")[0].strip()\n",
" return return_strings\n",
" \n",
"def get_next_options(topic,rhymes=None,lines=[]):\n",
" if not rhymes:\n",
" return get_rhymes_strings(topic)\n",
" assert len(lines) < 5\n",
" prompt = \"<\" + topic + \": \" + rhymes + \" =\"+tag[\"rhymes_to_limerick\"]+\"= \"\n",
" if len(lines):\n",
" prompt+=\" / \".join(lines) + \" / \"\n",
" print(prompt)\n",
" return get_next_limeric_strings(prompt)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fb525ce-faa6-4af7-853c-ac1129f6ee27",
"metadata": {},
"outputs": [],
"source": [
"def gen_next(*args):\n",
" args=[a for a in args]\n",
" is_final = False\n",
" if len(args) == 13:\n",
" is_final = True\n",
" topic = args.pop(0)\n",
" rhymes = None\n",
" lines = []\n",
" if len(args):\n",
" rhymes_string = args.pop(0).strip()\n",
" rhyme_choice = args.pop(0)\n",
" rhymes = rhymes_string.split(\"\\n\")[int(rhyme_choice)].split(\":\")[-1].strip()\n",
" while len(args):\n",
" lines_string = args.pop(0).strip()\n",
" line_choice = args.pop(0)\n",
" lines += [lines_string.split(\"\\n\")[int(line_choice)].split(\":\")[-1].strip(),]\n",
" if is_final:\n",
" return \"\\n\".join(lines)\n",
" return \"\\n\".join([str(n)+\": \" + r for n,r in enumerate(get_next_options(topic,rhymes=rhymes,lines=lines))])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "407602e4-6371-400c-b816-74730efae664",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks() as demo:\n",
" gr.Markdown(\"\"\"To generate a limerick, start by entering a topic. Click the first button to generate a list of words following the AABBA limerick format. \n",
"The first numerical input is for choosing which rhyme set to use. Subsequent inputs are for choosing the next line.\n",
"\n",
"<table><tr><th>Some examples:</th></tr>\n",
"<tr><td><p style=\"font-size:12px;\">Scrooge Jeff Bezos, a man with a greed \n",
"<br>For control, took the Amazon's seed \n",
"<br>And with profit the gold \n",
"<br>That it brought, now he's old, \n",
"<br>Has billions, and no real need.</p></td>\n",
"<td><p style=\"font-size:12px;\">When you're asking the Greeks, \"So what's Greek?\" \n",
"<br>You're asking for mythological speak. \n",
"<br>You may hear some tale \n",
"<br>Of Achilles and fail \n",
"<br>If you're not in the know what to seek.</p></td>\n",
"<td><p style=\"font-size:12px;\">On your index cards, write down your need, \n",
"<br>And arrange them in order of speed. \n",
"<br>When you're done, you'll recall \n",
"<br>Which one's quicker than all, \n",
"<br>And you'll know which is best, if indeed.</p></td></tr></table>\"\"\")\n",
" topic_input = gr.Textbox(label=\"Limerick topic (two to three words)\")\n",
" gen_rhymes_button = gr.Button(\"Generate end of line rhyming scheme options\")\n",
" rhyme_output = gr.Textbox(lines=3,label=\"Rhyme options\")\n",
" gen_rhymes_button.click(\n",
" gen_next, inputs=topic_input, outputs=rhyme_output\n",
" )\n",
" rhyme_input = gr.Number(label=\"Selection number for rhyme line\")\n",
" gen_line_one_button = gr.Button(\"Generate line one options\")\n",
" line_one_output = gr.Textbox(lines=3,label=\"Line one options\")\n",
" gen_line_one_button.click(\n",
" gen_next,\n",
" inputs=[topic_input,rhyme_output,rhyme_input],\n",
" outputs=line_one_output,\n",
" )\n",
" line_one_input = gr.Number(label=\"Selection number of line one\")\n",
" gen_line_two_button = gr.Button(\"Generate line two options\")\n",
" line_two_output = gr.Textbox(lines=3,label=\"Line two options\")\n",
" gen_line_two_button.click(\n",
" gen_next,\n",
" inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
" line_one_input],\n",
" outputs=line_two_output,\n",
" )\n",
" line_two_input = gr.Number(label=\"Selection number of line two\")\n",
" gen_line_three_button = gr.Button(\"Generate line three options\")\n",
" line_three_output = gr.Textbox(lines=3,label=\"Line three options\")\n",
" gen_line_three_button.click(\n",
" gen_next,\n",
" inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
" line_one_input,line_two_output,line_two_input],\n",
" outputs=line_three_output,\n",
" )\n",
" line_three_input = gr.Number(label=\"Selection number of line three\")\n",
" gen_line_four_button = gr.Button(\"Generate line four options\")\n",
" line_four_output = gr.Textbox(lines=3,label=\"Line four options\")\n",
" gen_line_four_button.click(\n",
" gen_next,\n",
" inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
" line_one_input,line_two_output,line_two_input,\n",
" line_three_output,line_three_input],\n",
" outputs=line_four_output,\n",
" )\n",
" line_four_input = gr.Number(label=\"Selection number of line four\")\n",
" gen_line_five_button = gr.Button(\"Generate line five options\")\n",
" line_five_output = gr.Textbox(lines=3,label=\"Line five options\")\n",
" gen_line_five_button.click(\n",
" gen_next,\n",
" inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
" line_one_input,line_two_output,line_two_input,\n",
" line_three_output,line_three_input,line_four_output,\n",
" line_four_input],\n",
" outputs=line_five_output,\n",
" )\n",
" line_five_input = gr.Number(label=\"Selection number of line five\")\n",
" final_button = gr.Button(\"Display full limerick\")\n",
" final_output = gr.Textbox(lines=5)\n",
" final_button.click(\n",
" gen_next,\n",
" inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n",
" line_one_input,line_two_output,line_two_input,\n",
" line_three_output,line_three_input,line_four_output,\n",
" line_four_input,line_five_output,line_five_input],\n",
" outputs=final_output,\n",
" )\n",
" gr.Markdown(\"\"\"(This is a remote-accessible demo of Robert A. Gonsalves' DeepLimericks. The model in use is pretrained GPT-J-6B that was finetuned on two RTX 3090s for the purpose of warming my pantry for a week.)\"\"\".strip())\n",
"demo.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
} |