import os,sys import torch from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList from peft import PeftModel import gradio as gr def rst(num): return "<|reserved_special_token_" + str(num) + "|>" PRE_PLAIN_CONTEXT_TOKEN = rst(11) + rst(21) PRE_SWIFT_CONTEXT_TOKEN = rst(12) + rst(22) PRE_PLAIN_TOKEN = rst(13) + rst(23) PRE_SWIFT_TOKEN = rst(14) + rst(24) eos_extras = [rst(n) for n in range(0, 247)] base_model_repo_and_name = "meta-llama/Llama-3.1-8B" lora_model_loc= "pcalhoun/Llama-3.1-8B-JonathanSwift-lora" tokenizer = AutoTokenizer.from_pretrained(base_model_repo_and_name,token=os.environ["HF_TOKEN"]) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Use eos_token as pad_token if pad_token is not set model = AutoModelForCausalLM.from_pretrained( base_model_repo_and_name, torch_dtype=torch.bfloat16, device_map="auto", load_in_8bit=True, token=os.environ["HF_TOKEN"], ) model = PeftModel.from_pretrained(model, lora_model_loc) model.config.pad_token_id = tokenizer.pad_token_id model.eval() def convert_to_swiftian(context_text, plain_text): if plain_text.strip() == "": return "Please enter text to convert." # Construct the prompt with special tokens, without adding spaces after them prompt = PRE_PLAIN_CONTEXT_TOKEN if context_text.strip() != "": prompt += " " + context_text.strip() prompt += "\n" + PRE_PLAIN_TOKEN + " " + plain_text.strip() + "\n" + PRE_SWIFT_TOKEN inputs = tokenizer(prompt, return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) input_token_length = input_ids.shape[1] class EndOfQuestionCriteria(StoppingCriteria): def __init__(self, start_length, eof_strings, tokenizer): self.start_length = start_length self.eof_strings = eof_strings self.tokenizer = tokenizer def __call__(self, input_ids, scores, **kwargs): decoded_generations = self.tokenizer.batch_decode( input_ids[:, self.start_length :], skip_special_tokens=False ) done = [] for decoded_generation in decoded_generations: done.append( any([stop_string in decoded_generation for stop_string in self.eof_strings]) ) return all(done) stopping_crit=StoppingCriteriaList( [ EndOfQuestionCriteria( input_token_length, eos_extras, tokenizer, ), ] ) output_tokens = model.generate( input_ids, attention_mask=attention_mask, stopping_criteria=stopping_crit, do_sample=True, temperature=0.8, max_length=input_token_length + 256, min_length=input_token_length + 30, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False) # Extract the generated Swiftian English text generated_text = output_text.split(PRE_SWIFT_TOKEN)[-1] # Remove any EOS tokens at the end for eos_str in eos_extras: if generated_text.endswith(eos_str): generated_text = generated_text[: -len(eos_str)].rstrip() break generated_text = generated_text.replace("", "").strip() return generated_text demo = gr.Interface( fn=convert_to_swiftian, inputs=[ gr.Textbox(lines=5, placeholder="Enter context (optional)", label="Context (Optional)"), gr.Textbox(lines=5, placeholder="Enter text to convert", label="Text to Convert") ], outputs="text", title="Modern English to Swiftian English Converter", description="Enter modern English text and optionally context, and click Convert to get the Swiftian English version.", allow_flagging="never" ) demo.launch()