Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ codesorter follows `semantic versioning <https://semver.org/>`_.
Unreleased
************

**Fixed**

- Keep a module-level assignment that calls a local definition after the names that call
needs at runtime. ``APP = App()`` only references ``App`` syntactically, but
instantiating it runs ``App.__init__``, which may read a module-level function defined
later in the file. Such an assignment was previously hoisted above that function and
raised ``NameError`` at import. The runtime references reachable through a called
class or function (transitively) are now treated as dependencies of the calling
assignment.

********************
0.2.7 (2026/06/15)
********************
Expand Down
83 changes: 83 additions & 0 deletions codesorter/sort_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,20 @@ class SortCodeCommand(VisitorBasedCodemodCommand, m.MatcherDecoratableTransforme

METADATA_DEPENDENCIES = (md.ScopeProvider, md.QualifiedNameProvider)

@staticmethod
def _called_local_names(node: cst.SimpleStatementLine) -> set[str]:
"""Return the simple-name callees invoked in a constant's value.

``APP = App()`` calls ``App``; ``mod.factory()`` calls an attribute, which
references an imported or external object and so imposes no in-module ordering
and is ignored.

"""
return {
cst.ensure_type(cst.ensure_type(call, cst.Call).func, cst.Name).value
for call in m.findall(node, m.Call(func=m.Name()))
}

@staticmethod
def _is_sortable(member: cst.CSTNode, *, sort_constants: bool) -> bool:
"""Return True if ``member`` is a class, function, or sortable constant."""
Expand All @@ -205,6 +219,14 @@ def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
self.original_nodes: dict[str, cst.CSTNode] = {}
self.dependencies: defaultdict[str, set[str]] = defaultdict(set)
# ``body_globals[name]`` holds every module-global name referenced anywhere in a
# node's body, including deferred references inside function and method bodies
# that run only when the node is called. ``called_names[name]`` holds the local
# names a constant assignment invokes (``App`` in ``APP = App()``). Together they
# let ``_fold_runtime_dependencies`` order an import-time call after everything it
# transitively reaches at runtime.
self.body_globals: defaultdict[str, set[str]] = defaultdict(set)
self.called_names: defaultdict[str, set[str]] = defaultdict(set)
# When ``from __future__ import annotations`` is active every annotation is a
# lazy string, so a name used only in an annotation imposes no runtime ordering.
# ``_lazy_annotation_names`` holds the id of every such annotation Name node.
Expand Down Expand Up @@ -296,6 +318,32 @@ def _add_edge(earlier: int, later: int) -> None:
_add_edge(earlier, later)
return dependents, indegree

def _fold_runtime_dependencies(self) -> None:
"""Order an import-time call after everything the call reaches at runtime.

``X = factory()`` only references ``factory`` syntactically, but executing it at
import runs ``factory``'s body, so ``X`` must also follow every module-global
that body uses (and, transitively, whatever those callees use). Without this an
assignment can be hoisted above a function it needs and fail with ``NameError``
at import. Definitions gain no edges from this, so it never forges a cycle
between a class or function and its siblings.

"""
runtime = {name: set(globals_) for name, globals_ in self.body_globals.items()}
changed = True
while changed:
changed = False
for reachable in runtime.values():
additions = {
name for dependency in tuple(reachable) for name in runtime.get(dependency, ())
} - reachable
if additions:
reachable.update(additions)
changed = True
for name, callees in self.called_names.items():
for callee in callees:
self.dependencies[name].update(runtime.get(callee, ()))

def _get_dependencies( # noqa: C901
self,
node: _Sortable,
Expand Down Expand Up @@ -546,11 +594,42 @@ def _reorder_segment(
def _resolve_dependents(self, node: _Sortable) -> None:
dependencies, _ = self._get_dependencies(node)
name = _sortable_name(node)
self.body_globals[name] = self._runtime_global_names(node)
for dependency in dependencies:
self.dependencies[name].add(dependency)
for parent_dependency in self.dependencies[dependency]:
self.dependencies[name].add(parent_dependency)

def _runtime_global_names(self, node: _Sortable) -> set[str]:
"""Return every module-global name referenced anywhere within ``node``.

Unlike :meth:`_get_dependencies`, this includes names used in nested function
and method bodies, which run only when the node is called rather than when it is
defined. Each name is resolved through its own scope, so a local that shadows a
module global is correctly excluded.

"""
own = _sortable_name(node)
names: set[str] = set()
for found in m.findall(node, m.Name()):
name_node = cst.ensure_type(found, cst.Name)
if id(name_node) in self._lazy_annotation_names or name_node.value == own:
continue
scope = self.get_metadata(md.ScopeProvider, name_node, None)
if scope is None:
continue
try:
assignments = scope[name_node.value]
except KeyError:
continue
for assignment in assignments:
if isinstance(assignment, (md.BuiltinAssignment, md.ImportAssignment)):
continue
if isinstance(assignment.scope, md.GlobalScope):
names.add(name_node.value)
break
return names

def _sorted_body(
self,
body: Sequence[cst.BaseStatement],
Expand Down Expand Up @@ -645,9 +724,12 @@ def leave_Module(
updated_node: cst.Module,
) -> cst.Module:
"""Sort the module-level definitions before returning the rewritten module."""
self._fold_runtime_dependencies()
updated_node = updated_node.with_changes(body=self._sorted_body(updated_node.body, sort_constants=True))
updated_node = cst.ensure_type(updated_node.visit(KeywordArgumentSorter()), cst.Module)
self.original_nodes = {}
self.body_globals.clear()
self.called_names.clear()
return updated_node

def visit_ClassDef(self, node: cst.ClassDef) -> bool:
Expand Down Expand Up @@ -693,6 +775,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> bool:
if _constant_name(node) is not None:
self.original_nodes[_gen_unique_name(node)] = node
self._resolve_dependents(node)
self.called_names[_sortable_name(node)] = self._called_local_names(node)
return False


Expand Down
19 changes: 19 additions & 0 deletions tests/test_files/runtime_call_dependency_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

CONFIG = {"name": "default"}


def on_start() -> None:
"""A module-level function referenced from Server.__init__ at instantiation time."""
return None


class Server:
"""Constructed at import time; __init__ reads a module-level function and constant."""

def __init__(self) -> None:
self.handler = on_start
self.name = CONFIG["name"]


SERVER = Server()
19 changes: 19 additions & 0 deletions tests/test_files/runtime_call_dependency_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

CONFIG = {"name": "default"}


class Server:
"""Constructed at import time; __init__ reads a module-level function and constant."""

def __init__(self) -> None:
self.handler = on_start
self.name = CONFIG["name"]


def on_start() -> None:
"""A module-level function referenced from Server.__init__ at instantiation time."""
return None


SERVER = Server()
17 changes: 17 additions & 0 deletions tests/test_sort_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,23 @@ def test_pytest_fixtures(self, test_files):

assert expected_code == result.code

def test_runtime_call_dependency(self, test_files):
"""Test that an assignment calling a class stays after names the call needs at runtime.

``SERVER = Server()`` only references ``Server`` syntactically, but
instantiating it runs ``Server.__init__``, which reads the module-level
``on_start``. ``on_start`` is a function, which by category sorts after the
assignment, so the topological sort must keep the assignment after it rather
than hoisting it (which would raise ``NameError`` at import).

"""
input_code, expected_code = test_files
context = CodemodContext()
command = SortCodeCommand(context)
result = command.transform_module(cst.parse_module(input_code))

assert expected_code == result.code

def test_staticmethod(self, test_files):
"""Test that static methods are sorted correctly."""
input_code, expected_code = test_files
Expand Down