{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "af040d11-fa07-4a2c-80a6-33c1ce556eb0", "metadata": {}, "outputs": [], "source": [ "import transformers, re, os\n", "import gradio as gr\n", "import torch\n", "import datetime, json\n", "\n", "os.makedirs(\"limerick_logs\",exist_ok=True)\n", "\n", "model_name = \"pcalhoun/gpt-j-6b-limericks-finetuned\"" ] }, { "cell_type": "code", "execution_count": null, "id": "d1d83bf7-f84d-4027-b38f-c4d019b241dd", "metadata": {}, "outputs": [], "source": [ "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n", "gpt = transformers.GPTJForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16,device_map=\"auto\")\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": null, "id": "028dfd38-0398-47f0-8541-bcb2edd93ac7", "metadata": {}, "outputs": [], "source": [ "tag = {\n", " \"topic_to_limerick\": \"T2L\",\n", " \"topic_to_rhymes\": \"T2R\",\n", " \"rhymes_to_limerick\":\"R2L\",\n", "}\n", "\n", "class EndCriteria(transformers.StoppingCriteria):\n", " def __init__(self, start_length, eof_strings, tokenizer):\n", " self.start_length = start_length\n", " self.eof_strings = eof_strings\n", " self.tokenizer = tokenizer\n", " def __call__(self, input_ids, scores, **kwargs):\n", " decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])\n", " done = []\n", " for decoded_generation in decoded_generations:\n", " done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))\n", " return all(done)\n", "\n", "def logall(*args,**kwargs):\n", " open(\n", " \"limerick_logs/\"+datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+\".txt\",\n", " \"w\"\n", " ).write(json.dumps({\"args\":args,\"kwargs\":kwargs},indent=4))\n", "\n", "def qry(string,\n", " result_length = 400,\n", " top_k=50,\n", " top_p=0.98,\n", " temperature=0.7,\n", " num_return_sequences=14,\n", " do_sample=True,\n", " pad_token_id=50256,\n", " stop_tokens = [\">\"],\n", " **extra_kwargs):\n", " with torch.no_grad():\n", " inputs = tokenizer(string, return_tensors=\"pt\")\n", " input_size = len(inputs[\"input_ids\"])\n", " inputs = inputs.to('cuda')\n", " beam_outputs = gpt.generate(inputs[\"input_ids\"], \n", " stopping_criteria=transformers.StoppingCriteriaList([EndCriteria(\n", " input_size,\n", " stop_tokens,\n", " tokenizer,\n", " ),]), max_length=result_length,\n", " top_k=top_k, top_p=top_p, do_sample=do_sample,\n", " temperature=temperature, pad_token_id=pad_token_id,\n", " num_return_sequences=num_return_sequences,**extra_kwargs)\n", " ret_texts = []\n", " for beam_output in beam_outputs:\n", " ret_texts.append(tokenizer.decode(beam_output, skip_special_tokens=True))\n", " logall(string,*ret_texts)\n", " return ret_texts\n", "\n", "def get_rhymes_strings(topic):\n", " return [n.split(\"=\")[-1].split(\">\")[0] for n in qry( \"<\" + topic + \" =\" + tag[\"topic_to_rhymes\"] + \"=\" )]\n", "\n", "def get_next_limeric_strings(in_string):\n", " return_strings = qry(in_string,top_p=0.95)\n", " for n in range(len(return_strings)):\n", " return_strings[n] = return_strings[n].split(\">\")[0]\n", " return_strings[n] = return_strings[n][len(in_string):].split(\"/\")[0].strip()\n", " return return_strings\n", " \n", "def get_next_options(topic,rhymes=None,lines=[]):\n", " if not rhymes:\n", " return get_rhymes_strings(topic)\n", " assert len(lines) < 5\n", " prompt = \"<\" + topic + \": \" + rhymes + \" =\"+tag[\"rhymes_to_limerick\"]+\"= \"\n", " if len(lines):\n", " prompt+=\" / \".join(lines) + \" / \"\n", " print(prompt)\n", " return get_next_limeric_strings(prompt)\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "5fb525ce-faa6-4af7-853c-ac1129f6ee27", "metadata": {}, "outputs": [], "source": [ "def gen_next(*args):\n", " args=[a for a in args]\n", " is_final = False\n", " if len(args) == 13:\n", " is_final = True\n", " topic = args.pop(0)\n", " rhymes = None\n", " lines = []\n", " if len(args):\n", " rhymes_string = args.pop(0).strip()\n", " rhyme_choice = args.pop(0)\n", " rhymes = rhymes_string.split(\"\\n\")[int(rhyme_choice)].split(\":\")[-1].strip()\n", " while len(args):\n", " lines_string = args.pop(0).strip()\n", " line_choice = args.pop(0)\n", " lines += [lines_string.split(\"\\n\")[int(line_choice)].split(\":\")[-1].strip(),]\n", " if is_final:\n", " return \"\\n\".join(lines)\n", " return \"\\n\".join([str(n)+\": \" + r for n,r in enumerate(get_next_options(topic,rhymes=rhymes,lines=lines))])" ] }, { "cell_type": "code", "execution_count": null, "id": "407602e4-6371-400c-b816-74730efae664", "metadata": {}, "outputs": [], "source": [ "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"To generate a limerick, start by entering a topic. Click the first button to generate a list of words following the AABBA limerick format. \n", "The first numerical input is for choosing which rhyme set to use. Subsequent inputs are for choosing the next line.\n", "\n", "
Some examples: | ||
---|---|---|
Scrooge Jeff Bezos, a man with a greed \n",
" | \n",
"When you're asking the Greeks, \"So what's Greek?\" \n",
" | \n",
"On your index cards, write down your need, \n",
" |