diff --git a/cecli/args.py b/cecli/args.py index eb05b51231e..83286782492 100644 --- a/cecli/args.py +++ b/cecli/args.py @@ -276,7 +276,10 @@ def get_parser(default_config_files, git_root): group.add_argument( "--retries", metavar="RETRIES_JSON", - help="Specify LLM retry configuration as a JSON string", + help=( + 'Specify LLM retry configuration as a JSON/YAML string (e.g., \'{"retry_on_empty": ' + "true}')" + ), default=None, ) diff --git a/cecli/args_formatter.py b/cecli/args_formatter.py index 01b9bc94094..aaa9463c3b3 100644 --- a/cecli/args_formatter.py +++ b/cecli/args_formatter.py @@ -132,6 +132,16 @@ def _format_action(self, action): break switch = switch.lstrip("-") + if switch == "retries": + parts.append(f"## {action.help}") + parts.append("#retries:") + parts.append("# retry-timeout: 60") + parts.append("# retry-backoff-factor: 2.0") + parts.append("# retry-on-unavailable: true") + parts.append("# retry-on-empty: false") + parts.append("") + return "\n".join(parts) + if isinstance(action, argparse._StoreTrueAction): default = False elif isinstance(action, argparse._StoreConstAction): diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index d57cfad9a9d..33a5d148c82 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -1115,7 +1115,7 @@ def _generate_tool_context(self, repetitive_tools): context_parts.append("## File Editing Tools Disabled") context_parts.append( "File editing tools are currently disabled. Use `ReadRange` to determine the" - " current content hash prefixes needed to perform an edit and activate them when" + " current content ID prefixes needed to perform an edit and activate them when" " you are ready to edit a file." ) diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index c28dc866cc6..4d5b26e97df 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -46,6 +46,7 @@ from cecli.helpers.io_proxy import IOProxy from cecli.helpers.observations.service import ObservationService from cecli.helpers.profiler import TokenProfiler +from cecli.helpers.threading import ThreadSafeEvent from cecli.history import ChatSummary from cecli.hooks import HookIntegration from cecli.io import ConfirmGroup, InputOutput @@ -91,6 +92,10 @@ class FinishReasonLength(Exception): pass +class EmptyResponseError(Exception): + pass + + def wrap_fence(name): return f"<{name}>", f"" @@ -420,7 +425,7 @@ def __init__( # Each contains "included" and "excluded" sets that filter from the global singletons self.registered_tools = {"included": set(), "excluded": set()} self.registered_servers = {"included": set(), "excluded": set()} - self.interrupt_event = asyncio.Event() + self.interrupt_event = ThreadSafeEvent() self.uuid = str(generate_unique_id()) if uuid: @@ -1643,6 +1648,7 @@ async def output_task(self, preproc): async def generate(self, user_message, preproc): await asyncio.sleep(0.1) + self.interrupt_event.clear() try: if self.enable_context_compaction: @@ -2402,6 +2408,39 @@ async def format_in_executor(): async for chunk in self.send(messages, tools=self.get_tool_list()): yield chunk break + except EmptyResponseError: + self.io.tool_warning(self.empty_llm_tool_warning()) + + retry_on_empty = False + retries_config = self.get_active_model().retries + if isinstance(retries_config, str): + try: + retries_config = json.loads(retries_config) + except json.JSONDecodeError: + self.io.tool_warning( + f"Could not parse retries config: {retries_config}" + ) + retries_config = {} + if isinstance(retries_config, dict): + retry_on_empty = retries_config.get("retry_on_empty", False) + + if not retry_on_empty: + break + + retry_delay *= 2 + if retry_delay > RETRY_TIMEOUT: + self.io.tool_error("Retry timeout exceeded on empty response.") + break + + self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...") + + _res, interrupted_sleep = await coroutines.interruptible( + asyncio.sleep(retry_delay), self.interrupt_event + ) + if interrupted_sleep: + interrupted = True + break + continue except litellm_ex.exceptions_tuple() as err: ex_info = litellm_ex.get_ex_info(err) @@ -3252,6 +3291,7 @@ async def send(self, messages, model=None, functions=None, tools=None): self.interrupt_event.clear() self.got_reasoning_content = False self.ended_reasoning_content = False + self.empty_response = False self._streaming_buffer_length = 0 self.io.reset_streaming_response() @@ -3302,6 +3342,9 @@ async def send(self, messages, model=None, functions=None, tools=None): else: await self.show_send_output(completion) + if self.empty_response: + raise EmptyResponseError + response, func_err, content_err = self.consolidate_chunks() if response: @@ -3382,7 +3425,8 @@ async def show_send_output(self, completion): and not len(self.partial_response_tool_calls) and not len(self.partial_response_reasoning_content) ): - self.io.tool_warning(self.empty_llm_tool_warning()) + self.empty_response = True + return self.io.assistant_output(show_resp, pretty=self.show_pretty()) @@ -3539,7 +3583,8 @@ async def show_send_output_stream(self, completion): return if not received_content and len(self.partial_response_tool_calls) == 0: - self.io.tool_warning(self.empty_llm_tool_warning()) + self.empty_response = True + return def consolidate_chunks(self): if self.partial_response_consolidated: diff --git a/cecli/commands/core.py b/cecli/commands/core.py index 2ad884fabd3..5242b73397a 100644 --- a/cecli/commands/core.py +++ b/cecli/commands/core.py @@ -1,4 +1,3 @@ -import asyncio import json import re import sys @@ -7,6 +6,7 @@ from cecli.commands.utils.registry import CommandRegistry from cecli.helpers import nested, plugin_manager from cecli.helpers.file_searcher import handle_core_files +from cecli.helpers.threading import ThreadSafeEvent from cecli.repo import ANY_GIT_ERROR @@ -94,7 +94,7 @@ def __init__( self.custom_commands = nested.getter(customizations, "command-paths", []) self._load_custom_commands(self.custom_commands) - self.cmd_running_event = asyncio.Event() + self.cmd_running_event = ThreadSafeEvent() self.cmd_running_event.set() self.last_command_show_notification = True diff --git a/cecli/helpers/conversation/files.py b/cecli/helpers/conversation/files.py index 50a7a3de239..3acfdee9c52 100644 --- a/cecli/helpers/conversation/files.py +++ b/cecli/helpers/conversation/files.py @@ -281,14 +281,17 @@ def update_file_diff(self, fname: str) -> Optional[str]: diff_message = { "role": "user", "content": ( - f"{rel_fname} has been updated. Here is a git diff of the changes to" - f" review:\n\n{diff}" + f"{rel_fname} has been updated. Review this git diff of the changes to" + f" ensure the modifications are intended:\n\n{diff}" ), } assistant_msg = { "role": "assistant", - "content": f"Thank you for sharing this diff of the updates to {rel_fname}.", + "content": ( + f"Thank you for sharing this diff of the updates to {rel_fname}." + " I will review their contents next turn." + ), } ConversationService.get_manager(coder).add_message( diff --git a/cecli/helpers/conversation/integration.py b/cecli/helpers/conversation/integration.py index f649689323d..5e59b7dc6f5 100644 --- a/cecli/helpers/conversation/integration.py +++ b/cecli/helpers/conversation/integration.py @@ -842,7 +842,7 @@ def add_file_context_messages(self, promote_messages=True) -> None: user_msg = { "role": "user", - "content": f"Hash-Prefixed Context For:\n{rel_fname}\n\n{context_content}", + "content": f"ID-Prefixed Context For:\n{rel_fname}\n\n{context_content}", } assistant_msg = { diff --git a/cecli/helpers/coroutines.py b/cecli/helpers/coroutines.py index 3bab125348f..07e27314a5e 100644 --- a/cecli/helpers/coroutines.py +++ b/cecli/helpers/coroutines.py @@ -1,5 +1,7 @@ import asyncio +from cecli.helpers.threading import ThreadSafeEvent + async def interruptible_async_generator(async_generator, interrupt_event): """ @@ -57,7 +59,7 @@ async def interruptible(coroutine, interrupt_event): - If interrupted: (None, True) """ if interrupt_event is None: - interrupt_event = asyncio.Event() + interrupt_event = ThreadSafeEvent() main_task = asyncio.create_task(coroutine) interrupt_task = asyncio.create_task(interrupt_event.wait()) diff --git a/cecli/helpers/hashline.py b/cecli/helpers/hashline.py index 110c8ebab25..cba3057787a 100644 --- a/cecli/helpers/hashline.py +++ b/cecli/helpers/hashline.py @@ -650,6 +650,20 @@ def _apply_start_stitching( # The replacement line matches the line being replaced # Don't stitch to a line in lines_before_range continue + + # Require 2 consecutive matching lines to avoid false positives + # (single boilerplate lines like "import sys" or "def foo():" + # are too likely to be coincidental) + if line_idx + 1 < len(replacement_lines) and match_index + 1 < len( + lines_before_range_normalized + ): + next_repl = replacement_lines[line_idx + 1] + next_repl_stripped = strip_hashline(next_repl) + if not next_repl_stripped.endswith("\n"): + next_repl_stripped += "\n" + if next_repl_stripped != lines_before_range_normalized[match_index + 1]: + continue # Only 1 line matches — likely coincidental + # Found a line that already exists before the range! # This is a non-contiguous match - we need to "stitch" the replacement # at this exact content match to prevent duplicate code structures @@ -694,9 +708,9 @@ def _apply_start_stitching( start_idx = new_start_idx replacement_lines = new_replacement_lines else: - # Can't extend backward due to overlap, but we can still truncate - # the replacement text to avoid duplication - replacement_lines = new_replacement_lines + # Can't extend backward due to overlap with another operation + # Don't truncate without extending — that would silently lose content + continue # Try next line instead # We've found our stitching point, break out of the loop break @@ -772,6 +786,15 @@ def _apply_end_stitching( # Check if this line exists anywhere in lines_after_range_normalized try: match_index = lines_after_range_normalized.index(replacement_line_stripped) + + # Require 2 consecutive matching lines to reduce false positives + if line_idx - 1 >= 0 and match_index - 1 >= 0: + prev_repl = replacement_lines[line_idx - 1] + prev_repl_stripped = strip_hashline(prev_repl) + if not prev_repl_stripped.endswith("\n"): + prev_repl_stripped += "\n" + if prev_repl_stripped != lines_after_range_normalized[match_index - 1]: + continue # Only 1 line matches — likely coincidental # Found a line that already exists after the range! # This is a non-contiguous match - we need to "stitch" the replacement # at this exact content match to prevent duplicate code structures @@ -900,109 +923,6 @@ def _apply_range_shifting(hashed_lines, resolved_ops): return resolved_ops -# Regex configuration -RE_CODE_NOISE = r'(#.*|//.*|/\*[\s\S]*?\*/|"(?:\\.|[^"\\])*"|\'(?:\\.|[^\'\\])*\')' - - -def get_brace_balance(lines_to_check: list[str]) -> int: - """ - Calculates the net curly brace debt of a list of lines. - Automatically strips hashlines, comments, and string literals. - """ - text = "".join(lines_to_check) - clean_code = strip_hashline(text) - clean_code = re.sub(RE_CODE_NOISE, "", clean_code) - return clean_code.count("{") - clean_code.count("}") - - -def _apply_closure_safeguard(hashed_lines, resolved_ops): - """ - Enhanced closure safeguard with dynamic bidirectional search. - """ - # Tune these to adjust how far the 'healing' logic searches - MAX_LOOK_DOWN = 5 - # Note: We'll calculate the actual MAX_LOOK_UP per operation - # to ensure we don't scan past the start_idx. - - for i, resolved in enumerate(resolved_ops): - op = resolved["op"] - if op["operation"] not in {"replace", "delete"}: - continue - - replacement_text = op.get("text", "") or "" - replacement_lines = replacement_text.splitlines(keepends=True) - - # --- PHASE 1: BIDIRECTIONAL STRUCTURAL HEALING --- - if get_brace_balance([replacement_text]) == 0: - start_idx = resolved["start_idx"] - orig_end_idx = resolved["end_idx"] - - if get_brace_balance(hashed_lines[start_idx : orig_end_idx + 1]) != 0: - # Dynamic Search List Generation - # We limit look-up so we don't scan before the start_idx - actual_max_up = orig_end_idx - start_idx - actual_max_down = max(MAX_LOOK_DOWN, orig_end_idx - start_idx) - search_offsets = [] - - # Generate alternating offsets: [1, -1, 2, -2, ... N] - for dist in range(1, max(actual_max_down, actual_max_up) + 1): - if dist <= actual_max_down: - search_offsets.append(dist) - if dist <= actual_max_up: - search_offsets.append(-dist) - - for offset in search_offsets: - candidate_end = orig_end_idx + offset - - # Safety: check bounds and avoid overlapping other ops - if candidate_end < start_idx or candidate_end >= len(hashed_lines): - continue - - if any( - j != i and (other["start_idx"] <= candidate_end <= other["end_idx"]) - for j, other in enumerate(resolved_ops) - ): - continue - - if get_brace_balance(hashed_lines[start_idx : candidate_end + 1]) == 0: - resolved["end_idx"] = candidate_end - break - - # --- PHASE 2: CONTRACTION (Indentation Guard) --- - # Prevents replacing an outer-scope brace if the replacement text already - # includes its own correctly indented closer. - if not replacement_lines: - continue - - last_repl_line = strip_hashline(replacement_lines[-1]) - last_repl_stripped = last_repl_line.strip().rstrip(";,") - - if last_repl_stripped and last_repl_stripped[-1] in "})]": - # Calculate replacement indent - repl_indent = len(last_repl_line) - len(last_repl_line.lstrip(" \t")) - - if resolved["end_idx"] < len(hashed_lines): - end_line = strip_hashline(hashed_lines[resolved["end_idx"]]) - check_end = end_line.strip().rstrip(";,") - - if check_end and check_end[-1] in "})]": - # Calculate indent of the existing brace in the file - file_indent = len(end_line) - len(end_line.lstrip(" \t")) - - # If the file's brace is less indented, it belongs to an outer scope - if file_indent < repl_indent and resolved["end_idx"] > resolved["start_idx"]: - new_end_idx = resolved["end_idx"] - 1 - - # Safety: don't contract into another operation's territory - if not any( - j != i and (other["start_idx"] <= new_end_idx <= other["end_idx"]) - for j, other in enumerate(resolved_ops) - ): - resolved["end_idx"] = new_end_idx - - return resolved_ops - - def _merge_replace_operations(resolved_ops): """ Merge contiguous or overlapping replace operations. @@ -1411,9 +1331,6 @@ def apply_hashline_operations( resolved_ops = _merge_replace_operations(resolved_ops) # Apply content-aware range expansion/shifting for replace operations # resolved_ops = _apply_range_shifting(hashed_lines, resolved_ops) - # Apply closure safeguard for braces/brackets - resolved_ops = _apply_closure_safeguard(hashed_lines, resolved_ops) - # Sort by start_idx descending to apply from bottom to top # When operations have same start_idx, apply in order: insert, replace, delete # This ensures correct behavior when multiple operations target the same line diff --git a/cecli/helpers/hashpos/hashpos.py b/cecli/helpers/hashpos/hashpos.py index dc26801ce26..516052012c9 100644 --- a/cecli/helpers/hashpos/hashpos.py +++ b/cecli/helpers/hashpos/hashpos.py @@ -52,7 +52,7 @@ def generate_private_id(self, text: str) -> str: def generate_public_id(self, text: str, line_idx: int) -> str: """ Generates a 4-char Base64 ID combining modulo buckets and context hash. - Layout: [2-bit b1] [10-bit Hash A] [2-bit b2] [10-bit Hash B] + Layout: [2-bit b1] [2-bit b2] [10-bit Hash A] [10-bit Hash B] """ b1, b2 = self._get_region_bits(line_idx) neighborhood_hash = self._get_neighborhood_hash(line_idx) @@ -62,8 +62,7 @@ def generate_public_id(self, text: str, line_idx: int) -> str: hash_b = neighborhood_hash & 0x3FF # Construct the mixed 24-bit integer - packed = (b1 << 22) | (hash_a << 12) | (b2 << 10) | hash_b - + packed = (b1 << 22) | (b2 << 20) | (hash_a << 10) | hash_b res = "" for _ in range(4): res += self.B64[packed % 64] @@ -79,10 +78,9 @@ def unpack_public_id(self, public_id: str) -> tuple[int, int]: packed |= self.B64.index(char) << (6 * i) b1 = (packed >> 22) & 3 - hash_a = (packed >> 12) & 0x3FF - b2 = (packed >> 10) & 3 + b2 = (packed >> 20) & 3 + hash_a = (packed >> 10) & 0x3FF hash_b = packed & 0x3FF - mod_val = (b1 << 2) | b2 neighborhood_hash = (hash_a << 10) | hash_b @@ -223,6 +221,6 @@ def normalize(hashpos_str: str) -> str: # If no pattern matches, raise error raise ValueError( f"Invalid HashPos format '{hashpos_str}'. " - r"Expected \"{hash_prefix}\" " - r"where hash_prefix is exactly 4 characters from the set [0-9a-zA-Z\~_@]." + r"Expected \"{content ID}\" " + r"where content ID is exactly 4 characters from the set [0-9a-zA-Z\~_@]." ) diff --git a/cecli/helpers/threading.py b/cecli/helpers/threading.py new file mode 100644 index 00000000000..8cd4a70a7d4 --- /dev/null +++ b/cecli/helpers/threading.py @@ -0,0 +1,46 @@ +import asyncio +import threading + + +class ThreadSafeEvent: + def __init__(self): + self._async_event = asyncio.Event() + self._thread_event = threading.Event() + + @staticmethod + def _get_loop(): + """Dynamically resolve the running event loop (not cached).""" + try: + return asyncio.get_running_loop() + except RuntimeError: + return None + + def set(self): + """Can be called from ANY thread or coroutine safely.""" + # Unblock threads + self._thread_event.set() + # Unblock async loop + if loop := self._get_loop(): + loop.call_soon_threadsafe(self._async_event.set) + else: + self._async_event.set() + + def clear(self): + """Can be called from ANY thread or coroutine safely.""" + self._thread_event.clear() + if loop := self._get_loop(): + loop.call_soon_threadsafe(self._async_event.clear) + else: + self._async_event.clear() + + def is_set(self): + """Thread-safe check.""" + return self._thread_event.is_set() + + def thread_wait(self, timeout=None): + """Call this from your background OS Thread.""" + return self._thread_event.wait(timeout=timeout) + + async def wait(self): + """Call this (with await) from your Async Coroutines.""" + await self._async_event.wait() diff --git a/cecli/io.py b/cecli/io.py index 47cbee2eccd..8bf7a3c657e 100644 --- a/cecli/io.py +++ b/cecli/io.py @@ -43,6 +43,7 @@ from cecli.commands import SwitchCoderSignal from cecli.helpers import coroutines +from cecli.helpers.threading import ThreadSafeEvent from cecli.report import update_error_prefix from .dump import dump # noqa: F401 @@ -395,7 +396,7 @@ def __init__( self.linear = False # State tracking for confirmation input - self.confirmation_in_progress_event = asyncio.Event() + self.confirmation_in_progress_event = ThreadSafeEvent() self.confirmation_in_progress_event.set() # Initially set, meaning no confirmation in progress self.confirmation_acknowledgement = False self.confirmation_input_active = False diff --git a/cecli/linter.py b/cecli/linter.py index 9e91d826fd8..434724e2bdf 100644 --- a/cecli/linter.py +++ b/cecli/linter.py @@ -12,6 +12,7 @@ from cecli.dump import dump # noqa: F401 from cecli.helpers.grep_ast import TreeContext, filename_to_lang from cecli.helpers.grep_ast.tsl import get_parser # noqa: E402 +from cecli.helpers.threading import ThreadSafeEvent from cecli.run_cmd import run_cmd_async, run_cmd_subprocess # noqa: F401 # tree_sitter is throwing a FutureWarning @@ -22,7 +23,7 @@ class Linter: def __init__(self, encoding="utf-8", root=None, interrupt_event=None): self.encoding = encoding self.root = root - self.interrupt_event = interrupt_event or asyncio.Event() + self.interrupt_event = interrupt_event or ThreadSafeEvent() self.languages = dict( python=self.py_lint, diff --git a/cecli/mcp/server.py b/cecli/mcp/server.py index fa1fb46ba8d..f148e47bd87 100644 --- a/cecli/mcp/server.py +++ b/cecli/mcp/server.py @@ -42,6 +42,7 @@ def __init__(self, server_config, io=None, verbose=False): self.io = io self.verbose = verbose self.session = None + self._connection_loop: asyncio.AbstractEventLoop | None = None self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.exit_stack = AsyncExitStack() @@ -59,10 +60,17 @@ async def connect(self): Returns: ClientSession: The active session if mcp is not disabled """ + current_loop = asyncio.get_running_loop() if self.session is not None: + # Event loop affinity check: streams from stdio_client() are bound + # to the loop that created them. Reconnect if the loop changed. + if self._connection_loop is current_loop: + if self.verbose and self.io: + self.io.tool_output(f"Using existing session for MCP server: {self.name}") + return self.session if self.verbose and self.io: - self.io.tool_output(f"Using existing session for MCP server: {self.name}") - return self.session + self.io.tool_output(f"Reconnecting MCP server {self.name} (event loop changed)") + await self.disconnect() if self.verbose and self.io: self.io.tool_output(f"Establishing new connection to MCP server: {self.name}") @@ -87,6 +95,7 @@ async def connect(self): session = await self.exit_stack.enter_async_context(ClientSession(read, write)) await session.initialize() self.session = session + self._connection_loop = current_loop return session except Exception as e: logging.error(f"Error initializing server {self.name}: {e}") @@ -193,10 +202,15 @@ def _create_transport(self, url, http_client): raise NotImplementedError("Subclasses must implement _create_transport") async def connect(self): + current_loop = asyncio.get_running_loop() if self.session is not None: + if self._connection_loop is current_loop: + if self.verbose and self.io: + self.io.tool_output(f"Using existing session for {self.name}") + return self.session if self.verbose and self.io: - self.io.tool_output(f"Using existing session for {self.name}") - return self.session + self.io.tool_output(f"Reconnecting {self.name} (event loop changed)") + await self.disconnect() if self.verbose and self.io: self.io.tool_output(f"Establishing new connection to {self.name}") @@ -224,6 +238,7 @@ async def connect(self): session = await self.exit_stack.enter_async_context(ClientSession(read, write)) await session.initialize() self.session = session + self._connection_loop = current_loop if oauth_provider.context.oauth_metadata: token_endpoint = oauth_provider._get_token_endpoint() @@ -270,9 +285,13 @@ class SseServer(McpServer): """SSE (Server-Sent Events) MCP server using mcp.client.sse_client.""" async def connect(self): + current_loop = asyncio.get_running_loop() if self.session is not None: - logging.info(f"Using existing session for SSE MCP server: {self.name}") - return self.session + if self._connection_loop is current_loop: + logging.info(f"Using existing session for SSE MCP server: {self.name}") + return self.session + logging.info(f"Reconnecting SSE MCP server {self.name} (event loop changed)") + await self.disconnect() logging.info(f"Establishing new connection to SSE MCP server: {self.name}") try: @@ -285,6 +304,7 @@ async def connect(self): session = await self.exit_stack.enter_async_context(ClientSession(read, write)) await session.initialize() self.session = session + self._connection_loop = current_loop return session except Exception as e: logging.error(f"Error initializing SSE server {self.name}: {e}") diff --git a/cecli/prompts/agent.yml b/cecli/prompts/agent.yml index 6ed6e1566b7..71a3477377e 100644 --- a/cecli/prompts/agent.yml +++ b/cecli/prompts/agent.yml @@ -29,8 +29,8 @@ main_system: | ### 1. FILE FORMAT - File contents will be prefixed with identifiers. Each line starts with a case-sensitive content hash followed by `::`. These are used to target where editing tools will perform edits. - They are algorithmically generated, maintained, and subject to change. Do not search for these content hashes. Focus on the lines they identify. + File contents will be prefixed with identifiers. Each line starts with a case-sensitive content ID followed by `::`. These are used to target where editing tools will perform edits. + They are algorithmically generated, maintained, and subject to change. Do not search for these content IDs. Focus on the lines they identify. **Example File Format :** il9n::#!/usr/bin/env python3 diff --git a/cecli/prompts/hashline.yml b/cecli/prompts/hashline.yml index fd300b1acf5..a5a23fa8594 100644 --- a/cecli/prompts/hashline.yml +++ b/cecli/prompts/hashline.yml @@ -6,7 +6,7 @@ main_system: | Act as an expert software developer. Plan carefully, explain your logic briefly, and execute via LOCATE/CONTENTS blocks. ### 1. FILE FORMAT - Files are provided in "hashline" format. Each line starts with a case-sensitive content hash followed by `::`. + Files are provided in "hashline" format. Each line starts with a case-sensitive content ID followed by `::`. These hashes are used as identifiers for lines when editing. **Example File Format :** diff --git a/cecli/tools/_yield.py b/cecli/tools/_yield.py index b575cfa9efd..4697ab96561 100644 --- a/cecli/tools/_yield.py +++ b/cecli/tools/_yield.py @@ -1,6 +1,7 @@ import asyncio import logging +from cecli.helpers.threading import ThreadSafeEvent from cecli.tools.utils.base_tool import BaseTool from cecli.tools.utils.helpers import ToolError from cecli.tools.utils.output import color_markers, tool_footer, tool_header @@ -65,7 +66,7 @@ async def execute(cls, coder, **kwargs): # the interrupt event, avoiding nested asyncio.wait() calls. interrupt_event = coder.interrupt_event if interrupt_event is None: - interrupt_event = asyncio.Event() + interrupt_event = ThreadSafeEvent() interrupt_task = asyncio.create_task(interrupt_event.wait()) pending = set(active_tasks) | {interrupt_task} diff --git a/cecli/tools/edit_text.py b/cecli/tools/edit_text.py index a8eeabca75f..874c3ca85c2 100644 --- a/cecli/tools/edit_text.py +++ b/cecli/tools/edit_text.py @@ -35,13 +35,13 @@ class Tool(BaseTool): "function": { "name": "EditText", "description": ( - "Edit text in one or more files using content hash markers. " + "Edit text in one or more files using content ID markers. " "Supports replace, delete, and insert operations in a single call. " "Can handle an array of up to 10 edits across multiple files. " "Each edit must include its own file_path and operation type. " - "Use content hash ranges with the start_line and end_line parameters with format " + "Use content ID ranges with the start_line and end_line parameters with format " "`{4 char hash}` (without the braces). For empty files, use `@000` as the " - "content hash references." + "content ID references." ), "parameters": { "type": "object", @@ -74,14 +74,14 @@ class Tool(BaseTool): "start_line": { "type": "string", "description": ( - "Content hash for start line: `{4 char hash}` (without " + "content ID for start line: `{4 char hash}` (without " "the braces)" ), }, "end_line": { "type": "string", "description": ( - "Content hash for end line: `{4 char hash}` (without the" + "content ID for end line: `{4 char hash}` (without the" " braces)" ), }, @@ -248,7 +248,7 @@ def execute( if new_content != original_content: file_successful_edits += len(successful_ops) else: - raise ToolError("Invalid Edit - Update content hash bounds") + raise ToolError("Invalid Edit - Update content ID bounds") if len(failed_ops): for failed_op in failed_ops: @@ -446,7 +446,7 @@ def format_output(cls, coder, mcp_server, tool_response): text=strip_hashline(text), ) except ContentHashError as e: - diff_output = f"Content hash verification failed: {str(e)}" + diff_output = f"content ID verification failed: {str(e)}" except Exception: pass diff --git a/cecli/tools/grep.py b/cecli/tools/grep.py index a925bffc606..cf709995b6c 100644 --- a/cecli/tools/grep.py +++ b/cecli/tools/grep.py @@ -46,24 +46,14 @@ class Tool(BaseTool): }, "use_regex": { "type": "boolean", - "default": False, + "default": True, "description": "Whether to use regex.", }, "case_insensitive": { "type": "boolean", - "default": False, + "default": True, "description": "Whether to perform a case-insensitive search.", }, - "context_before": { - "type": "integer", - "default": 5, - "description": "Number of lines to show before a match.", - }, - "context_after": { - "type": "integer", - "default": 5, - "description": "Number of lines to show after a match.", - }, }, "required": ["pattern"], }, @@ -117,8 +107,8 @@ def execute( pattern = strip_hashline(search_op.get("pattern")) file_pattern = search_op.get("file_pattern", "*") directory = search_op.get("directory", search_op.get("path", ".")) - use_regex = search_op.get("use_regex", False) - case_insensitive = search_op.get("case_insensitive", False) + use_regex = search_op.get("use_regex", True) + case_insensitive = search_op.get("case_insensitive", True) context_before = search_op.get("context_before", 5) context_after = search_op.get("context_after", 5) diff --git a/cecli/tools/read_range.py b/cecli/tools/read_range.py index 413a111b61f..969f5237b42 100644 --- a/cecli/tools/read_range.py +++ b/cecli/tools/read_range.py @@ -17,31 +17,36 @@ class Tool(BaseTool): NORM_NAME = "readrange" TRACK_INVOCATIONS = False VALIDATIONS = { - "show": ["coerce_list"], - "show[]": ["coerce_dict"], + "read": ["coerce_list"], + "read[]": ["coerce_dict"], + "read[].range_start": ["coerce_str"], + "read[].range_end": ["coerce_str"], } SCHEMA = { "type": "function", "function": { "name": "ReadRange", "description": ( - "Get content hash prefixes of content between start and end patterns in files." - " Accepts an array of `show` objects, each with file_path, start_text, end_text." - " These values must be lines of content in the file. They can contain up to 3" - " lines but newlines should generally be avoided. Avoid using generic keywords and" + "Get content ID prefixed content between start and end markers in files." + " This is useful for files you are attempting to edit and for understanding their structure." + " Accepts an array of `read` objects, each with file_path, range_start, range_end." + " They can contain up to 3 lines of content. Avoid using singular generic keywords and" " symbols. Special markers @000 and 000@ represent the file boundaries and can be" - " used for start_text and end_text for the first and last lines of the file" - " respectively. Avoid using both of the special markers together on non-empty" - " files. Line numbers may be used as values but they are discouraged as" - " they shift between edits. Never use content hashes as the start_text and end_text values." - " Do not use the same pattern for the start_text and end_text. It is best to use function" - " names, variable declarations and other meaningful identifiers as start_text and" - " end_text values." + " used for range_start and range_end for the first and last lines of the file" + " respectively. Line numbers may also be used for range lookups." + " It is best to use function names, variable declarations and other meaningful identifiers" + " as range_start and range_end values." + " Do not use both of the special markers together on non-empty file." + " Do not use the same pattern for the range_start and range_end." + " Do not use empty strings for the range_start and range_end." + " Prefer using this tool over cli tools for reading files." + " Calling this tool sequentially on increasingly finer grained searches " + " will help with understanding important structural features." ), "parameters": { "type": "object", "properties": { - "show": { + "read": { "type": "array", "items": { "type": "object", @@ -50,14 +55,14 @@ class Tool(BaseTool): "type": "string", "description": "File path to search in.", }, - "start_text": { + "range_start": { "type": "string", "description": ( "The text marking the beginning of the range." " Use '@000' for the first line on empty files." ), }, - "end_text": { + "range_end": { "type": "string", "description": ( "The text marking the end of the range." @@ -65,12 +70,12 @@ class Tool(BaseTool): ), }, }, - "required": ["file_path", "start_text", "end_text"], + "required": ["file_path", "range_start", "range_end"], }, - "description": "Array of show operations to perform.", + "description": "Array of read operations to perform.", }, }, - "required": ["show"], + "required": ["read"], }, }, } @@ -79,11 +84,11 @@ class Tool(BaseTool): _last_read_turn: Dict[str, int] = {} # abs_path -> turn_count when last read @classmethod - def execute(cls, coder, show, **kwargs): + def execute(cls, coder, read, **kwargs): """ Displays numbered lines from multiple files centered around target locations (patterns or line_numbers), without adding files to context. - Accepts an array of show operations to perform. + Accepts an array of read operations to perform. Uses utility functions for path resolution and error handling. """ from cecli.helpers.conversation import ConversationService @@ -94,70 +99,73 @@ def execute(cls, coder, show, **kwargs): error_outputs = [] try: - # 1. Validate show parameter - if not isinstance(show, list): - show = [show] if isinstance(show, dict) else show + # 1. Validate read parameter + if not isinstance(read, list): + read = [read] if isinstance(read, dict) else read - if len(show) == 0: - raise ToolError("show array cannot be empty") + if len(read) == 0: + raise ToolError("read array cannot be empty") all_outputs = [] already_up_to_details = [] new_context_details = [] + all_outputs_set = set() + new_context_set = set() + already_up_to_set = set() ranges = {} - for show_index, show_op in enumerate(show): - # Extract parameters for this show operation - file_path = show_op.get("file_path") - start_text = show_op.get("start_text") - end_text = show_op.get("end_text") + for read_index, read_op in enumerate(read): + # Extract parameters for this read operation + file_path = read_op.get("file_path") + range_start = read_op.get("range_start") + range_end = read_op.get("range_end") padding = 5 if file_path is None: error_outputs.append( cls.format_error( coder, - f"Show operation {show_index + 1} missing required file_path parameter", + f"read operation {read_index + 1} missing required file_path parameter", None, None, None, - show_index, + read_index, ) ) continue # Validate arguments for this operation - if not is_provided(start_text) or not is_provided(end_text): + if not is_provided(range_start) or not is_provided(range_end): error_outputs.append( cls.format_error( coder, ( - f"Show operation {show_index + 1}: Provide both 'start_text' and" - " 'end_text'." + f"read operation {read_index + 1}: Provide both 'range_start' and" + " 'range_end'." ), file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue - if start_text.count("\n") > 4 or end_text.count("\n") > 4: + if range_start.count("\n") > 4 or range_end.count("\n") > 4: error_outputs.append( cls.format_error( coder, "Patterns must not contain more than 5 lines.", file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue - start_text = strip_hashline(start_text).strip() - end_text = strip_hashline(end_text).strip() + range_start = strip_hashline(range_start).strip() + range_end = strip_hashline(range_end).strip() # 2. Resolve path abs_path, rel_path = resolve_paths(coder, file_path) @@ -168,9 +176,9 @@ def execute(cls, coder, show, **kwargs): coder, f"File not found: {file_path}", file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue @@ -183,9 +191,9 @@ def execute(cls, coder, show, **kwargs): coder, f"Could not read file: {file_path}", file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue @@ -212,25 +220,79 @@ def execute(cls, coder, show, **kwargs): # 4. Determine line range start_line_idx = -1 end_line_idx = -1 + both_structured = False # found_by = "" - if start_text is not None and end_text is not None: - if start_text.isdigit() and end_text.isdigit(): - # Treat both as 1-based line numbers - start_line_num = int(start_text) - end_line_num = int(end_text) - # Clamp to valid range [1, num_lines] - start_line_num = max(1, min(start_line_num, num_lines)) - end_line_num = max(1, min(end_line_num, num_lines)) - if start_line_num > end_line_num: - # Swap so start <= end - start_line_num, end_line_num = end_line_num, start_line_num - start_indices = [start_line_num - 1] - end_indices = [end_line_num - 1] - elif start_text == "@000" or start_text == "000@": - start_indices = [0] + if range_start is not None and range_end is not None: + + def _is_valid_int(s): + try: + int(s) + return True + except ValueError: + return False + + start_is_digit = _is_valid_int(range_start) + end_is_digit = _is_valid_int(range_end) + start_is_special = range_start in ("@000", "000@") + end_is_special = range_end in ("@000", "000@") + both_structured = (start_is_digit or start_is_special) and ( + end_is_digit or end_is_special + ) + start_is_text = not start_is_digit and not start_is_special + end_is_text = not end_is_digit and not end_is_special + mixed_special_search = (start_is_special and end_is_text) or ( + end_is_special and start_is_text + ) + start_indices = [] + end_indices = [] + + if both_structured: + if start_is_digit: + start_line_num = int(range_start) - 1 + start_line_num = max(0, min(start_line_num, num_lines - 1)) + start_indices = [start_line_num] + else: + start_indices = [0] + + if end_is_digit: + end_line_num = int(range_end) - 1 + end_line_num = max(0, min(end_line_num, num_lines - 1)) + end_indices = [end_line_num] + else: + end_indices = [num_lines - 1] + elif mixed_special_search: + if start_is_special: + # Start is special marker, end is text pattern + if range_start == "@000": + start_indices = [0] + else: # 000@ + start_indices = [num_lines - 1] + # Search for end pattern as text + end_pattern_lines = range_end.split("\n") + end_indices = [] + for i in range(len(lines) - len(end_pattern_lines) + 1): + if all( + p_line in lines[i + j] + for j, p_line in enumerate(end_pattern_lines) + ): + end_indices.append(i + len(end_pattern_lines) - 1) + else: + # Start is text pattern, end is special marker + start_pattern_lines = range_start.split("\n") + start_indices = [] + for i in range(len(lines) - len(start_pattern_lines) + 1): + if all( + p_line in lines[i + j] + for j, p_line in enumerate(start_pattern_lines) + ): + start_indices.append(i) + if range_end == "@000": + end_indices = [0] + else: # 000@ + end_indices = [num_lines - 1] else: - start_pattern_lines = start_text.split("\n") + start_pattern_lines = range_start.split("\n") start_indices = [] for i in range(len(lines) - len(start_pattern_lines) + 1): if all( @@ -239,10 +301,7 @@ def execute(cls, coder, show, **kwargs): ): start_indices.append(i) - if end_text == "000@" or end_text == "@000": - end_indices = [num_lines - 1] - else: - end_pattern_lines = end_text.split("\n") + end_pattern_lines = range_end.split("\n") end_indices = [] for i in range(len(lines) - len(end_pattern_lines) + 1): if all( @@ -259,13 +318,13 @@ def execute(cls, coder, show, **kwargs): cls.format_error( coder, ( - f"Start pattern '{start_text}' too broad." + f"Start pattern '{range_start}' too broad." " Refine your search. Be more specific." ), file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue @@ -311,13 +370,13 @@ def execute(cls, coder, show, **kwargs): cls.format_error( coder, ( - f"Start pattern '{start_text}' not found in {file_path}." + f"Start pattern '{range_start}' not found in {file_path}." " Refine your search." ), file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue @@ -327,13 +386,13 @@ def execute(cls, coder, show, **kwargs): cls.format_error( coder, ( - f"End pattern '{end_text}' not found in {file_path}." + f"End pattern '{range_end}' not found in {file_path}." " Refine your search." ), file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue @@ -343,36 +402,58 @@ def execute(cls, coder, show, **kwargs): cls.format_error( coder, ( - f"End pattern '{end_text}' not found after start pattern in" + f"End pattern '{range_end}' not found after start pattern in" f" {file_path}." ), file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue - s_idx, e_idx = best_pair + if best_pair is None: + error_outputs.append( + cls.format_error( + coder, + ( + f"End pattern '{range_end}' not found after start pattern in" + f" {file_path}." + ), + file_path, + range_start, + range_end, + read_index, + ) + ) + continue - # Validate range width when special markers are used - # If too large, use _get_range_preview which tries get_file_stub - # first, falling back to 20 equally-spaced lines for non-code files - if (start_text == "@000" or end_text == "000@") and (e_idx - s_idx > 200): + s_idx, e_idx = best_pair + s_idx, e_idx = cls._extend_range_with_stub( + coder, abs_path, s_idx, e_idx, num_lines + ) + # For structured searches (line numbers, special markers) or mixed searches + # (one special marker, one text pattern), cap large ranges with preview + # Text pattern searches are not subject to capping + if (both_structured or mixed_special_search) and (e_idx - s_idx > 200): preview = cls._get_range_preview( abs_path, coder.io, start_idx=s_idx, end_idx=e_idx, line_numbers=True ) - if show_index > 0: - all_outputs.append("") - all_outputs.append(preview) + + if preview not in all_outputs_set: + all_outputs_set.add(preview) + if len(all_outputs): + all_outputs.append("") + all_outputs.append(preview) + cls._last_invocation[abs_path] = {"start_idx": s_idx, "end_idx": e_idx} continue # Store the found indices for future disambiguation cls._last_invocation[abs_path] = {"start_idx": s_idx, "end_idx": e_idx} - # found_by = f"range '{start_text}' to '{end_text}'" + # found_by = f"range '{range_start}' to '{range_end}'" try: padding_int = int(padding) @@ -390,9 +471,9 @@ def execute(cls, coder, show, **kwargs): coder, "Internal error: Could not determine line range.", file_path, - start_text, - end_text, - show_index, + range_start, + range_end, + read_index, ) ) continue @@ -412,8 +493,8 @@ def execute(cls, coder, show, **kwargs): # hashed_line = context_hashed_lines[i - start_line_idx] # output_lines.append(hashed_line) - # Add separator between multiple show operations - # if show_index > 0: + # Add separator between multiple read operations + # if read_index > 0: # all_outputs.append("") # all_outputs.extend(output_lines) @@ -436,14 +517,14 @@ def execute(cls, coder, show, **kwargs): is_already_up_to_date = False add_to_ranges = False - last_turn = cls._last_read_turn.get(abs_path) + # last_turn = cls._last_read_turn.get(abs_path) if original_context_content and original_context_content == new_context_content: already_up_to_date.append(rel_path) is_already_up_to_date = True - if last_turn is None or coder.turn_count - last_turn < 3 and already_up_to_date: - add_to_ranges = True + # if last_turn is None or coder.turn_count - last_turn < 3 and already_up_to_date: + # add_to_ranges = True else: add_to_ranges = True @@ -463,16 +544,24 @@ def execute(cls, coder, show, **kwargs): and e_idx >= 0 and e_idx < len(hashed_lines) ): - hashed_slice = hashed_lines[s_idx : e_idx + 1] + # hashed_slice = hashed_lines[s_idx : e_idx + 1] if is_already_up_to_date: - already_up_to_details.append( - cls.format_model_response(coder, rel_path, s_idx, e_idx, hashed_slice) + model_response = cls.format_model_response( + coder, rel_path, s_idx, e_idx, hashed_lines, current=True ) + + if model_response not in already_up_to_set: + already_up_to_set.add(model_response) + already_up_to_details.append(model_response) else: - new_context_details.append( - cls.format_model_response(coder, rel_path, s_idx, e_idx, hashed_slice) + model_response = cls.format_model_response( + coder, rel_path, s_idx, e_idx, hashed_lines ) + if model_response not in new_context_set: + new_context_set.add(model_response) + new_context_details.append(model_response) + # Conditionally remove old file context messages # If the file was last read >= 3 turns ago, keep old messages (allow coexistence) # Otherwise, remove them to avoid duplicates @@ -506,6 +595,7 @@ def execute(cls, coder, show, **kwargs): result_parts.append( f"Retrieved context for {len(new_context_details)} operation(s):\n\n" f"{detail_str}\n" + "Full results for these reads will be given in a follow up message.\n" ) if already_up_to_details: coder.io.tool_output( @@ -518,6 +608,7 @@ def execute(cls, coder, show, **kwargs): "Content up to date and available in history from previous read for " f"{len(already_up_to_details)} operation(s):\n\n" f"{detail_str}\n" + "Current contents for these reads available in previous content ID message." ) if already_up_to_date and not new_context_retrieved: result_parts.append( @@ -527,6 +618,7 @@ def execute(cls, coder, show, **kwargs): if all_outputs: result_parts.append("\n".join(all_outputs)) + result_parts.append("\nUse these outlines to refine your search.\n") if error_outputs: coder.io.tool_error(f"Errors encountered for {len(error_outputs)} operation(s)") @@ -546,19 +638,108 @@ def execute(cls, coder, show, **kwargs): return handle_tool_error(coder, tool_name, e) @classmethod - def format_model_response(cls, coder, rel_path, s_idx, e_idx, hashed_slice): + def format_model_response(cls, coder, rel_path, s_idx, e_idx, hashed_lines, current=False): """Format a file's context range as hash-prefixed lines for the model.""" + # Read file content for stub lookups + try: + from cecli.tools.utils.helpers import resolve_paths + + abs_path, _ = resolve_paths(coder, rel_path) + last_turn = cls._last_read_turn[abs_path] or 0 + except Exception: + pass + + lines = [] + + # Try to return structural stub information instead of raw hashed lines + try: + if hashed_lines and current and coder.turn_count - last_turn >= 2: + num_lines = len(hashed_lines) + + start_stub_s, start_stub_e = cls._extend_range_with_stub( + coder, abs_path, s_idx, s_idx, num_lines + ) + end_stub_s, end_stub_e = cls._extend_range_with_stub( + coder, abs_path, e_idx, e_idx, num_lines + ) + + # start_stub_s, start_stub_e = cls._reposition_indices(s_idx, start_stub_s, start_stub_e) + # end_stub_s, end_stub_e = cls._reposition_indices(e_idx, end_stub_s, end_stub_e) + + start_found = start_stub_s != s_idx or start_stub_e != s_idx + end_found = end_stub_s != e_idx or end_stub_e != e_idx + + if end_stub_s != start_stub_s or end_stub_e != start_stub_e: + start_stub_s = end_stub_s + start_stub_e = end_stub_e + start_found = True + end_found = False + + if start_found or end_found: + if start_found: + lines.append( + f"File {rel_path} Snapshot (Lines {start_stub_s + 1} - {start_stub_e + 1}):" + ) + lines.extend(hashed_lines[start_stub_s:start_stub_e]) + + if ( + end_found + and start_stub_s != end_stub_s + and start_stub_e != end_stub_e + and end_stub_e != e_idx + ): + lines.append("...⋮...") + lines.append( + f"File {rel_path} Snapshot (Lines {end_stub_s + 1} - {end_stub_e + 1}):" + ) + lines.extend(hashed_lines[end_stub_s:end_stub_e]) + + lines.append("") + return "\n".join(lines) + except Exception: + pass + lines = [f"File {rel_path} Snapshot (Lines {s_idx + 1} - {e_idx + 1}):"] - total = len(hashed_slice) - if total <= 10: - lines.extend(hashed_slice) + total = e_idx - s_idx + if total <= 15: + lines.extend(hashed_lines[s_idx : e_idx + 1]) else: - lines.extend(hashed_slice[:5]) - lines.append("...") - lines.extend(hashed_slice[-5:]) + lines.extend(hashed_lines[s_idx : s_idx + 5]) + lines.append("...⋮...") + lines.extend(hashed_lines[e_idx - 4 : e_idx + 1]) lines.append("") return "\n".join(lines) + @classmethod + def _reposition_indices( + cls, target_idx: int, start_idx: int, end_idx: int, total_lines: int = 20 + ) -> tuple: + """ + Calculates the clamped start and end indices for a centered window. + Returns a tuple of (slice_start, slice_end) compatible with python slicing. + """ + # 1. Calculate ideal half-window size + half_window = total_lines // 2 + + # 2. Calculate initial left/right bounds + left = target_idx - half_window + right = target_idx + half_window + + # 3. Slide the window if it overflows boundaries + if left < start_idx: + right += start_idx - left + left = start_idx + + if right > end_idx: + left -= right - end_idx + right = end_idx + + # 4. Final safety clamp in case the range itself is smaller than total_lines + left = max(start_idx, left) + + # Return right + 1 so it's ready-to-use for standard Python slicing [start:end] + return left, right + 1 + @classmethod def clear_old_messages(cls, coder): from cecli.helpers.conversation import ConversationService, MessageTag @@ -622,18 +803,18 @@ def format_output(cls, coder, mcp_server, tool_response): coder.io.tool_error("Invalid Tool JSON") return - show_ops = params.get("show", []) - if show_ops: + read_ops = params.get("read", []) + if read_ops: coder.io.tool_output("") - for i, show_op in enumerate(show_ops): - file_path = show_op.get("file_path", "") - start_text = strip_hashline(show_op.get("start_text", "")).strip() - end_text = strip_hashline(show_op.get("end_text", "")).strip() + for i, read_op in enumerate(read_ops): + file_path = read_op.get("file_path", "") + range_start = strip_hashline(read_op.get("range_start", "")).strip() + range_end = strip_hashline(read_op.get("range_end", "")).strip() - # Format as "show: • file_path • start_text • end_text • padding" + # Format as "read: • file_path • range_start • range_end • padding" formatted_query = ( - f"{color_start}range_{i + 1}:{color_end} {file_path} • {start_text} •" - f" {end_text}" + f"{color_start}range_{i + 1}:{color_end} {file_path} • {range_start} •" + f" {range_end}" ) coder.io.tool_output(formatted_query) coder.io.tool_output("") @@ -641,24 +822,24 @@ def format_output(cls, coder, mcp_server, tool_response): tool_footer(coder=coder, tool_response=tool_response, params=params) @classmethod - def format_error(cls, coder, error_text, file_path, start_text, end_text, operation_index): + def format_error(cls, coder, error_text, file_path, range_start, range_end, operation_index): """Format error output for the ReadRange tool.""" - # Truncate start_text to first line with ellipsis if multiline - start_line = (start_text or "N/A").split("\n")[0] - if start_text and start_text.count("\n") > 0: + # Truncate range_start to first line with ellipsis if multiline + start_line = (range_start or "N/A").split("\n")[0] + if range_start and range_start.count("\n") > 0: start_line = start_line + " ..." - # Truncate end_text to first line with ellipsis if multiline - end_line = (end_text or "N/A").split("\n")[0] - if end_text and end_text.count("\n") > 0: + # Truncate range_end to first line with ellipsis if multiline + end_line = (range_end or "N/A").split("\n")[0] + if range_end and range_end.count("\n") > 0: end_line = end_line + " ..." output = [ f"[Operation {operation_index + 1}]", f"file_path: {file_path or 'N/A'}", - f"start_text: {start_line}", - f"end_text: {end_line}", + f"range_start: {start_line}", + f"range_end: {end_line}", "", error_text, ] @@ -669,6 +850,57 @@ def format_error(cls, coder, error_text, file_path, start_text, end_text, operat def on_duplicate_request(cls, coder, **kwargs): coder.edit_allowed = True + @classmethod + def _extend_range_with_stub(cls, coder, abs_path, s_idx, e_idx, num_lines): + """ + Extends the range [s_idx, e_idx] to include the stub result before + and up to the stub result after the specified range. + """ + from cecli.repomap import RepoMap + + try: + if not hasattr(RepoMap, "_stub_instance"): + RepoMap._stub_instance = RepoMap(map_tokens=0, io=coder.io) + rm = RepoMap._stub_instance + rel_fname = rm.get_rel_fname(abs_path) + tags = rm.get_tags(abs_path, rel_fname) + if not tags: + return s_idx, e_idx + + # Get all definition lines, plus import lines for structural context + lois = sorted( + list( + set( + tag.line + for tag in tags + if tag.kind == "def" or tag.specific_kind == "import" + ) + ) + ) + if not lois: + return s_idx, e_idx + + # Find the stub result before or at s_idx + # We want the largest line in lois that is <= s_idx + before_lines = [ln for ln in lois if ln <= s_idx] + new_s_idx = s_idx + if before_lines: + new_s_idx = before_lines[-1] + + # Find the stub result after e_idx + # We want the smallest line in lois that is > e_idx + after_lines = [ln for ln in lois if ln > e_idx] + new_e_idx = e_idx + if after_lines: + new_e_idx = after_lines[0] - 1 + else: + new_e_idx = num_lines - 1 + + return new_s_idx, new_e_idx + except Exception: + # Fallback to original range if anything goes wrong + return s_idx, e_idx + @classmethod def _get_range_preview(cls, abs_path, io, start_idx, end_idx, line_numbers=True): """Get a preview of a large file range between start_idx and end_idx. diff --git a/cecli/tools/validations/validations.py b/cecli/tools/validations/validations.py index 531d295d59b..15fc22ab460 100644 --- a/cecli/tools/validations/validations.py +++ b/cecli/tools/validations/validations.py @@ -67,22 +67,114 @@ def validate_params(cls, params: dict, validations: dict, schema: dict | None = return params for raw_key, method_names in validations.items(): - # Determine whether the key targets list items (trailing "[]") - iterate_over_list = raw_key.endswith("[]") - clean_key = raw_key.rstrip("[]") + segments = cls._parse_validation_key(raw_key) + if not segments: + continue + cls._apply_along_segments(params, segments, method_names) + return params + + @staticmethod + def _parse_validation_key(raw_key: str) -> list[tuple[str, bool]]: + """ + Parse a validation path into a list of (key, iterate) tuples. - # Split on dots to get the navigation path into params - path = clean_key.split(".") if clean_key else [] + Supports the following path shapes: - if not path: - continue + "segment" -> [("segment", False)] + "segment.nested" -> [("segment", False), ("nested", False)] + "segment[]" -> [("segment", True)] + "segment[].nested" -> [("segment", True), ("nested", False)] + "segment.nested[]" -> [("segment", False), ("nested", True)] + "segment[].nested[].n2" -> [("segment", True), ("nested", True), ("n2", False)] - if iterate_over_list: - cls._apply_validations_to_list_items(params, path, method_names) + Any trailing ``[]`` on a path segment marks it for iteration — the + validation will be applied to each item in the list found at that key. + + Returns: + A list of (key, should_iterate) tuples. Returns an empty list + if the key is empty or contains only separators. + """ + if not raw_key: + return [] + + parts = raw_key.split(".") + segments: list[tuple[str, bool]] = [] + for part in parts: + if not part: + continue + if part.endswith("[]"): + segments.append((part[:-2], True)) else: - cls._apply_validations_to_value(params, path, method_names) + segments.append((part, False)) - return params + return segments + + @classmethod + def _apply_along_segments( + cls, params: dict, segments: list[tuple[str, bool]], method_names: list[str] + ) -> None: + """ + Recursively apply *method_names* along the parsed *segments* path. + + Each segment is a ``(key, iterate)`` tuple. When *iterate* is ``True`` + the method expects ``params[key]`` to be a list and either applies the + validations to each item (if this is the last segment) or recurses + into each item's dict (if there are further segments). When *iterate* + is ``False`` the method either applies validations to ``params[key]`` + (last segment) or recurses into the nested dict. + + ``params`` is mutated in place. + """ + if not segments: + return + + key, iterate = segments[0] + remaining = segments[1:] + + if not isinstance(params, dict) or key not in params: + return + + if iterate: + items = params[key] + if not isinstance(items, list): + return + + if not remaining: + # Apply validation methods to each item in the list + new_items: list = [] + for item in items: + for method_name in method_names: + method = getattr(cls, method_name, None) + if method is None: + raise ToolError(f"Unknown validation method: {method_name}") + item = method(item) + if item is None: + break + if item is not None: + new_items.append(item) + params[key] = new_items + else: + # Recurse into each item, applying remaining segments + for item in items: + if isinstance(item, dict): + cls._apply_along_segments(item, remaining, method_names) + else: + if not remaining: + # Apply validation methods to the value at this key + value = params[key] + for method_name in method_names: + method = getattr(cls, method_name, None) + if method is None: + raise ToolError(f"Unknown validation method: {method_name}") + value = method(value) + if value is None: + break + params[key] = value + else: + # Navigate deeper + nested = params[key] + if isinstance(nested, dict): + cls._apply_along_segments(nested, remaining, method_names) @classmethod def _basic_validations(cls, params: object, schema: dict | None = None) -> dict: diff --git a/tests/tools/test_get_lines.py b/tests/tools/test_get_lines.py index 1bbeb0d3b6c..686146a817c 100644 --- a/tests/tools/test_get_lines.py +++ b/tests/tools/test_get_lines.py @@ -54,17 +54,17 @@ def test_pattern_with_zero_line_number_is_allowed(coder_with_file): result = read_range.Tool.execute( coder, - show=[ + read=[ { "file_path": "example.txt", - "start_text": "beta", - "end_text": "beta", + "range_start": "beta", + "range_end": "beta", "padding": 0, } ], ) - # show_numbered_context now returns a new formatted context message + # read_range now returns a new formatted context message assert "Retrieved context for 1 operation(s)" in result coder.io.tool_error.assert_not_called() @@ -74,17 +74,17 @@ def test_empty_pattern_uses_line_number(coder_with_file): result = read_range.Tool.execute( coder, - show=[ + read=[ { "file_path": "example.txt", - "start_text": "beta", - "end_text": "beta", + "range_start": "beta", + "range_end": "beta", "padding": 0, } ], ) - # show_numbered_context now returns a static success message + # read_range now returns a static success message assert "Retrieved context for 1 operation(s)" in result coder.io.tool_error.assert_not_called() @@ -93,18 +93,19 @@ def test_conflicting_pattern_and_line_number_raise(coder_with_file): coder, file_path = coder_with_file # Test that missing start_text raises an error + # Test that missing range_start raises an error result = read_range.Tool.execute( coder, - show=[ + read=[ { "file_path": "example.txt", - "end_text": "beta", + "range_end": "beta", "padding": 0, } ], ) - assert "Provide both 'start_text' and 'end_text'" in result + assert "Provide both 'range_start' and 'range_end'" in result coder.io.tool_error.assert_called() @@ -130,11 +131,11 @@ def test_multiline_pattern_search(coder_with_file): result = read_range.Tool.execute( coder, - show=[ + read=[ { "file_path": "example.txt", - "start_text": "alpha\nbeta", - "end_text": "beta\ngamma", + "range_start": "alpha\nbeta", + "range_end": "beta\ngamma", "padding": 0, } ], @@ -157,11 +158,11 @@ def test_empty_file_includes_edit_hint(tmp_path): conv.get_chunks.return_value.add_file_context_messages = Mock() result = read_range.Tool.execute( coder, - show=[ + read=[ { "file_path": "pubspec.yaml", - "start_text": "@000", - "end_text": "@000", + "range_start": "@000", + "range_end": "@000", } ], ) diff --git a/tests/tools/test_insert_block.py b/tests/tools/test_insert_block.py index 9e5ae2b855e..9171a6cfee4 100644 --- a/tests/tools/test_insert_block.py +++ b/tests/tools/test_insert_block.py @@ -121,7 +121,7 @@ def test_mutually_exclusive_parameters_raise(coder_with_file): ) assert result.startswith("Error in EditText:") - assert "Invalid Edit - Update content hash bounds" in result + assert "Invalid Edit - Update content ID bounds" in result assert file_path.read_text().startswith("first line") coder.io.tool_error.assert_called() diff --git a/tests/tools/test_read_range_execute.py b/tests/tools/test_read_range_execute.py new file mode 100644 index 00000000000..bad0fde5981 --- /dev/null +++ b/tests/tools/test_read_range_execute.py @@ -0,0 +1,550 @@ +""" +Tests for the execute method of read_range.py. + +Focuses on the parsing logic for line numbers, special markers (@000, 000@), +and text strings. Tests cover all combinations of these marker types. +""" + +import os +import sys +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + + +def _safe_relpath(path): + """Wrapper around os.path.relpath that handles cross-drive scenarios on Windows.""" + try: + return os.path.relpath(path) + except ValueError: + # On Windows, os.path.relpath fails when path and cwd are on different drives. + # Fall back to basename which is sufficient for test patches. + return os.path.basename(path) + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def mock_coder(): + """Create a mock coder object with all necessary attributes.""" + coder = MagicMock() + coder.turn_count = 5 + coder.abs_root_path.side_effect = lambda p: os.path.abspath(p) + coder.get_rel_fname.side_effect = lambda p: _safe_relpath(p) + coder.io.tool_output = MagicMock() + coder.io.tool_error = MagicMock() + coder.io.tool_warning = MagicMock() + return coder + + +@pytest.fixture +def mock_file_context(): + """Mock the ConversationService file context operations.""" + file_context = MagicMock() + file_context.get_file_context.return_value = None + file_context.update_file_context.return_value = (1, 10) + file_context.clear_ranges = MagicMock() + file_context.push_range = MagicMock() + return file_context + + +@pytest.fixture +def mock_chunks(): + """Mock the ConversationService chunks operations.""" + chunks = MagicMock() + chunks.add_file_context_messages = MagicMock() + return chunks + + +@pytest.fixture +def mock_manager(): + """Mock the ConversationService manager operations.""" + manager = MagicMock() + manager.get_tag_messages.return_value = [] + return manager + + +def create_test_file(content): + """Create a temporary file with the given content and return the path.""" + tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) + tmp.write(content) + tmp.close() + return tmp.name + + +# ============================================================================= +# Test Class +# ============================================================================= + + +class TestReadRangeExecute: + """Tests for Tool.execute() parsing logic.""" + + # Class-level patches that apply to all tests + @pytest.fixture(autouse=True) + def setup_patches(self): + self.patches = [] + yield + for p in self.patches: + p.stop() + + def _setup(self, mock_coder, mock_file_context, mock_chunks, mock_manager, file_content=""): + """Set up mocks and create a test file with given content.""" + self.coder = mock_coder + self.test_file = create_test_file(file_content) + self.coder.io.read_text.return_value = file_content + + # Patch ConversationService - it's imported locally in execute(), + # so we patch at the source module + mock_cs = MagicMock() + mock_cs.get_files.return_value = mock_file_context + mock_cs.get_chunks.return_value = mock_chunks + mock_cs.get_manager.return_value = mock_manager + cs_patch = patch("cecli.helpers.conversation.ConversationService", mock_cs) + cs_patch.start() + self.patches.append(cs_patch) + + # Patch strip_hashline to be identity + sh_patch = patch("cecli.tools.read_range.strip_hashline", side_effect=lambda x: x) + sh_patch.start() + self.patches.append(sh_patch) + + # Patch hashline to be identity + hl_patch = patch("cecli.tools.read_range.hashline", side_effect=lambda x: x) + hl_patch.start() + self.patches.append(hl_patch) + + # Patch resolve_paths + rp_patch = patch( + "cecli.tools.read_range.resolve_paths", + return_value=(self.test_file, _safe_relpath(self.test_file)), + ) + rp_patch.start() + self.patches.append(rp_patch) + + # Patch is_provided + ip_patch = patch( + "cecli.tools.read_range.is_provided", + side_effect=lambda v, **kw: v is not None and v != "", + ) + ip_patch.start() + self.patches.append(ip_patch) + + # Reset class-level state on Tool + from cecli.tools.read_range import Tool + + self.Tool = Tool + Tool._last_invocation = {} + Tool._last_read_turn = {} + + def _teardown(self): + """Clean up temporary file.""" + if hasattr(self, "test_file") and os.path.exists(self.test_file): + os.unlink(self.test_file) + + # ========================================================================= + # Line Number Parsing (both_structured, both digits) + # ========================================================================= + + def test_both_digits_valid_range( + self, mock_coder, mock_file_context, mock_chunks, mock_manager + ): + """Test: range_start='5', range_end='10' -> lines 5-10 (1-based).""" + content = "\n".join(f"line{i}" for i in range(1, 11)) + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "5", "range_end": "10"}] + result = self.Tool.execute(self.coder, show) + assert "Snapshot" in result + assert "line5" in result + assert "line10" in result + finally: + self._teardown() + + def test_both_digits_same_line(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test: range_start='1', range_end='1' -> just line 0.""" + content = "\n".join(f"line{i}" for i in range(1, 11)) + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "1", "range_end": "1"}] + result = self.Tool.execute(self.coder, show) + assert "line1" in result + finally: + self._teardown() + + def test_both_digits_out_of_bounds( + self, mock_coder, mock_file_context, mock_chunks, mock_manager + ): + """Test: range_start='1', range_end='100' -> clamp to valid range.""" + content = "\n".join(f"line{i}" for i in range(1, 11)) + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "1", "range_end": "100"}] + result = self.Tool.execute(self.coder, show) + assert "line1" in result + assert "line10" in result + finally: + self._teardown() + + def test_both_digits_inverted_order( + self, mock_coder, mock_file_context, mock_chunks, mock_manager + ): + """Test: range_start='10', range_end='5': inverted matching swaps.""" + content = "\n".join(f"line{i}" for i in range(1, 11)) + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "10", "range_end": "5"}] + result = self.Tool.execute(self.coder, show) + # Inverted: start=[9], end=[4], only one each -> swap to (4, 9) + assert result is not None + finally: + self._teardown() + + # ========================================================================= + # Special Marker Parsing (both_structured, both special) + # ========================================================================= + + def test_special_start_end(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test: @000 to 000@ -> first to last line.""" + content = "\n".join([f"line{i}" for i in range(1, 6)]) + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "@000", "range_end": "000@"}] + result = self.Tool.execute(self.coder, show) + assert "line1" in result + assert "line5" in result + finally: + self._teardown() + + def test_special_start_at_000(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test: @000 to @000 -> first line only.""" + content = "\n".join([f"line{i}" for i in range(1, 6)]) + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "@000", "range_end": "@000"}] + result = self.Tool.execute(self.coder, show) + assert "line1" in result + finally: + self._teardown() + + def test_special_end_at_000(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test: 000@ to 000@ -> last line only.""" + content = "\n".join([f"line{i}" for i in range(1, 6)]) + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "000@", "range_end": "000@"}] + result = self.Tool.execute(self.coder, show) + assert "line5" in result + finally: + self._teardown() + + # ========================================================================= + # Mixed Digit + Special (both_structured) + # ========================================================================= + + def test_special_start_digit_end( + self, mock_coder, mock_file_context, mock_chunks, mock_manager + ): + """Test: @000 to '3' -> first to line 3 (1-based).""" + content = "line1\nline2\nline3\nline4\nline5" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "@000", "range_end": "3"}] + result = self.Tool.execute(self.coder, show) + assert "line1" in result + assert "line3" in result + finally: + self._teardown() + + def test_digit_start_special_end( + self, mock_coder, mock_file_context, mock_chunks, mock_manager + ): + """Test: '2' to 000@ -> line 1 to last.""" + content = "line1\nline2\nline3\nline4\nline5" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "2", "range_end": "000@"}] + result = self.Tool.execute(self.coder, show) + assert "line2" in result + assert "line5" in result + finally: + self._teardown() + + # ========================================================================= + # Text Pattern Parsing + # ========================================================================= + + def test_both_text_patterns(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test text patterns that exist in the file.""" + content = ( + "def foo():\n return 1\n\ndef bar():\n return 2\n\ndef baz():\n return 3\n" + ) + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [ + { + "file_path": self.test_file, + "range_start": "def foo():", + "range_end": "def bar():", + } + ] + result = self.Tool.execute(self.coder, show) + assert "Snapshot" in result + assert "def foo()" in result + assert "def bar()" in result + finally: + self._teardown() + + def test_text_pattern_not_found(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test text pattern that doesn't exist -> error.""" + content = "line1\nline2\nline3" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [ + { + "file_path": self.test_file, + "range_start": "nonexistent_pattern", + "range_end": "also_nonexistent", + } + ] + result = self.Tool.execute(self.coder, show) + assert "Errors" in result or "not found" in result + finally: + self._teardown() + + def test_text_pattern_multiline(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test multiline text patterns.""" + content = "def foo():\n return 1\n\ndef bar():\n return 2\n" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "def foo", "range_end": "def bar"}] + result = self.Tool.execute(self.coder, show) + assert "Snapshot" in result + finally: + self._teardown() + + # ========================================================================= + # Mixed Special + Text (mixed_special_search) + # ========================================================================= + + def test_special_start_text_end(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test: @000 to text 'debug_mode'. + + NOTE: This may expose a bug in mixed_special_search where indices + get overwritten after the if/else block. + """ + content = "header\nconfig_value = 42\ndebug_mode = True\nfooter" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "@000", "range_end": "debug_mode"}] + result = self.Tool.execute(self.coder, show) + # Should find '@000' at start and 'debug_mode' as text + print(f"\n[special_start_text_end] result: {result[:300]}") + assert result is not None + finally: + self._teardown() + + def test_text_start_special_end(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test: text 'config_value' to 000@. + + NOTE: This may expose a bug in mixed_special_search where indices + get overwritten after the if/else block. + """ + content = "header\nconfig_value = 42\ndebug_mode = True\nfooter" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [ + {"file_path": self.test_file, "range_start": "config_value", "range_end": "000@"} + ] + result = self.Tool.execute(self.coder, show) + print(f"\n[text_start_special_end] result: {result[:300]}") + assert result is not None + finally: + self._teardown() + + # ========================================================================= + # Edge Cases + # ========================================================================= + + def test_empty_file(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test with an empty file.""" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, "") + try: + show = [{"file_path": self.test_file, "range_start": "@000", "range_end": "000@"}] + result = self.Tool.execute(self.coder, show) + assert "empty" in result.lower() + finally: + self._teardown() + + def test_single_line_file(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test with a single line file.""" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, "only_line") + try: + show = [{"file_path": self.test_file, "range_start": "1", "range_end": "1"}] + result = self.Tool.execute(self.coder, show) + assert "only_line" in result + finally: + self._teardown() + + def test_file_not_found(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test with a non-existent file.""" + mock_coder.io.read_text.return_value = None + # We need abs_path to pass os.path.exists but read_text to return None + abs_path = "/nonexistent/path.py" + mock_coder.abs_root_path.return_value = abs_path + + rp_patch = patch( + "cecli.tools.read_range.resolve_paths", return_value=(abs_path, "nonexistent/path.py") + ) + rp_patch.start() + self.patches.append(rp_patch) + + from cecli.tools.read_range import Tool + + show = [{"file_path": "nonexistent/path.py", "range_start": "1", "range_end": "10"}] + result = Tool.execute(mock_coder, show) + assert "not found" in result or "Errors" in result + + def test_missing_parameters(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test with missing range_start and range_end (empty strings).""" + from cecli.tools.read_range import Tool + + show = [{"file_path": "some_file.py", "range_start": "", "range_end": ""}] + result = Tool.execute(mock_coder, show) + assert "Provide both" in result or "Errors" in result + + def test_multiple_show_operations( + self, mock_coder, mock_file_context, mock_chunks, mock_manager + ): + """Test multiple show operations in one call.""" + content1 = "line1_1\nline1_2\nline1_3\nline1_4\nline1_5" + content2 = "line2_1\nline2_2\nline2_3\nline2_4\nline2_5" + test_file1 = create_test_file(content1) + test_file2 = create_test_file(content2) + + def resolve_side_effect(coder, file_path): + if "file1" in file_path: + return (test_file1, "file1.py") + return (test_file2, "file2.py") + + rp_patch = patch("cecli.tools.read_range.resolve_paths", side_effect=resolve_side_effect) + rp_patch.start() + + sh_patch = patch("cecli.tools.read_range.strip_hashline", side_effect=lambda x: x) + sh_patch.start() + + hl_patch = patch("cecli.tools.read_range.hashline", side_effect=lambda x: x) + hl_patch.start() + + ip_patch = patch( + "cecli.tools.read_range.is_provided", + side_effect=lambda v, **kw: v is not None and v != "", + ) + ip_patch.start() + + mock_cs = MagicMock() + mock_cs.get_files.return_value = mock_file_context + mock_cs.get_chunks.return_value = mock_chunks + mock_cs.get_manager.return_value = mock_manager + cs_patch = patch("cecli.helpers.conversation.ConversationService", mock_cs) + cs_patch.start() + + mock_coder.io.read_text.side_effect = [content1, content2] + + try: + from cecli.tools.read_range import Tool + + Tool._last_invocation = {} + Tool._last_read_turn = {} + + show = [ + {"file_path": "file1.py", "range_start": "1", "range_end": "3"}, + {"file_path": "file2.py", "range_start": "2", "range_end": "4"}, + ] + result = Tool.execute(mock_coder, show) + assert "line1_1" in result + assert "line2_2" in result + finally: + for p in [cs_patch, sh_patch, hl_patch, rp_patch, ip_patch]: + p.stop() + os.unlink(test_file1) + os.unlink(test_file2) + + # ========================================================================= + # Multiple Matches / Disambiguation + # ========================================================================= + + def test_few_matches(self, mock_coder, mock_file_context, mock_chunks, mock_manager): + """Test with ≤5 matches where each pattern appears once.""" + content = """def func_a(): + pass + +def func_b(): + pass + +def func_c(): + pass + +def func_d(): + pass + +def func_e(): + pass + +def func_f(): + pass +""" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [ + { + "file_path": self.test_file, + "range_start": "def func_a", + "range_end": "def func_c", + } + ] + result = self.Tool.execute(self.coder, show) + assert "Snapshot" in result + finally: + self._teardown() + + def test_too_many_matches_without_history( + self, mock_coder, mock_file_context, mock_chunks, mock_manager + ): + """Test with >5 matches without history -> should report 'too broad'.""" + content = """def func_a(): + pass + +def func_b(): + pass + +def func_c(): + pass + +def func_d(): + pass + +def func_e(): + pass + +def func_f(): + pass +""" + self._setup(mock_coder, mock_file_context, mock_chunks, mock_manager, content) + try: + show = [{"file_path": self.test_file, "range_start": "def", "range_end": "def"}] + result = self.Tool.execute(self.coder, show) + assert "too broad" in result.lower() + finally: + self._teardown() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short", "-s"]) diff --git a/tests/tools/validations.py b/tests/tools/validations.py index 95f58bc183b..9f651e75586 100644 --- a/tests/tools/validations.py +++ b/tests/tools/validations.py @@ -461,3 +461,209 @@ def test_list_iteration_on_non_list_does_nothing(self): {"items[]": ["coerce_dict"]}, ) assert result == {"items": "not a list"} + + +# ========================================================================= +# path parsing tests +# ========================================================================= + + +class TestPathParsing: + """Path resolution: segment, segment.nested, segment[], segment[].nested, segment.nested[], and complex.""" + + # ---- "segment" - single path ---- + + def test_single_path_segment(self): + """A simple key should resolve to a top-level param value.""" + params = {"delegations": '[{"name": "a1"}]'} + result = ToolValidations.validate_params( + params, + {"delegations": ["coerce_list"]}, + ) + assert result == {"delegations": [{"name": "a1"}]} + + # ---- "segment.nested" - nested path ---- + + def test_nested_path_segment(self): + """A dot-separated key should resolve to a nested param value.""" + params = {"outer": {"inner": '[{"name": "a1"}]'}} + result = ToolValidations.validate_params( + params, + {"outer.inner": ["coerce_list"]}, + ) + assert result == {"outer": {"inner": [{"name": "a1"}]}} + + def test_nested_path_deep(self): + """Deeply nested dot-separated key should resolve correctly.""" + params = {"a": {"b": {"c": '{"x": 1}'}}} + result = ToolValidations.validate_params( + params, + {"a.b.c": ["coerce_dict"]}, + ) + assert result == {"a": {"b": {"c": {"x": 1}}}} + + # ---- "segment[]" - iterate over list items at segment ---- + + def test_segment_bracket_iterates_list_items(self): + """A key with trailing [] should apply validation to each list item.""" + params = { + "items": [ + '{"name": "a1", "prompt": "do x"}', + '{"name": "a2", "prompt": "do y"}', + ] + } + result = ToolValidations.validate_params( + params, + {"items[]": ["coerce_dict"]}, + ) + assert result == { + "items": [ + {"name": "a1", "prompt": "do x"}, + {"name": "a2", "prompt": "do y"}, + ] + } + + # ---- "segment[].nested" - iterate then access sub-key ---- + + def test_segment_bracket_nested_key(self): + """segment[].nested: iterate over segment items, apply validation to each item's .nested.""" + params = { + "items": [ + {"nested": '{"a": 1}'}, + {"nested": '{"b": 2}'}, + ] + } + result = ToolValidations.validate_params( + params, + {"items[].nested": ["coerce_dict"]}, + ) + assert result == { + "items": [ + {"nested": {"a": 1}}, + {"nested": {"b": 2}}, + ] + } + + def test_segment_bracket_nested_skips_missing_keys(self): + """segment[].nested: items missing the nested key should be left alone.""" + params = { + "items": [ + {"nested": '{"a": 1}'}, + {"other": "value"}, + ] + } + result = ToolValidations.validate_params( + params, + {"items[].nested": ["coerce_dict"]}, + ) + # The item without 'nested' should remain unchanged + assert result == { + "items": [ + {"nested": {"a": 1}}, + {"other": "value"}, + ] + } + + def test_segment_bracket_nested_not_a_list(self): + """segment[].nested: if segment is not a list, params should be left unchanged.""" + params = {"items": "not a list"} + result = ToolValidations.validate_params( + params, + {"items[].nested": ["coerce_dict"]}, + ) + assert result == {"items": "not a list"} + + # ---- "segment.nested[]" - navigate then iterate ---- + + def test_nested_dot_bracket_iterates_list(self): + """segment.nested[]: navigate to segment.nested, then iterate over list items.""" + params = { + "group": { + "items": [ + '{"name": "a1"}', + '{"name": "a2"}', + ] + } + } + result = ToolValidations.validate_params( + params, + {"group.items[]": ["coerce_dict"]}, + ) + assert result == { + "group": { + "items": [ + {"name": "a1"}, + {"name": "a2"}, + ] + } + } + + # ---- "segment[].nested[].nested2" - complex ---- + + def test_complex_nested_iteration(self): + """segment[].nested[].nested2: iterate, descend, iterate, access sub-key.""" + params = { + "items": [ + { + "nested": [ + {"nested2": '{"a": 1}'}, + {"nested2": '{"b": 2}'}, + ] + }, + { + "nested": [ + {"nested2": '{"c": 3}'}, + ] + }, + ] + } + result = ToolValidations.validate_params( + params, + {"items[].nested[].nested2": ["coerce_dict"]}, + ) + assert result == { + "items": [ + { + "nested": [ + {"nested2": {"a": 1}}, + {"nested2": {"b": 2}}, + ] + }, + { + "nested": [ + {"nested2": {"c": 3}}, + ] + }, + ] + } + + # ---- edge cases ---- + + def test_complex_missing_intermediate_key(self): + """Complex path: missing intermediate key should leave params unchanged.""" + params = {"items": [{"nested": [{"nested2": "value"}]}]} + result = ToolValidations.validate_params( + params, + {"items[].missing[].nested2": ["coerce_dict"]}, + ) + # "missing" doesn't exist, so nothing happens + assert result == {"items": [{"nested": [{"nested2": "value"}]}]} + + def test_complex_middle_not_a_list(self): + """Complex path: if an intermediate [] target is not a list, params left unchanged.""" + params = {"items": [{"nested": "not a list"}]} + result = ToolValidations.validate_params( + params, + {"items[].nested[].nested2": ["coerce_dict"]}, + ) + # "nested" is not a list, so the second [] iteration can't happen + assert result == {"items": [{"nested": "not a list"}]} + + def test_complex_empty_inner_list(self): + """Complex path: an empty inner list should remain empty.""" + params = {"items": [{"nested": []}]} + result = ToolValidations.validate_params( + params, + {"items[].nested[].nested2": ["coerce_dict"]}, + ) + assert result == {"items": [{"nested": []}]}