Add full tools support to the chat template
#45
by
Rocketknight1
HF staff
- opened
This PR is still in progress! It should work, but it needs testing to verify it exactly matches the outputs of mistral-common
.
This has now been tested and confirmed to match the output from mistral-common
!
Test script to confirm:
from transformers import AutoTokenizer
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, ToolMessage
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from mistral_common.protocol.instruct.request import ChatCompletionRequest
hf_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1", revision="pr/45")
hf_tool = {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
}
hf_tool = {"type": "function", "function": hf_tool}
test_chat = [{"role": "user", "content": "What's the weather like today in Paris"}]
tool_call = {"name": "get_current_weather", "arguments": {"location": "Paris, France"}}
test_chat.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call, "id": "abcdef123"}]})
test_chat.append({"role": "tool", "name": "get_current_temperature", "tool_call_id": "abcdef123", "content": "22.0"})
hf_text =hf_tokenizer.apply_chat_template(test_chat, tokenize=False, tools=[hf_tool])
hf_tokens = hf_tokenizer.apply_chat_template(test_chat, tokenize=True, tools=[hf_tool])
mistral_tokenizer = MistralTokenizer.v3()
mistral_tool = Tool(
function=Function(
name="get_current_weather",
description="Get the current weather",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
)
)
mistral_query = ChatCompletionRequest(
tools=[mistral_tool],
messages=[
UserMessage(content="What's the weather like today in Paris"),
AssistantMessage(tool_calls=[ToolCall(type="function", function=FunctionCall(
name="get_current_weather", arguments={"location": "Paris, France"}), id="abcdef123"
)]),
ToolMessage(content="22.0", tool_call_id="abcdef123")
],
model="test",
)
encodeds = mistral_tokenizer.encode_chat_completion(mistral_query).text
mistral_text = encodeds.replace("▁", " ")
mistral_tokens = mistral_tokenizer.encode_chat_completion(mistral_query).tokens
print(hf_text == mistral_text)
print(hf_tokens == mistral_tokens)
pandora-s
changed pull request status to
merged