pcalhoun's picture
Update app.py
7c0f254 verified
import os,sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
from peft import PeftModel
import gradio as gr
def rst(num):
return "<|reserved_special_token_" + str(num) + "|>"
PRE_PLAIN_CONTEXT_TOKEN = rst(11) + rst(21)
PRE_SWIFT_CONTEXT_TOKEN = rst(12) + rst(22)
PRE_PLAIN_TOKEN = rst(13) + rst(23)
PRE_SWIFT_TOKEN = rst(14) + rst(24)
eos_extras = [rst(n) for n in range(0, 247)]
base_model_repo_and_name = "meta-llama/Llama-3.1-8B"
lora_model_loc= "pcalhoun/Llama-3.1-8B-JonathanSwift-lora"
tokenizer = AutoTokenizer.from_pretrained(base_model_repo_and_name,token=os.environ["HF_TOKEN"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Use eos_token as pad_token if pad_token is not set
model = AutoModelForCausalLM.from_pretrained(
base_model_repo_and_name,
torch_dtype=torch.bfloat16,
device_map="auto",
load_in_8bit=True,
token=os.environ["HF_TOKEN"],
)
model = PeftModel.from_pretrained(model, lora_model_loc)
model.config.pad_token_id = tokenizer.pad_token_id
model.eval()
def convert_to_swiftian(context_text, plain_text):
if plain_text.strip() == "":
return "Please enter text to convert."
# Construct the prompt with special tokens, without adding spaces after them
prompt = PRE_PLAIN_CONTEXT_TOKEN
if context_text.strip() != "":
prompt += " " + context_text.strip()
prompt += "\n" + PRE_PLAIN_TOKEN + " " + plain_text.strip() + "\n" + PRE_SWIFT_TOKEN
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
input_token_length = input_ids.shape[1]
class EndOfQuestionCriteria(StoppingCriteria):
def __init__(self, start_length, eof_strings, tokenizer):
self.start_length = start_length
self.eof_strings = eof_strings
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
decoded_generations = self.tokenizer.batch_decode(
input_ids[:, self.start_length :], skip_special_tokens=False
)
done = []
for decoded_generation in decoded_generations:
done.append(
any([stop_string in decoded_generation for stop_string in self.eof_strings])
)
return all(done)
stopping_crit=StoppingCriteriaList(
[
EndOfQuestionCriteria(
input_token_length,
eos_extras,
tokenizer,
),
]
)
output_tokens = model.generate(
input_ids,
attention_mask=attention_mask,
stopping_criteria=stopping_crit,
do_sample=True,
temperature=0.8,
max_length=input_token_length + 256,
min_length=input_token_length + 30,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
# Extract the generated Swiftian English text
generated_text = output_text.split(PRE_SWIFT_TOKEN)[-1]
# Remove any EOS tokens at the end
for eos_str in eos_extras:
if generated_text.endswith(eos_str):
generated_text = generated_text[: -len(eos_str)].rstrip()
break
generated_text = generated_text.replace("</s>", "").strip()
return generated_text
demo = gr.Interface(
fn=convert_to_swiftian,
inputs=[
gr.Textbox(lines=5, placeholder="Enter context (optional)", label="Context (Optional)"),
gr.Textbox(lines=5, placeholder="Enter text to convert", label="Text to Convert")
],
outputs="text",
title="Modern English to Swiftian English Converter",
description="Enter modern English text and optionally context, and click Convert to get the Swiftian English version.",
allow_flagging="never"
)
demo.launch()