In [None]:
import transformers, re, os
import gradio as gr
import torch
import datetime, json

os.makedirs("limerick_logs",exist_ok=True)

model_name = "pcalhoun/gpt-j-6b-limericks-finetuned"

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
gpt = transformers.GPTJForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16,device_map="auto")
torch.cuda.empty_cache()

In [None]:
tag = {
    "topic_to_limerick": "T2L",
    "topic_to_rhymes": "T2R",
    "rhymes_to_limerick":"R2L",
}

class EndCriteria(transformers.StoppingCriteria):
  def __init__(self, start_length, eof_strings, tokenizer):
    self.start_length = start_length
    self.eof_strings = eof_strings
    self.tokenizer = tokenizer
  def __call__(self, input_ids, scores, **kwargs):
    decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
    done = []
    for decoded_generation in decoded_generations:
      done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings]))
    return all(done)

def logall(*args,**kwargs):
    open(
        "limerick_logs/"+datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+".txt",
        "w"
        ).write(json.dumps({"args":args,"kwargs":kwargs},indent=4))

def qry(string,
        result_length = 400,
        top_k=50,
        top_p=0.98,
        temperature=0.7,
        num_return_sequences=14,
        do_sample=True,
        pad_token_id=50256,
        stop_tokens = [">"],
        **extra_kwargs):
    with torch.no_grad():
        inputs = tokenizer(string, return_tensors="pt")
        input_size = len(inputs["input_ids"])
        inputs = inputs.to('cuda')
        beam_outputs = gpt.generate(inputs["input_ids"], 
                        stopping_criteria=transformers.StoppingCriteriaList([EndCriteria(
                                input_size,
                                stop_tokens,
                                tokenizer,
                            ),]), max_length=result_length,
                        top_k=top_k, top_p=top_p, do_sample=do_sample,
                        temperature=temperature, pad_token_id=pad_token_id,
                        num_return_sequences=num_return_sequences,**extra_kwargs)
        ret_texts = []
        for beam_output in beam_outputs:
            ret_texts.append(tokenizer.decode(beam_output, skip_special_tokens=True))
    logall(string,*ret_texts)
    return ret_texts

def get_rhymes_strings(topic):
    return [n.split("=")[-1].split(">")[0] for n in qry( "<" + topic + " =" + tag["topic_to_rhymes"] + "=" )]

def get_next_limeric_strings(in_string):
    return_strings = qry(in_string,top_p=0.95)
    for n in range(len(return_strings)):
        return_strings[n] = return_strings[n].split(">")[0]
        return_strings[n] = return_strings[n][len(in_string):].split("/")[0].strip()
    return return_strings
        
def get_next_options(topic,rhymes=None,lines=[]):
    if not rhymes:
        return get_rhymes_strings(topic)
    assert len(lines) < 5
    prompt = "<" + topic + ": " + rhymes + " ="+tag["rhymes_to_limerick"]+"= "
    if len(lines):
        prompt+=" / ".join(lines) + " / "
    print(prompt)
    return get_next_limeric_strings(prompt)
    

In [None]:
def gen_next(*args):
    args=[a for a in args]
    is_final = False
    if len(args) == 13:
        is_final = True
    topic = args.pop(0)
    rhymes = None
    lines = []
    if len(args):
        rhymes_string = args.pop(0).strip()
        rhyme_choice = args.pop(0)
        rhymes = rhymes_string.split("\n")[int(rhyme_choice)].split(":")[-1].strip()
    while len(args):
        lines_string = args.pop(0).strip()
        line_choice = args.pop(0)
        lines += [lines_string.split("\n")[int(line_choice)].split(":")[-1].strip(),]
    if is_final:
        return "\n".join(lines)
    return "\n".join([str(n)+": " + r for n,r in enumerate(get_next_options(topic,rhymes=rhymes,lines=lines))])

In [None]:
with gr.Blocks() as demo:
    gr.Markdown("""To generate a limerick, start by entering a topic. Click the first button to generate a list of words following the AABBA limerick format.  
The first numerical input is for choosing which rhyme set to use. Subsequent inputs are for choosing the next line.

<table><tr><th>Some examples:</th></tr>
<tr><td><p style="font-size:12px;">Scrooge Jeff Bezos, a man with a greed  
<br>For control, took the Amazon's seed  
<br>And with profit the gold  
<br>That it brought, now he's old,  
<br>Has billions, and no real need.</p></td>
<td><p style="font-size:12px;">When you're asking the Greeks, "So what's Greek?"  
<br>You're asking for mythological speak.  
<br>You may hear some tale  
<br>Of Achilles and fail  
<br>If you're not in the know what to seek.</p></td>
<td><p style="font-size:12px;">On your index cards, write down your need,  
<br>And arrange them in order of speed.  
<br>When you're done, you'll recall  
<br>Which one's quicker than all,  
<br>And you'll know which is best, if indeed.</p></td></tr></table>""")
    topic_input = gr.Textbox(label="Limerick topic (two to three words)")
    gen_rhymes_button = gr.Button("Generate end of line rhyming scheme options")
    rhyme_output = gr.Textbox(lines=3,label="Rhyme options")
    gen_rhymes_button.click(
        gen_next, inputs=topic_input, outputs=rhyme_output
    )
    rhyme_input = gr.Number(label="Selection number for rhyme line")
    gen_line_one_button = gr.Button("Generate line one options")
    line_one_output = gr.Textbox(lines=3,label="Line one options")
    gen_line_one_button.click(
        gen_next,
        inputs=[topic_input,rhyme_output,rhyme_input],
        outputs=line_one_output,
    )
    line_one_input = gr.Number(label="Selection number of line one")
    gen_line_two_button = gr.Button("Generate line two options")
    line_two_output = gr.Textbox(lines=3,label="Line two options")
    gen_line_two_button.click(
        gen_next,
        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,
                line_one_input],
        outputs=line_two_output,
    )
    line_two_input = gr.Number(label="Selection number of line two")
    gen_line_three_button = gr.Button("Generate line three options")
    line_three_output = gr.Textbox(lines=3,label="Line three options")
    gen_line_three_button.click(
        gen_next,
        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,
                line_one_input,line_two_output,line_two_input],
        outputs=line_three_output,
    )
    line_three_input = gr.Number(label="Selection number of line three")
    gen_line_four_button = gr.Button("Generate line four options")
    line_four_output = gr.Textbox(lines=3,label="Line four options")
    gen_line_four_button.click(
        gen_next,
        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,
                line_one_input,line_two_output,line_two_input,
                line_three_output,line_three_input],
        outputs=line_four_output,
    )
    line_four_input = gr.Number(label="Selection number of line four")
    gen_line_five_button = gr.Button("Generate line five options")
    line_five_output = gr.Textbox(lines=3,label="Line five options")
    gen_line_five_button.click(
        gen_next,
        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,
                line_one_input,line_two_output,line_two_input,
                line_three_output,line_three_input,line_four_output,
                line_four_input],
        outputs=line_five_output,
    )
    line_five_input = gr.Number(label="Selection number of line five")
    final_button = gr.Button("Display full limerick")
    final_output = gr.Textbox(lines=5)
    final_button.click(
        gen_next,
        inputs=[topic_input,rhyme_output,rhyme_input,line_one_output,
                line_one_input,line_two_output,line_two_input,
                line_three_output,line_three_input,line_four_output,
                line_four_input,line_five_output,line_five_input],
        outputs=final_output,
    )
    gr.Markdown("""(This is a remote-accessible demo of Robert A. Gonsalves' DeepLimericks. The model in use is pretrained GPT-J-6B that was finetuned on two RTX 3090s for the purpose of warming my pantry for a week.)""".strip())
demo.launch()