Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions cecli/tools/edit_text.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json

from cecli.helpers import responses
from cecli.helpers.hashline import (
ContentHashError,
apply_hashline_operations,
Expand All @@ -12,6 +11,7 @@
apply_change,
format_tool_result,
handle_tool_error,
normalize_json_array,
validate_file_for_edit,
)
from cecli.tools.utils.output import color_markers, tool_footer, tool_header
Expand Down Expand Up @@ -96,6 +96,23 @@ class Tool(BaseTool):
},
}

@classmethod
def _coerce_edits(cls, edits) -> list[dict]:
"""Normalize ``edits`` for display (local models often double-encode arrays)."""
try:
normalized = normalize_json_array(edits, param_name="edits", allow_empty=True)
except ToolError:
return []
out: list[dict] = []
for item in normalized:
if isinstance(item, dict):
out.append(item)
elif isinstance(item, str):
parsed = responses.try_parse_json_value(item)
if isinstance(parsed, dict):
out.append(parsed)
return out

@classmethod
def execute(
cls,
Expand All @@ -117,16 +134,18 @@ def execute(

tool_name = "EditText"
try:
# 1. Validate edits parameter
if not isinstance(edits, list):
raise ToolError("edits parameter must be an array")
edits = normalize_json_array(edits, param_name="edits")

if len(edits) == 0:
raise ToolError("edits array cannot be empty")

# 2. Group edits by file_path
edits_by_file = {}
for i, edit in enumerate(edits):
if not isinstance(edit, dict):
raise ToolError(
f"Edit {i + 1} must be an object, got {type(edit).__name__}"
)
edit_file_path = edit.get("file_path")
if edit_file_path is None:
raise ToolError(f"Edit {i + 1} missing required file_path parameter")
Expand Down Expand Up @@ -370,17 +389,20 @@ def execute(
def format_output(cls, coder, mcp_server, tool_response):
color_start, color_end = color_markers(coder)

try:
params = json.loads(tool_response.function.arguments)
except json.JSONDecodeError:
coder.io.tool_error("Invalid Tool JSON")
params = responses.parse_tool_arguments(tool_response.function.arguments or "")
if "@error" in params:
coder.io.tool_error(f"Invalid Tool JSON: {params['@error']}")
return

tool_header(coder=coder, mcp_server=mcp_server, tool_response=tool_response)

# Group edits by file_path for display
edits_by_file = {}
edits = cls._coerce_edits(params.get("edits", []))

for i, edit in enumerate(params.get("edits", [])):
for i, edit in enumerate(edits):
if not isinstance(edit, dict):
continue
edit_file_path = edit.get("file_path")
if edit_file_path not in edits_by_file:
edits_by_file[edit_file_path] = []
Expand All @@ -397,7 +419,7 @@ def format_output(cls, coder, mcp_server, tool_response):
for edit_index, edit in file_edits:
operation = edit.get("operation", "replace")

if len(params.get("edits", [])) > 1:
if len(edits) > 1:
coder.io.tool_output(
f"{color_start}{OPERATION_NOUNS[operation]}_{edit_index + 1}:{color_end}"
)
Expand Down
13 changes: 8 additions & 5 deletions cecli/tools/utils/base_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod

from cecli.tools.utils.helpers import handle_tool_error, normalize_json_array
from cecli.tools.utils.helpers import ToolError, handle_tool_error, normalize_json_array
from cecli.tools.utils.output import print_tool_response


Expand Down Expand Up @@ -82,6 +82,10 @@ def process_response(cls, coder, params):
)
return handle_tool_error(coder, tool_name, ValueError(error_msg))

for param in cls.LIST_PARAMS:
if param in params:
params[param] = normalize_json_array(params[param], param_name=param)

# Check for repeated invocations if TRACK_INVOCATIONS is enabled
if cls.TRACK_INVOCATIONS:
tool_name = None
Expand Down Expand Up @@ -122,10 +126,6 @@ def process_response(cls, coder, params):
coder, tool_name, ValueError(error_msg), add_traceback=False
)

for param in cls.LIST_PARAMS:
if param in params:
params[param] = normalize_json_array(params[param], param_name=param)

# Add current invocation to history (keeping only last 3)
if params:
cls._invocations[tool_name].append((current_params_tuple, params))
Expand All @@ -134,6 +134,9 @@ def process_response(cls, coder, params):

try:
return cls.execute(coder, **params)
except ToolError as e:
tool_name = (cls.SCHEMA or {}).get("function", {}).get("name", cls.__name__)
return handle_tool_error(coder, tool_name, e, add_traceback=False)
except Exception as e:
return handle_tool_error(coder, cls.SCHEMA.get("function").get("name"), e)

Expand Down
101 changes: 101 additions & 0 deletions tests/tools/test_edit_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""EditText tool — double-encoded edits and format_output safety."""

import json
from types import SimpleNamespace
from unittest.mock import Mock

from cecli.tools import edit_text
from cecli.tools.utils.base_tool import BaseTool


class DummyIO:
def __init__(self):
self.tool_output = Mock()
self.tool_error = Mock()
self.tool_warning = Mock()


class DummyCoder:
def __init__(self):
self.io = DummyIO()
self.pretty = False
self.verbose = False


class _NoTrackTool(BaseTool):
"""Minimal tool mirroring EditText LIST_PARAMS + TRACK_INVOCATIONS=False."""

NORM_NAME = "notrack"
TRACK_INVOCATIONS = False
LIST_PARAMS = ["edits"]
SCHEMA = {
"function": {
"name": "NoTrack",
"parameters": {
"type": "object",
"properties": {
"edits": {"type": "array"},
},
"required": ["edits"],
},
}
}

@classmethod
def execute(cls, coder, edits=None, **kwargs):
if not isinstance(edits, list):
return f"edits type={type(edits).__name__}"
return f"edits len={len(edits)}"


def test_list_params_normalized_when_track_invocations_disabled():
coder = DummyCoder()
edits_json = json.dumps(
[{"file_path": "pubspec.yaml", "operation": "replace", "start_line": "@000", "end_line": "@000"}]
)
result = _NoTrackTool.process_response(coder, {"edits": edits_json})
assert result == "edits len=1"


def test_format_output_accepts_edits_as_json_string():
coder = DummyCoder()
edits_json = json.dumps(
[
{
"file_path": "pubspec.yaml",
"operation": "replace",
"start_line": "@000",
"end_line": "@000",
"text": "name: demo",
}
]
)
args = json.dumps({"edits": edits_json})
tool_response = SimpleNamespace(function=SimpleNamespace(name="EditText", arguments=args))

edit_text.Tool.format_output(
coder,
mcp_server=SimpleNamespace(name="test"),
tool_response=tool_response,
)

output_text = "\n".join(call.args[0] for call in coder.io.tool_output.call_args_list)
assert "pubspec.yaml" in output_text
coder.io.tool_error.assert_not_called()


def test_format_output_string_edits_does_not_crash():
"""Regression: iterating a JSON string used to raise AttributeError in format_output."""
coder = DummyCoder()
edits_json = json.dumps([{"file_path": "a.txt", "operation": "replace"}])
tool_response = SimpleNamespace(
function=SimpleNamespace(name="EditText", arguments=json.dumps({"edits": edits_json}))
)

edit_text.Tool.format_output(
coder,
mcp_server=SimpleNamespace(name="test"),
tool_response=tool_response,
)

coder.io.tool_error.assert_not_called()
Loading