Update tokenization_minicpm.py
Browse files- tokenization_minicpm.py +10 -10
tokenization_minicpm.py
CHANGED
|
@@ -4,7 +4,6 @@ import keyword
|
|
| 4 |
import traceback
|
| 5 |
import uuid
|
| 6 |
from collections import deque
|
| 7 |
-
from copy import deepcopy
|
| 8 |
from logging import getLogger
|
| 9 |
from typing import Any, Dict, List, Optional, Union
|
| 10 |
|
|
@@ -17,6 +16,7 @@ from jsonschema import Draft202012Validator, exceptions, validate
|
|
| 17 |
from transformers import LlamaTokenizerFast
|
| 18 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 19 |
from transformers.utils import TensorType
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
logger = getLogger(__name__)
|
|
@@ -148,7 +148,7 @@ class MiniCPMTokenizer(LlamaTokenizerFast):
|
|
| 148 |
tool_calls.append(this_one)
|
| 149 |
|
| 150 |
return {
|
| 151 |
-
"content": content
|
| 152 |
"tool_calls": [
|
| 153 |
{"type": "function", "function": tool_call, "id": "call_" + uuid.uuid4().hex}
|
| 154 |
for tool_call in tool_calls
|
|
@@ -158,13 +158,13 @@ class MiniCPMTokenizer(LlamaTokenizerFast):
|
|
| 158 |
except:
|
| 159 |
logger.error(traceback.format_exc())
|
| 160 |
return {
|
| 161 |
-
"content": content
|
| 162 |
"role": "assistant",
|
| 163 |
"thought": thought_string,
|
| 164 |
}
|
| 165 |
else:
|
| 166 |
return {
|
| 167 |
-
"content": sequence
|
| 168 |
"role": "assistant",
|
| 169 |
"thought": thought_string,
|
| 170 |
}
|
|
@@ -259,10 +259,11 @@ def message_format(msg, system_suffix="", user_prefix=""):
|
|
| 259 |
content = thought_prefix + content
|
| 260 |
msg["content"] = content
|
| 261 |
elif msg["role"] == "user":
|
| 262 |
-
|
|
|
|
| 263 |
elif msg["role"] == "system":
|
| 264 |
msg["content"] = msg["content"] + "\n" + system_suffix
|
| 265 |
-
msg["content"] = msg["content"]
|
| 266 |
return msg
|
| 267 |
|
| 268 |
|
|
@@ -361,12 +362,12 @@ func2(params)
|
|
| 361 |
<|tool_call_end|>
|
| 362 |
{{answer the user's question directly or ask the user for more information}}
|
| 363 |
"""
|
| 364 |
-
tools_string = tools_template.format(tools=tools_string)
|
| 365 |
else:
|
| 366 |
tools_string = ""
|
| 367 |
|
| 368 |
if add_to_system:
|
| 369 |
-
if len(messages) > 0 and messages[0]["role"] != "system" and tools_string
|
| 370 |
messages.insert(0, {"role": "system", "content": ""})
|
| 371 |
return [message_format(msg, system_suffix=tools_string, user_prefix="") for msg in messages]
|
| 372 |
else:
|
|
@@ -429,5 +430,4 @@ def resolve_ast_by_type(value):
|
|
| 429 |
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
| 430 |
else:
|
| 431 |
raise Exception(f"Unsupported AST type: {type(value)}")
|
| 432 |
-
return output
|
| 433 |
-
|
|
|
|
| 4 |
import traceback
|
| 5 |
import uuid
|
| 6 |
from collections import deque
|
|
|
|
| 7 |
from logging import getLogger
|
| 8 |
from typing import Any, Dict, List, Optional, Union
|
| 9 |
|
|
|
|
| 16 |
from transformers import LlamaTokenizerFast
|
| 17 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 18 |
from transformers.utils import TensorType
|
| 19 |
+
from copy import deepcopy
|
| 20 |
|
| 21 |
|
| 22 |
logger = getLogger(__name__)
|
|
|
|
| 148 |
tool_calls.append(this_one)
|
| 149 |
|
| 150 |
return {
|
| 151 |
+
"content": content,
|
| 152 |
"tool_calls": [
|
| 153 |
{"type": "function", "function": tool_call, "id": "call_" + uuid.uuid4().hex}
|
| 154 |
for tool_call in tool_calls
|
|
|
|
| 158 |
except:
|
| 159 |
logger.error(traceback.format_exc())
|
| 160 |
return {
|
| 161 |
+
"content": content,
|
| 162 |
"role": "assistant",
|
| 163 |
"thought": thought_string,
|
| 164 |
}
|
| 165 |
else:
|
| 166 |
return {
|
| 167 |
+
"content": sequence,
|
| 168 |
"role": "assistant",
|
| 169 |
"thought": thought_string,
|
| 170 |
}
|
|
|
|
| 259 |
content = thought_prefix + content
|
| 260 |
msg["content"] = content
|
| 261 |
elif msg["role"] == "user":
|
| 262 |
+
if user_prefix != "":
|
| 263 |
+
msg["content"] = user_prefix + "\n" + msg["content"]
|
| 264 |
elif msg["role"] == "system":
|
| 265 |
msg["content"] = msg["content"] + "\n" + system_suffix
|
| 266 |
+
msg["content"] = msg["content"]
|
| 267 |
return msg
|
| 268 |
|
| 269 |
|
|
|
|
| 362 |
<|tool_call_end|>
|
| 363 |
{{answer the user's question directly or ask the user for more information}}
|
| 364 |
"""
|
| 365 |
+
tools_string = tools_template.format(tools=tools_string)
|
| 366 |
else:
|
| 367 |
tools_string = ""
|
| 368 |
|
| 369 |
if add_to_system:
|
| 370 |
+
if len(messages) > 0 and messages[0]["role"] != "system" and len(tools_string.strip()) > 0:
|
| 371 |
messages.insert(0, {"role": "system", "content": ""})
|
| 372 |
return [message_format(msg, system_suffix=tools_string, user_prefix="") for msg in messages]
|
| 373 |
else:
|
|
|
|
| 430 |
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
| 431 |
else:
|
| 432 |
raise Exception(f"Unsupported AST type: {type(value)}")
|
| 433 |
+
return output
|
|
|