Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer | |
| from fastchat.conversation import get_conv_template | |
| import os | |
| from utils import sanitize_jinja2 | |
| import difflib | |
| def test_llama2_template(): | |
| jinja_lines = [] | |
| with open("../templates/mistral-7b-openorca.jinja2", "r") as f: | |
| jinja_lines = f.readlines() | |
| print("jinja_lines: ", jinja_lines) | |
| print("sanitized: ", sanitize_jinja2(jinja_lines)) | |
| chat = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "Hello, how are you?"}, | |
| {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, | |
| {"role": "user", "content": "I'd like to show off how chat templating works!"}, | |
| ] | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="Open-Orca/Mistral-7B-OpenOrca", trust_remote_code=True) | |
| # f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" | |
| transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) | |
| print("default template") | |
| print(transformer_prompt) | |
| # print(tokenizer.chat_template) | |
| # tokenizer.eos_token = "<|end_of_turn|>" | |
| tokenizer.chat_template = sanitize_jinja2(jinja_lines) | |
| transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) | |
| print() | |
| print("add_generation_prompt False:") | |
| print(transformer_prompt) | |
| transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
| print() | |
| print("add_generation_prompt True:") | |
| print(transformer_prompt) | |
| # transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True) | |
| # print(transformer_prompt) | |
| print("Fastchat template: ") | |
| conv = get_conv_template("mistral-7b-openorca") | |
| conv.set_system_message(chat[0]["content"]) | |
| conv.append_message(conv.roles[0], chat[1]["content"]) | |
| conv.append_message(conv.roles[1], chat[2]["content"]) | |
| conv.append_message(conv.roles[0], chat[3]["content"]) | |
| conv.append_message(conv.roles[1], None) | |
| print(conv.get_prompt()) | |
| matcher = difflib.SequenceMatcher(a=transformer_prompt, b=conv.get_prompt()) | |
| print("Matching Sequences:") | |
| for match in matcher.get_matching_blocks(): | |
| print("Match : {}".format(match)) | |
| print("Matching Sequence : {}".format(transformer_prompt[match.a:match.a+match.size])) | |
| assert transformer_prompt == conv.get_prompt() | |
| if __name__ == "__main__": | |
| test_llama2_template() |