Skip to content

feat(torch): add generated operator bases#622

Open
voltjia wants to merge 1 commit into
feat/torch-codegen-optional-overloadsfrom
feat/torch-operator-bases
Open

feat(torch): add generated operator bases#622
voltjia wants to merge 1 commit into
feat/torch-codegen-optional-overloadsfrom
feat/torch-operator-bases

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented May 27, 2026

Summary

  • Add generated C++ operator base headers under src/base/, regenerated from the current feat/torch-codegen-optional-overloads generator output.
  • Keep wrapper generation scalable for the larger base-header set by parallelizing wrapper generation by default and skipping wrappers for operators without platform implementations.
  • Skip device tests cleanly when an operator wrapper is not available in the current build.

Motivation

The PyTorch codegen work needs a checked-in operator-base layer that matches the current generator behavior, including optional-parameter overload support from PR #619. The larger generated base set also exposes wrapper-generation cost and missing-wrapper test behavior that need to be handled before the bases can be validated across all supported platforms.

Closes # N/A — no dedicated issue.

Type of Change

  • feat — new feature / new operator / new platform
  • fix — bug fix
  • perf — performance improvement (no behavioral change)
  • refactor — code restructuring without behavior change
  • test — adding or fixing tests only
  • docs — documentation only
  • build / ci — build system or CI configuration
  • chore — tooling, formatting, or other non-code changes
  • Breaking change (requires a ! in the Conventional Commits prefix or a BREAKING CHANGE: footer)

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • Build system / CMake / CI
  • Python bindings / user-facing API

Test Results on Supported Platforms

All rows used a full bare python3 -m pytest -v run, without tests/, --devices, or -n. Each build regenerated PyTorch operator sources first, installed with WITH_TORCH=ON, and smoke-checked representative generated PyTorch operators after install. Build times are from the pip install phase recorded by the local validation runner; pytest times are from the timed pytest command; total time is generate + build + pytest.

Platform Built pytest Result Generate Build Pytest Total Notes / Hardware
NVIDIA Yes 6304 passed, 11538 skipped in 337.64s 2s 1039s 348s 1389s Full bare pytest. PyTorch backend compiled, generated torch-op tests were included, and representative ops reported active platform implementation index 8.
Iluvatar Yes 4804 passed, 11520 skipped in 557.07s 14s 848s 561s 1423s Full bare pytest. PyTorch backend compiled, generated torch-op tests were included, and representative ops reported active platform implementation index 8.
MetaX Yes 5804 passed, 10520 skipped in 354.50s 2s 1386s 371s 1759s Full bare pytest. PyTorch backend compiled, generated torch-op tests were included, and representative ops reported active platform implementation index 8.
Cambricon Yes 3082 passed, 12858 skipped in 921.19s 3s 2266s 929s 3198s Full bare pytest. PyTorch backend compiled, generated torch-op tests were included, and representative ops reported active platform implementation index 8.
Moore Yes 5768 passed, 10574 skipped in 569.91s 6s 2239s 577s 2822s Full bare pytest. PyTorch backend compiled, generated torch-op tests were included, and representative ops reported active platform implementation index 8.
Ascend Yes 4481 passed, 11801 skipped in 533.13s 3s 1111s 549s 1663s Full bare pytest. PyTorch backend compiled, NPU test cases were visible in pytest, generated torch-op tests were included, and representative ops reported active platform implementation index 8. The container exited with code 137 after pytest had already emitted a passing summary.
Full `pytest` output (optional)
NVIDIA:    6304 passed, 11538 skipped in 337.64s (0:05:37)
Iluvatar:  4804 passed, 11520 skipped in 557.07s (0:09:17)
MetaX:     5804 passed, 10520 skipped in 354.50s (0:05:54)
Cambricon: 3082 passed, 12858 skipped in 921.19s (0:15:21)
Moore:     5768 passed, 10574 skipped in 569.91s (0:09:29)
Ascend:    4481 passed, 11801 skipped in 533.13s (0:08:53)

The test counts are higher than earlier PR-body snapshots because this branch now includes the generated PyTorch operator-base set and full bare pytest collects tests/test_torch_ops.py with WITH_TORCH=ON. The platform-to-platform count differences come from vendor PyTorch schema availability and explicit known vendor-kernel skips in the generated torch-op harness.

Benchmark / Performance Impact

Wrapper generation now runs in parallel by default and skips base headers that have no platform implementation in the scanned implementation set. This keeps the generated-base branch practical to build with the larger checked-in src/base/ set; the all-platform validation above records the resulting generate, build, and pytest times.

Notes for Reviewers

This PR is stacked on PR #619 because the regenerated base signatures depend on the optional-parameter hashing and overload handling introduced there. The PR should be retargeted to master after #619 lands, or merged after #619.

The generated base files are intentionally checked in as generator output. File paths are kept flat under src/base/.

The generated bases intentionally omit src/base/all.h, src/base/any.h, and src/base/internal_scaled_mm.h in this PR because their ATen schemas vary across installed PyTorch builds; those forms are better regenerated by the local codegen environment instead of frozen as stable public bases.


Checklist

Title, Branch, and Commits

  • PR title follows Conventional Commits (e.g. feat(nvidia): …, fix(cuda/gemm): …).
  • Branch name follows <type>/xxx-yyyy-zzzz where <type> matches the PR title's Conventional Commits type and words are joined with hyphens (see CONTRIBUTING.md §Branches).
  • Each commit message follows Conventional Commits.
  • Small PR is a single squashable commit; or, for a large PR, every commit is meaningful, well-formed, and independently reviewable (see CONTRIBUTING.md §Pull Requests).
  • No stray merge commits from master — the branch is rebased cleanly on top of the current master through PR feat(torch): expose optional codegen parameters #619.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal — nothing unrelated to the stated motivation was added (CONTRIBUTING.md §Code/General).
  • No dead code, commented-out blocks, debug prints, printf/std::cout/print(...) left behind, or TODO without an owner and issue link.
  • No unrelated formatting churn that would obscure the diff.
  • Public API changes (if any) are intentional, documented, and reflected in affected callers/tests.

General Code Hygiene (applies to all languages)

  • The code is self-explanatory; comments were added only where the why is non-obvious (CONTRIBUTING.md §Code/General).
  • Every modified or added file ends with a single trailing newline (CONTRIBUTING.md §Code/General).
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • Identifiers in comments and error messages are wrapped in backticks (e.g. the `seqlens_k` tensor) (CONTRIBUTING.md §Code/General).
  • All comments and error messages are in English (CONTRIBUTING.md §Code/General).
  • Comments and error messages are complete sentences — capitalized first letter, terminal punctuation — unless the language/framework convention says otherwise (CONTRIBUTING.md §Code/General; §Python).

C++ Specific (if C++ files changed)

  • Code follows the Google C++ Style Guide strictly.
  • clang-format (version 21, per .github/workflows/clang-format.yml) has been run against all modified .h, .cc, .cuh, and .mlu files; the diff is clean.
  • clang-tidy concerns (per .clang-tidy) have been reviewed — no new warnings beyond the existing baseline.
  • Operator parameter order is inputs first, outputs last; attributes are between inputs and outputs; naming follows PyTorch → ONNX → CUDA API precedence (CONTRIBUTING.md §C++).
  • No exceptions are thrown. Error paths use assert with messages that include at least __FILE__, __LINE__, and __func__ (CONTRIBUTING.md §C++).
  • Error and warning message wording follows the LLVM Coding Standards (CONTRIBUTING.md §C++).
  • N/A — Kernel files are named correctly; this PR adds operator bases, not kernels.
  • N/A — Kernel and kernel launcher separation is unchanged; this PR adds operator bases, not kernels.
  • Constructor initializer list order matches member declaration order (CONTRIBUTING.md §C++).
  • Exactly one blank line between classes, between classes and functions, and between functions (CONTRIBUTING.md §C++).
  • Exactly one blank line between members (functions and variables) within a class (CONTRIBUTING.md §C++).
  • Exactly one blank line before and after the contents of a namespace (CONTRIBUTING.md §C++).
  • New operators added via src/base/<op>.h (inheriting Operator<Op>) with platform implementations under src/<category>/<platform>/ inheriting the base (CONTRIBUTING.md §Adding an Operator).
  • No raw new/delete; RAII / smart pointers / existing allocators are used.

Python Specific (if Python files changed)

  • Code is PEP 8 compliant; ruff check passes cleanly on CI (see .github/workflows/ruff.yml).
  • ruff format --check passes cleanly — if not, run ruff format and commit the result.
  • Comments are complete English sentences, starting with a capital letter and ending with punctuation; Markdown backticks are used for code references (CONTRIBUTING.md §Python).
  • Framework-specific conventions (e.g. lowercase pytest.skip messages without terminal period) are honored where applicable (CONTRIBUTING.md §Python).
  • No blank line between the function signature and the body when there is no docstring or comment (CONTRIBUTING.md §Python).
  • A blank line is present before and after if, for, and similar control-flow statements (CONTRIBUTING.md §Python).
  • A blank line appears before each return, except when it directly follows a control-flow statement (CONTRIBUTING.md §Python).
  • N/A — No new docstrings were added.
  • Type hints are added / kept consistent with the surrounding code.

Testing

  • pytest was run locally on every supported platform that this PR can affect, and the results are recorded in the "Test Results" table above (CONTRIBUTING.md §Pull Requests).
  • N/A — Every supported platform was tested.
  • New functionality has matching tests under tests/ following tests/test_add.py / tests/test_gemm.py patterns (CONTRIBUTING.md §Adding an Operator).
  • Tests use pytest.mark.parametrize correctly: dependent parameters share one decorator (e.g. @pytest.mark.parametrize("dtype, rtol, atol", …)), independent parameters use separate decorators ordered by parameter declaration.
  • Where appropriate, pytest.mark.auto_act_and_assert is used and the test returns a Payload whose func and ref share the same calling convention.
  • Default dtype / device parameterization is relied on, or overridden with an explicit pytest.mark.parametrize when necessary.
  • N/A — No new flaky parallel-only test was added.
  • N/A — This is not a bug-fix-only PR.

Build, CI, and Tooling

  • The project builds cleanly from a fresh directory with pip install .[dev] on at least one affected platform.
  • compile_commands.json still regenerates (CMake option CMAKE_EXPORT_COMPILE_COMMANDS=ON in pyproject.toml — required by the code-lint skill and clang-tidy -p).
  • N/A — No new backend / device was added.
  • Only one CUDA-like GPU backend is selectable at a time — the existing mutual-exclusion check in CMakeLists.txt is not broken.
  • Both CI workflows (clang-format.yml, ruff.yml) are green locally (or expected to be green on CI).
  • No new runtime dependency was added without updating pyproject.toml's [project.optional-dependencies] (or justified in the PR description).

Documentation

  • N/A — No user workflow, build flag, or developer workflow documentation changed.
  • New operators, new dispatch helpers, or new public utilities are documented (docstring, header comment, or an addition to CONTRIBUTING.md §Some Code Explanations).
  • N/A — No user-visible breaking change is introduced.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, IP addresses, or personal hardware identifiers have been committed.
  • N/A — No third-party code was added.
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks were introduced.

@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from 9444f9c to 9864ff2 Compare May 27, 2026 19:51
@voltjia voltjia force-pushed the feat/torch-operator-bases branch from 33e537c to ffc3d68 Compare May 27, 2026 19:54
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from 9864ff2 to c0db647 Compare May 27, 2026 20:27
@voltjia voltjia force-pushed the feat/torch-operator-bases branch from ffc3d68 to d89ce8e Compare May 27, 2026 20:28
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from c0db647 to 3e3e319 Compare May 27, 2026 21:15
@voltjia voltjia force-pushed the feat/torch-operator-bases branch 2 times, most recently from fe50963 to c5a3a38 Compare May 27, 2026 21:51
@voltjia voltjia force-pushed the feat/torch-operator-bases branch from c5a3a38 to 312cd42 Compare May 27, 2026 22:25
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from 3e3e319 to 2a5d6af Compare May 27, 2026 23:33
@voltjia voltjia force-pushed the feat/torch-operator-bases branch 2 times, most recently from 34db70e to f5f6a15 Compare May 28, 2026 03:39
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from d41f01d to 9f591db Compare May 28, 2026 03:55
@voltjia voltjia force-pushed the feat/torch-operator-bases branch from f5f6a15 to ee42c3c Compare May 28, 2026 03:56
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from 9f591db to 70094a1 Compare May 28, 2026 07:41
@voltjia voltjia force-pushed the feat/torch-operator-bases branch from ee42c3c to 9299ffb Compare May 28, 2026 07:44
@voltjia voltjia force-pushed the feat/torch-codegen-optional-overloads branch from 70094a1 to 87e86ab Compare May 28, 2026 08:02
@voltjia voltjia force-pushed the feat/torch-operator-bases branch from 9299ffb to 1c61728 Compare May 28, 2026 08:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants