File size: 21,342 Bytes
9eb70c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
import ast
import json
import re
from collections.abc import Sequence
from typing import Union

import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaFunctionCall, DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
    ToolParser,
    ToolParserManager,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)


@ToolParserManager.register_module("llama_nemotron_xml")
class LlamaNemotronXMLToolParser(ToolParser):

    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)

        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id: int = -1  # Potentially for streaming
        self.streamed_args_for_tool: list[str] = [] # Potentially for streaming

        self.tool_call_start_token: str = "<tool_call>"
        self.tool_call_end_token: str = "</tool_call>"

        # Regex to find full <tool_call>...</tool_call> blocks and capture their content
        self.tool_call_block_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
        # Regex to find <tool>...</tool> within a tool_call block content
        self.name_regex = re.compile(r"<tool>(.*?)</tool>", re.DOTALL)
        # Regex to find <key>value</key> pairs within the tool_call block content (excluding <tool> tags)
        self.param_regex = re.compile(r"<([^/>\s]+)>(.*?)</\1>", re.DOTALL)

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:

        tool_call_start_index = model_output.find(self.tool_call_start_token)

        if tool_call_start_index == -1:
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )
        
        content = model_output[:tool_call_start_index].strip()
        tool_calls_str_content = model_output[tool_call_start_index:]

        parsed_tool_calls = []
        
        try:
            # Find all occurrences of <tool_call>...</tool_call>
            xml_tool_call_contents = self.tool_call_block_regex.findall(tool_calls_str_content)

            for tool_content_str in xml_tool_call_contents:
                name_match = self.name_regex.search(tool_content_str)
                if not name_match:
                    logger.warning(f"Could not find tool name in XML block: {tool_content_str}")
                    continue
                tool_name = name_match.group(1).strip()

                parsed_arguments = {}
                
                # Find all parameter tags in the tool_call content, excluding the <tool> tag
                param_matches = self.param_regex.finditer(tool_content_str)
                
                for match in param_matches:
                    param_name = match.group(1).strip()
                    param_value_str = match.group(2).strip()
                    
                    # Skip the <tool> tag since it's not a parameter
                    if param_name == "tool":
                        continue
                    
                    target_type = None
                    # Try to get type from request.tools schema
                    if request.tools:
                        for tool_def in request.tools:
                            if tool_def.function.name == tool_name:
                                if tool_def.function.parameters and \
                                   isinstance(tool_def.function.parameters, dict) and \
                                   "properties" in tool_def.function.parameters and \
                                   isinstance(tool_def.function.parameters["properties"], dict) and \
                                   param_name in tool_def.function.parameters["properties"] and \
                                   isinstance(tool_def.function.parameters["properties"][param_name], dict):
                                    target_type = tool_def.function.parameters["properties"][param_name].get("type")
                                break
                    
                    typed_param_value = param_value_str # Default to string
                    if target_type:
                        try:
                            if target_type == "string":
                                typed_param_value = param_value_str
                            elif target_type == "integer":
                                typed_param_value = int(param_value_str)
                            elif target_type == "number":
                                typed_param_value = float(param_value_str)
                            elif target_type == "boolean":
                                typed_param_value = param_value_str.lower() == 'true'
                            elif target_type in ["object", "array"]:
                                try:
                                    typed_param_value = json.loads(param_value_str)
                                except json.JSONDecodeError:
                                    # Fallback for non-strict JSON like Python dict/list string
                                    typed_param_value = ast.literal_eval(param_value_str)
                            else: # Unknown type, keep as string
                                typed_param_value = param_value_str
                        except (ValueError, SyntaxError, json.JSONDecodeError) as e:
                            logger.warning(
                                f"Could not convert param '{param_name}' with value '{param_value_str}' "
                                f"to type '{target_type}'. Error: {e}. Using string value."
                            )
                            typed_param_value = param_value_str
                    else: # No schema type, try ast.literal_eval
                        try:
                            # For values like "true", "123", "['a', 'b']"
                            # ast.literal_eval('some_string_without_quotes') will raise SyntaxError
                            if (param_value_str.startswith("'") and param_value_str.endswith("'")) or \
                               (param_value_str.startswith('"') and param_value_str.endswith('"')) or \
                               (param_value_str.startswith('[') and param_value_str.endswith(']')) or \
                               (param_value_str.startswith('{') and param_value_str.endswith('}')) or \
                               param_value_str.lower() in ['true', 'false', 'none'] or \
                               param_value_str.replace('.', '', 1).isdigit() or \
                               (param_value_str.startswith('-') and param_value_str[1:].replace('.', '', 1).isdigit()):
                                typed_param_value = ast.literal_eval(param_value_str)
                            else: # It's likely a plain string not meant for ast.literal_eval
                                typed_param_value = param_value_str
                        except (ValueError, SyntaxError):
                            typed_param_value = param_value_str # Keep as string if ast.literal_eval fails

                    parsed_arguments[param_name] = typed_param_value
                
                parsed_tool_calls.append(ToolCall(
                    id=f"call_{random_uuid()}",
                    type="function",
                    function=FunctionCall(
                        name=tool_name,
                        arguments=json.dumps(parsed_arguments, ensure_ascii=False),
                    ),
                ))

            return ExtractedToolCallInformation(
                tools_called=len(parsed_tool_calls) > 0,
                tool_calls=parsed_tool_calls,
                content=content if content else None,
            )

        except Exception:
            logger.exception(f"Error in extracting XML tool call from response. Response: {model_output}")
            # Fallback to original model output if parsing fails catastrophically
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:

        raise NotImplementedError("Tool calling is not supported in streaming mode!")


@ToolParserManager.register_module("llama_nemotron_json")
class LlamaNemotronJSONToolParser(ToolParser):

    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)

        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id: int = -1
        self.streamed_args_for_tool: list[str] = []

        self.tool_call_start_token: str = "<TOOLCALL>"
        self.tool_call_end_token: str = "</TOOLCALL>"

        self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:

        if self.tool_call_start_token not in model_output:
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )

        else:

            try:
                str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip()
                if not str_tool_calls.startswith("["):
                    str_tool_calls = "[" + str_tool_calls
                if not str_tool_calls.endswith("]"):
                    str_tool_calls = "]" + str_tool_calls
                json_tool_calls = json.loads(str_tool_calls)
                tool_calls = []
                for tool_call in json_tool_calls:
                    try:
                        tool_calls.append(ToolCall(
                            type="function",
                            function=FunctionCall(
                                name=tool_call["name"],
                                arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \
                                    if isinstance(tool_call["arguments"], dict) else tool_call["arguments"],
                            ),
                        ))
                    except:
                        continue

                content = model_output[:model_output.rfind(self.tool_call_start_token)]

                return ExtractedToolCallInformation(
                    tools_called=True,
                    tool_calls=tool_calls,
                    content=content if content else None,
                )

            except Exception:
                logger.exception(f"Error in extracting tool call from response. Response: {model_output}")
                return ExtractedToolCallInformation(
                    tools_called=False,
                    tool_calls=[],
                    content=model_output,
                )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:

        raise NotImplementedError("Tool calling is not supported in streaming mode!")


@ToolParserManager.register_module("llama_nemotron_pythonic")
class LlamaNemotronPythonicToolParser(ToolParser):

    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)

        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id: int = -1
        self.streamed_args_for_tool: list[str] = []

        self.tool_call_start_token: str = "<TOOLCALL>"
        self.tool_call_end_token: str = "</TOOLCALL>"

        self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
        # Regex to parse pythonic function calls: function_name(arg1="value1", arg2=123, arg3=True)
        self.function_call_regex = re.compile(r"(\w+)\((.*?)\)$", re.DOTALL)

    def parse_function_arguments(self, args_str: str) -> dict:
        """Parse pythonic function arguments string into a dictionary"""
        if not args_str.strip():
            return {}
        
        # Use ast.parse to safely parse the function call arguments
        # We'll construct a temporary function call and parse it
        try:
            # Create a dummy function call to parse arguments
            dummy_code = f"dummy_func({args_str})"
            parsed = ast.parse(dummy_code, mode='eval')
            
            # Extract arguments from the AST
            call_node = parsed.body
            if not isinstance(call_node, ast.Call):
                return {}
            
            arguments = {}
            
            # Handle keyword arguments
            for keyword in call_node.keywords:
                if keyword.arg is None:  # **kwargs
                    continue
                    
                # Convert AST value to Python value
                try:
                    value = ast.literal_eval(keyword.value)
                    arguments[keyword.arg] = value
                except (ValueError, TypeError):
                    # If literal_eval fails, try to get the raw value
                    if isinstance(keyword.value, ast.Name):
                        arguments[keyword.arg] = keyword.value.id
                    elif isinstance(keyword.value, ast.Constant):
                        arguments[keyword.arg] = keyword.value.value
                    else:
                        # Fallback: convert to string
                        arguments[keyword.arg] = ast.unparse(keyword.value)
            
            # Handle positional arguments (less common in tool calls but supported)
            for i, arg in enumerate(call_node.args):
                try:
                    value = ast.literal_eval(arg)
                    arguments[f"arg_{i}"] = value
                except (ValueError, TypeError):
                    if isinstance(arg, ast.Name):
                        arguments[f"arg_{i}"] = arg.id
                    elif isinstance(arg, ast.Constant):
                        arguments[f"arg_{i}"] = arg.value
                    else:
                        arguments[f"arg_{i}"] = ast.unparse(arg)
            
            return arguments
            
        except (SyntaxError, ValueError) as e:
            logger.warning(f"Failed to parse function arguments '{args_str}': {e}")
            return {}

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:

        if self.tool_call_start_token not in model_output:
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )

        tool_call_start_index = model_output.find(self.tool_call_start_token)
        content = model_output[:tool_call_start_index].strip()
        
        try:
            # Extract content between <TOOLCALL> tags
            tool_call_matches = self.tool_call_regex.findall(model_output)
            if not tool_call_matches:
                return ExtractedToolCallInformation(
                    tools_called=False,
                    tool_calls=[],
                    content=model_output,
                )
            
            tool_calls_content = tool_call_matches[0].strip()
            
            # Split by lines to get individual function calls
            function_lines = [line.strip() for line in tool_calls_content.split('\n') if line.strip()]
            
            parsed_tool_calls = []
            
            for func_line in function_lines:
                # Parse each function call
                match = self.function_call_regex.match(func_line)
                if not match:
                    logger.warning(f"Could not parse function call: {func_line}")
                    continue
                
                function_name = match.group(1)
                args_str = match.group(2)
                
                # Parse arguments
                parsed_arguments = self.parse_function_arguments(args_str)
                
                # Apply type conversion based on schema if available
                if request.tools:
                    for tool_def in request.tools:
                        if tool_def.function.name == function_name:
                            schema_properties = {}
                            if (tool_def.function.parameters and 
                                isinstance(tool_def.function.parameters, dict) and 
                                "properties" in tool_def.function.parameters and 
                                isinstance(tool_def.function.parameters["properties"], dict)):
                                schema_properties = tool_def.function.parameters["properties"]
                            
                            # Convert arguments based on schema types
                            for arg_name, arg_value in parsed_arguments.items():
                                if arg_name in schema_properties:
                                    param_info = schema_properties[arg_name]
                                    target_type = param_info.get("type")
                                    
                                    try:
                                        if target_type == "string" and not isinstance(arg_value, str):
                                            parsed_arguments[arg_name] = str(arg_value)
                                        elif target_type == "integer" and not isinstance(arg_value, int):
                                            parsed_arguments[arg_name] = int(arg_value)
                                        elif target_type == "number" and not isinstance(arg_value, (int, float)):
                                            parsed_arguments[arg_name] = float(arg_value)
                                        elif target_type == "boolean" and not isinstance(arg_value, bool):
                                            if isinstance(arg_value, str):
                                                parsed_arguments[arg_name] = arg_value.lower() in ['true', '1', 'yes']
                                            else:
                                                parsed_arguments[arg_name] = bool(arg_value)
                                        elif target_type in ["object", "array"]:
                                            if isinstance(arg_value, str):
                                                try:
                                                    parsed_arguments[arg_name] = json.loads(arg_value)
                                                except json.JSONDecodeError:
                                                    # Keep as string if JSON parsing fails
                                                    pass
                                    except (ValueError, TypeError) as e:
                                        logger.warning(f"Type conversion failed for {arg_name}: {e}")
                                        # Keep original value if conversion fails
                            break
                
                parsed_tool_calls.append(ToolCall(
                    id=f"call_{random_uuid()}",
                    type="function",
                    function=FunctionCall(
                        name=function_name,
                        arguments=json.dumps(parsed_arguments, ensure_ascii=False),
                    ),
                ))

            return ExtractedToolCallInformation(
                tools_called=len(parsed_tool_calls) > 0,
                tool_calls=parsed_tool_calls,
                content=content if content else None,
            )

        except Exception:
            logger.exception(f"Error in extracting pythonic tool call from response. Response: {model_output}")
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:

        raise NotImplementedError("Tool calling is not supported in streaming mode!")