import ast import json import logging import re from typing import Any, List, Optional from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, ToolCallItem, _GetInfoFunc, ) logger = logging.getLogger(__name__) class Qwen3CoderDetector(BaseFormatDetector): def __init__(self): super().__init__() # Sentinel tokens self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.tool_call_prefix: str = "(.*?)", re.DOTALL) self.tool_call_function_regex = re.compile( r"||(?=)|$)", re.DOTALL, ) # Streaming State # Base class already initializes _buffer, we just use it directly # No need to check with hasattr - we control the lifecycle through inheritance # Index pointing to the next character to be processed in buffer self.parsed_pos: int = 0 # Parameter count inside the current tool being processed, used to determine whether to add comma self.current_tool_param_count: int = 0 # Flag indicating whether current tool has already sent '{' self.json_started: bool = False # [FIX] New state flag: mark whether inside tool_call structure block self.is_inside_tool_call: bool = False # Initialize attributes that were missing in the original PR self.current_func_name: Optional[str] = None def has_tool_call(self, text: str) -> bool: return self.tool_call_start_token in text def _get_arguments_config( self, func_name: str, tools: Optional[list[Tool]] ) -> dict: """Extract argument configuration for a function.""" if tools is None: return {} for config in tools: try: config_type = config.type config_function = config.function config_function_name = config_function.name except AttributeError: continue if config_type == "function" and config_function_name == func_name: try: params = config_function.parameters except AttributeError: return {} if isinstance(params, dict) and "properties" in params: return params["properties"] elif isinstance(params, dict): return params else: return {} logger.warning(f"Tool '{func_name}' is not defined in the tools list.") return {} def _convert_param_value( self, param_value: str, param_name: str, param_config: dict, func_name: str ) -> Any: """Convert parameter value based on its type in the schema.""" # Handle null value for any type if param_value.lower() == "null": return None if param_name not in param_config: if param_config != {}: logger.warning( f"Parsed parameter '{param_name}' is not defined in the tool " f"parameters for tool '{func_name}', directly returning the string value." ) return param_value if ( isinstance(param_config[param_name], dict) and "type" in param_config[param_name] ): param_type = str(param_config[param_name]["type"]).strip().lower() else: param_type = "string" if param_type in ["string", "str", "text", "varchar", "char", "enum"]: return param_value elif ( param_type.startswith("int") or param_type.startswith("uint") or param_type.startswith("long") or param_type.startswith("short") or param_type.startswith("unsigned") ): try: param_value = int(param_value) except Exception: logger.warning( f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " f"'{func_name}', degenerating to string." ) return param_value elif param_type.startswith("num") or param_type.startswith("float"): try: maybe_convert = ( False if "." in param_value or "e" in param_value.lower() else True ) param_value: float = float(param_value) if maybe_convert and param_value.is_integer(): param_value = int(param_value) except Exception: logger.warning( f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " f"'{func_name}', degenerating to string." ) return param_value elif param_type in ["boolean", "bool", "binary"]: param_value = param_value.lower() if param_value not in ["true", "false"]: logger.warning( f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." ) return param_value == "true" else: if ( param_type in ["object", "array", "arr"] or param_type.startswith("dict") or param_type.startswith("list") ): try: param_value = json.loads(param_value) return param_value except Exception: logger.warning( f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool " f"'{func_name}', will try other methods to parse it." ) try: param_value = ast.literal_eval(param_value) # safer except Exception: logger.warning( f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string." ) return param_value def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: """One-shot parsing for non-streaming scenarios.""" if self.tool_call_start_token not in text: return StreamingParseResult(normal_text=text) calls = [] try: # Simple cleanup of the text to find tool calls # Note: This is a simplified regex approach consistent with vLLM raw_tool_calls = self.tool_call_regex.findall(text) if not raw_tool_calls: # Fallback: maybe the whole text is inside the tag or tags are stripped if self.tool_call_prefix in text: raw_tool_calls = [text] tool_idx = 0 for tool_content in raw_tool_calls: # Find function calls funcs = self.tool_call_function_regex.findall(tool_content) for func_match in funcs: func_body = func_match[0] or func_match[1] if ">" not in func_body: continue name_end = func_body.index(">") func_name = func_body[:name_end] params_str = func_body[name_end + 1 :] param_config = self._get_arguments_config(func_name, tools) parsed_params = {} for p_match in self.tool_call_parameter_regex.findall(params_str): if ">" not in p_match: continue p_idx = p_match.index(">") p_name = p_match[:p_idx] p_val = p_match[p_idx + 1 :] # Remove prefixing and trailing \n if p_val.startswith("\n"): p_val = p_val[1:] if p_val.endswith("\n"): p_val = p_val[:-1] parsed_params[p_name] = self._convert_param_value( p_val, p_name, param_config, func_name ) calls.append( ToolCallItem( tool_index=tool_idx, name=func_name, parameters=json.dumps(parsed_params, ensure_ascii=False), ) ) tool_idx += 1 # Determine normal text (text before the first tool call) start_idx = text.find(self.tool_call_start_token) if start_idx == -1: start_idx = text.find(self.tool_call_prefix) normal_text = text[:start_idx] if start_idx > 0 else "" return StreamingParseResult(normal_text=normal_text, calls=calls) except Exception as e: logger.error(f"Error in detect_and_parse: {e}") return StreamingParseResult(normal_text=text) def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: """ Robust cursor-based streaming parser. """ self._buffer += new_text # Guard against empty buffer if not self._buffer: return StreamingParseResult() calls = [] normal_text_chunks = [] while True: # Working text slice current_slice = self._buffer[self.parsed_pos :] # Optimization: If almost empty, wait for more if not current_slice: break # ------------------------------------------------------- # 1. Priority detection: check if it's the start of Tool Call # ------------------------------------------------------- if current_slice.startswith(self.tool_call_start_token): self.parsed_pos += len(self.tool_call_start_token) self.is_inside_tool_call = True continue # ------------------------------------------------------- # 2. Function Name: # ------------------------------------------------------- if current_slice.startswith(self.tool_call_prefix): end_angle = current_slice.find(">") if end_angle != -1: func_name = current_slice[len(self.tool_call_prefix) : end_angle] self.current_tool_id += 1 self.current_tool_name_sent = True self.current_tool_param_count = 0 self.json_started = False self.current_func_name = func_name calls.append( ToolCallItem( tool_index=self.current_tool_id, name=func_name, parameters="", ) ) self.parsed_pos += end_angle + 1 continue else: # Incomplete tag break # ------------------------------------------------------- # 3. Parameter: value... # ------------------------------------------------------- if current_slice.startswith(self.parameter_prefix): name_end = current_slice.find(">") if name_end != -1: value_start_idx = name_end + 1 rest_of_slice = current_slice[value_start_idx:] # A parameter can end in multiple ways: # 1. [Normal] Encounter # 2. [Abnormal] Encounter next # So we need to find the smallest one as the parameter end position. cand_end_param = rest_of_slice.find(self.parameter_end_token) cand_next_param = rest_of_slice.find(self.parameter_prefix) cand_end_func = rest_of_slice.find(self.function_end_token) candidates = [] if cand_end_param != -1: candidates.append( (cand_end_param, len(self.parameter_end_token)) ) if cand_next_param != -1: candidates.append((cand_next_param, 0)) if cand_end_func != -1: candidates.append((cand_end_func, 0)) if candidates: best_cand = min(candidates, key=lambda x: x[0]) end_pos = best_cand[0] end_token_len = best_cand[1] param_name = current_slice[ len(self.parameter_prefix) : name_end ] raw_value = rest_of_slice[:end_pos] # Cleanup value if raw_value.startswith("\n"): raw_value = raw_value[1:] if raw_value.endswith("\n"): raw_value = raw_value[:-1] # JSON Construction if not self.json_started: calls.append( ToolCallItem( tool_index=self.current_tool_id, parameters="{" ) ) self.json_started = True param_config = self._get_arguments_config( self.current_func_name, tools ) converted_val = self._convert_param_value( raw_value, param_name, param_config, self.current_func_name ) # Construct JSON fragment: "key": value # Note: We must be careful with json.dumps to ensure valid JSON streaming json_key_val = f"{json.dumps(param_name)}: {json.dumps(converted_val, ensure_ascii=False)}" if self.current_tool_param_count > 0: fragment = f", {json_key_val}" else: fragment = json_key_val calls.append( ToolCallItem( tool_index=self.current_tool_id, parameters=fragment ) ) self.current_tool_param_count += 1 # Advance cursor total_len = (name_end + 1) + end_pos + end_token_len self.parsed_pos += total_len continue # Incomplete parameter tag or value break # ------------------------------------------------------- # 4. Function End: # ------------------------------------------------------- if current_slice.startswith(self.function_end_token): if not self.json_started: calls.append( ToolCallItem(tool_index=self.current_tool_id, parameters="{") ) self.json_started = True calls.append( ToolCallItem(tool_index=self.current_tool_id, parameters="}") ) self.parsed_pos += len(self.function_end_token) self.current_func_name = None continue # ------------------------------------------------------- # 5. Tool Call End: # ------------------------------------------------------- if current_slice.startswith(self.tool_call_end_token): self.parsed_pos += len(self.tool_call_end_token) self.is_inside_tool_call = False # [FIX] Exit tool call region continue # ------------------------------------------------------- # 6. Handling content / whitespace / normal text # ------------------------------------------------------- # If current position is not the start of a tag (i.e., doesn't start with <), it might be plain text, # or a newline between two tags. # But we need to be careful not to output truncated tags like " 0: self._buffer = self._buffer[self.parsed_pos :] self.parsed_pos = 0 normal_text = "".join(normal_text_chunks) if normal_text_chunks else "" return StreamingParseResult(calls=calls, normal_text=normal_text) def supports_structural_tag(self) -> bool: return False def structure_info(self) -> _GetInfoFunc: raise NotImplementedError