{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "af040d11-fa07-4a2c-80a6-33c1ce556eb0", "metadata": {}, "outputs": [], "source": [ "import transformers, re, os\n", "import gradio as gr\n", "import torch\n", "import datetime, json\n", "\n", "os.makedirs(\"limerick_logs\",exist_ok=True)\n", "\n", "model_name = \"pcalhoun/gpt-j-6b-limericks-finetuned\"" ] }, { "cell_type": "code", "execution_count": null, "id": "d1d83bf7-f84d-4027-b38f-c4d019b241dd", "metadata": {}, "outputs": [], "source": [ "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n", "gpt = transformers.GPTJForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16,device_map=\"auto\")\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": null, "id": "028dfd38-0398-47f0-8541-bcb2edd93ac7", "metadata": {}, "outputs": [], "source": [ "tag = {\n", " \"topic_to_limerick\": \"T2L\",\n", " \"topic_to_rhymes\": \"T2R\",\n", " \"rhymes_to_limerick\":\"R2L\",\n", "}\n", "\n", "class EndCriteria(transformers.StoppingCriteria):\n", " def __init__(self, start_length, eof_strings, tokenizer):\n", " self.start_length = start_length\n", " self.eof_strings = eof_strings\n", " self.tokenizer = tokenizer\n", " def __call__(self, input_ids, scores, **kwargs):\n", " decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])\n", " done = []\n", " for decoded_generation in decoded_generations:\n", " done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))\n", " return all(done)\n", "\n", "def logall(*args,**kwargs):\n", " open(\n", " \"limerick_logs/\"+datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+\".txt\",\n", " \"w\"\n", " ).write(json.dumps({\"args\":args,\"kwargs\":kwargs},indent=4))\n", "\n", "def qry(string,\n", " result_length = 400,\n", " top_k=50,\n", " top_p=0.98,\n", " temperature=0.7,\n", " num_return_sequences=14,\n", " do_sample=True,\n", " pad_token_id=50256,\n", " stop_tokens = [\">\"],\n", " **extra_kwargs):\n", " with torch.no_grad():\n", " inputs = tokenizer(string, return_tensors=\"pt\")\n", " input_size = len(inputs[\"input_ids\"])\n", " inputs = inputs.to('cuda')\n", " beam_outputs = gpt.generate(inputs[\"input_ids\"], \n", " stopping_criteria=transformers.StoppingCriteriaList([EndCriteria(\n", " input_size,\n", " stop_tokens,\n", " tokenizer,\n", " ),]), max_length=result_length,\n", " top_k=top_k, top_p=top_p, do_sample=do_sample,\n", " temperature=temperature, pad_token_id=pad_token_id,\n", " num_return_sequences=num_return_sequences,**extra_kwargs)\n", " ret_texts = []\n", " for beam_output in beam_outputs:\n", " ret_texts.append(tokenizer.decode(beam_output, skip_special_tokens=True))\n", " logall(string,*ret_texts)\n", " return ret_texts\n", "\n", "def get_rhymes_strings(topic):\n", " return [n.split(\"=\")[-1].split(\">\")[0] for n in qry( \"<\" + topic + \" =\" + tag[\"topic_to_rhymes\"] + \"=\" )]\n", "\n", "def get_next_limeric_strings(in_string):\n", " return_strings = qry(in_string,top_p=0.95)\n", " for n in range(len(return_strings)):\n", " return_strings[n] = return_strings[n].split(\">\")[0]\n", " return_strings[n] = return_strings[n][len(in_string):].split(\"/\")[0].strip()\n", " return return_strings\n", " \n", "def get_next_options(topic,rhymes=None,lines=[]):\n", " if not rhymes:\n", " return get_rhymes_strings(topic)\n", " assert len(lines) < 5\n", " prompt = \"<\" + topic + \": \" + rhymes + \" =\"+tag[\"rhymes_to_limerick\"]+\"= \"\n", " if len(lines):\n", " prompt+=\" / \".join(lines) + \" / \"\n", " print(prompt)\n", " return get_next_limeric_strings(prompt)\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "5fb525ce-faa6-4af7-853c-ac1129f6ee27", "metadata": {}, "outputs": [], "source": [ "def gen_next(*args):\n", " args=[a for a in args]\n", " is_final = False\n", " if len(args) == 13:\n", " is_final = True\n", " topic = args.pop(0)\n", " rhymes = None\n", " lines = []\n", " if len(args):\n", " rhymes_string = args.pop(0).strip()\n", " rhyme_choice = args.pop(0)\n", " rhymes = rhymes_string.split(\"\\n\")[int(rhyme_choice)].split(\":\")[-1].strip()\n", " while len(args):\n", " lines_string = args.pop(0).strip()\n", " line_choice = args.pop(0)\n", " lines += [lines_string.split(\"\\n\")[int(line_choice)].split(\":\")[-1].strip(),]\n", " if is_final:\n", " return \"\\n\".join(lines)\n", " return \"\\n\".join([str(n)+\": \" + r for n,r in enumerate(get_next_options(topic,rhymes=rhymes,lines=lines))])" ] }, { "cell_type": "code", "execution_count": null, "id": "407602e4-6371-400c-b816-74730efae664", "metadata": {}, "outputs": [], "source": [ "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"To generate a limerick, start by entering a topic. Click the first button to generate a list of words following the AABBA limerick format. \n", "The first numerical input is for choosing which rhyme set to use. Subsequent inputs are for choosing the next line.\n", "\n", "\n", "\n", "\n", "
Some examples:

Scrooge Jeff Bezos, a man with a greed \n", "
For control, took the Amazon's seed \n", "
And with profit the gold \n", "
That it brought, now he's old, \n", "
Has billions, and no real need.

When you're asking the Greeks, \"So what's Greek?\" \n", "
You're asking for mythological speak. \n", "
You may hear some tale \n", "
Of Achilles and fail \n", "
If you're not in the know what to seek.

On your index cards, write down your need, \n", "
And arrange them in order of speed. \n", "
When you're done, you'll recall \n", "
Which one's quicker than all, \n", "
And you'll know which is best, if indeed.

\"\"\")\n", " topic_input = gr.Textbox(label=\"Limerick topic (two to three words)\")\n", " gen_rhymes_button = gr.Button(\"Generate end of line rhyming scheme options\")\n", " rhyme_output = gr.Textbox(lines=3,label=\"Rhyme options\")\n", " gen_rhymes_button.click(\n", " gen_next, inputs=topic_input, outputs=rhyme_output\n", " )\n", " rhyme_input = gr.Number(label=\"Selection number for rhyme line\")\n", " gen_line_one_button = gr.Button(\"Generate line one options\")\n", " line_one_output = gr.Textbox(lines=3,label=\"Line one options\")\n", " gen_line_one_button.click(\n", " gen_next,\n", " inputs=[topic_input,rhyme_output,rhyme_input],\n", " outputs=line_one_output,\n", " )\n", " line_one_input = gr.Number(label=\"Selection number of line one\")\n", " gen_line_two_button = gr.Button(\"Generate line two options\")\n", " line_two_output = gr.Textbox(lines=3,label=\"Line two options\")\n", " gen_line_two_button.click(\n", " gen_next,\n", " inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n", " line_one_input],\n", " outputs=line_two_output,\n", " )\n", " line_two_input = gr.Number(label=\"Selection number of line two\")\n", " gen_line_three_button = gr.Button(\"Generate line three options\")\n", " line_three_output = gr.Textbox(lines=3,label=\"Line three options\")\n", " gen_line_three_button.click(\n", " gen_next,\n", " inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n", " line_one_input,line_two_output,line_two_input],\n", " outputs=line_three_output,\n", " )\n", " line_three_input = gr.Number(label=\"Selection number of line three\")\n", " gen_line_four_button = gr.Button(\"Generate line four options\")\n", " line_four_output = gr.Textbox(lines=3,label=\"Line four options\")\n", " gen_line_four_button.click(\n", " gen_next,\n", " inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n", " line_one_input,line_two_output,line_two_input,\n", " line_three_output,line_three_input],\n", " outputs=line_four_output,\n", " )\n", " line_four_input = gr.Number(label=\"Selection number of line four\")\n", " gen_line_five_button = gr.Button(\"Generate line five options\")\n", " line_five_output = gr.Textbox(lines=3,label=\"Line five options\")\n", " gen_line_five_button.click(\n", " gen_next,\n", " inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n", " line_one_input,line_two_output,line_two_input,\n", " line_three_output,line_three_input,line_four_output,\n", " line_four_input],\n", " outputs=line_five_output,\n", " )\n", " line_five_input = gr.Number(label=\"Selection number of line five\")\n", " final_button = gr.Button(\"Display full limerick\")\n", " final_output = gr.Textbox(lines=5)\n", " final_button.click(\n", " gen_next,\n", " inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,\n", " line_one_input,line_two_output,line_two_input,\n", " line_three_output,line_three_input,line_four_output,\n", " line_four_input,line_five_output,line_five_input],\n", " outputs=final_output,\n", " )\n", " gr.Markdown(\"\"\"(This is a remote-accessible demo of Robert A. Gonsalves' DeepLimericks. The model in use is pretrained GPT-J-6B that was finetuned on two RTX 3090s for the purpose of warming my pantry for a week.)\"\"\".strip())\n", "demo.launch()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }