Fix AMX support using MultiRamp#9122
Conversation
The previous comment reported a time that seemed to have regressed. It was not 8.2ms on main - more like 11
Before: Computing best tile sizes for each type ................................................. bytes, tile width, tile height, bandwidth (GB/s): 1 8 8 20.9997 1 16 8 20.8329 1 8 16 18.5702 1 8 32 17.2463 1 8 64 14.312 2 8 16 19.2047 2 8 8 18.8368 2 16 8 17.0593 2 8 32 17.0591 2 4 8 15.7681 4 8 8 24.9364 4 4 16 22.9699 4 8 16 22.5743 4 4 32 22.255 4 4 8 20.4468 8 8 8 38.4094 8 16 4 28.4167 8 16 8 27.6184 8 8 4 27.6062 8 8 16 26.8693 After: Computing best tile sizes for each type ................................................. bytes, tile width, tile height, bandwidth (GB/s): 1 16 32 34.1921 1 16 16 31.8399 1 8 16 25.575 1 16 64 25.1665 1 32 16 25.0061 2 8 32 28.2635 2 8 16 27.7648 2 16 16 27.2126 2 16 32 23.9034 2 8 8 23.6345 4 8 16 34.5303 4 8 8 28.3653 4 16 8 26.8521 4 8 32 26.084 4 16 16 24.4519 8 8 8 33.7163 8 8 4 29.1339 8 4 16 26.418 8 16 4 25.4663 8 2 8 24.3949
Also better algorithm for innermost containing stmt
We might want to look into this: https://github.github.com/gh-stack/ |
A correctness test that exercises ten of the user-facing error paths in ExtractTileOperations.cpp. Each scenario is the most natural-looking matmul pattern that triggers a particular reject, doubling as a TODO list of cases we'd ideally support but currently don't: - too_large tile_x > 16 - bad_result_type i8 * i8 -> i16 (AMX always accumulates i32/f32) - naive_rhs row-major RHS without VNNI packing - indirect gather-style A(r, row_indices(y)) * B(...) - conv1d 1D conv of a 2D signal (LHS depends on x, k, y) - no_matmul store_in(AMXTile) on a non-matmul Func - widening_16bit i16 * i16 -> i32 (only 8-bit inputs supported) - inconsistent_tiles one Func with two updates at different tile sizes - not_a_matmul_pattern row-sum into AMXTile (no multiply) - scaled_matmul A(r, y) * 3 (RHS hoisted out of the reduce) The harness wraps each scenario in try/catch around Halide::CompileError. Halide is sometimes built without exceptions; in that case the test prints [SKIP] and exits 0 since the catch path can't fire. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| failures += !expect_user_error("too_large", scenario_too_large); | ||
| failures += !expect_user_error("bad_result_type", scenario_bad_result_type); | ||
| failures += !expect_user_error("naive_rhs", scenario_naive_rhs); | ||
| failures += !expect_user_error("indirect", scenario_indirect); | ||
| failures += !expect_user_error("conv1d", scenario_conv1d); | ||
| failures += !expect_user_error("no_matmul", scenario_no_matmul); | ||
| failures += !expect_user_error("widening_16bit", scenario_widening_16bit); | ||
| failures += !expect_user_error("inconsistent_tiles", scenario_inconsistent_tiles); | ||
| failures += !expect_user_error("not_a_matmul_pattern", scenario_not_a_matmul_pattern); | ||
| failures += !expect_user_error("matmul_by_constant", scenario_matmul_by_constant); |
There was a problem hiding this comment.
At a glance, I thought that the string was something that we expected to find in e.what(), but this isn't the case. That would be useful, though.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #9122 +/- ##
=======================================
Coverage ? 69.59%
=======================================
Files ? 255
Lines ? 78264
Branches ? 18722
=======================================
Hits ? 54470
Misses ? 18178
Partials ? 5616 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
It would be great to make that happen on GHA |
|
TODO: Instead of checking memory_type = AMXTile elsewhere in the compiler, add a helper function that checks if a memory type is natively 2D, which will currently only return true for AMXTile Edit: Done |
|
TODO: The exception-using test won't even compile if exceptions are not enabled Edit: Done |
1. Add is_tile_memory_type(MemoryType) helper, and use it at the four sites in StageStridedLoads and Deinterleave that previously hard-coded `mt == AMXTile` checks to opt out of load/store rewrites. Currently the helper just returns `t == AMXTile`, but future natively-2D memory types (e.g. other vendors' matrix accumulators) can be added by extending it in one place. 2. Guard tiled_matmul_errors with #if HALIDE_WITH_EXCEPTIONS so the compilation doesn't fail when exceptions are disabled — matching the pattern used by other exception-using correctness tests (exception.cpp, bad_partition_always_throws.cpp, etc.). The runtime exceptions_enabled() check is still there so the test also handles the rare case of being built with HALIDE_WITH_EXCEPTIONS but linked against a libHalide that wasn't. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Rewrites the AMX support to use MultiRamp. This, I believe, fixes the outstanding bugs in AMX support identified by #8350
Validated by running the AMX tests under SDE.
Future work is generalizing the AMX support to be willing to ingest larger vectors, and automatically slice it up into multiple tile-level operations. More TODO scenarios are in the test tiled_matmul_errors.cpp