Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
756 changes: 0 additions & 756 deletions bin/generate_answers.ipynb

This file was deleted.

1 change: 0 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ TiledArray/expressions/binary_expr.h
TiledArray/expressions/blk_tsr_engine.h
TiledArray/expressions/blk_tsr_expr.h
TiledArray/expressions/cont_engine.h
TiledArray/expressions/contraction_helpers.h
TiledArray/expressions/expr.h
TiledArray/expressions/expr_engine.h
TiledArray/expressions/expr_trace.h
Expand Down
64 changes: 50 additions & 14 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,23 +673,38 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
auto range_map =
(RangeMap(a, A.array().trange()) | RangeMap(b, B.array().trange()));

// special Hadamard
// Fused broadcast: one operand is ENTIRELY fused (h == its index set) and
// there is no contraction -- C(h,e) = A * B, a per-fused-block scale. The
// general-product expression engine evaluates this natively when the
// fully-fused operand LEADS (it folds to a rank-0 left tile, restored by a
// synthetic unit left-external mode; see ContEngine::
// synthetic_unit_left_external). Order the fully-fused operand first and
// delegate -- no physical replication needed.
if (h.size() == a.size() || h.size() == b.size()) {
TA_ASSERT(!i && e);
bool const small_a = h.size() == a.size();
auto const delta_trng = make_trange(range_map, e);
_ein_call.branch = "fused-broadcast-expression";
const bool a_leads = (h.size() == a.size()); // A entirely fused?
// the engine produces the result in its canonical inner layout
// (inner fused, then left-then-right inner externals, in the order the
// operands are handed to it) and rejects a non-identity inner result
// permutation; evaluate canonically, then apply the inner permutation
// (cf. the generalized-inner-perm-recurse path below). The canonical
// inner-external order tracks the operand order, so mirror the swap.
std::string canon_inner;
if constexpr (IsArrayToT<ArrayC>) {
auto inner_e = a_leads ? (inner.A ^ inner.B) : (inner.B ^ inner.A);
canon_inner = ";" + (std::string)(inner.h + inner_e);
}
std::string canon_layout = std::string(c) + canon_inner;
ArrayC C0;
if (a_leads) // A fully fused -> already leads
C0(canon_layout) = A * B;
else // B fully fused -> put it first (multiplication commutes)
C0(canon_layout) = B * A;
std::string target_layout = std::string(c) + inner.c;
if (target_layout == canon_layout) return C0;
ArrayC C;
if (small_a) {
auto temp = replicate_array(A.array(), delta_trng);
std::string temp_layout = std::string(e) + "," + A.annotation();
C(target_layout) = temp(temp_layout) * B;
} else {
auto temp = replicate_array(B.array(), delta_trng);
std::string temp_layout = std::string(e) + "," + B.annotation();
C(target_layout) = A * temp(temp_layout);
}

C(target_layout) = C0(canon_layout); // inner-permute to requested layout
return C;
}

Expand Down Expand Up @@ -791,7 +806,28 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
}
}

if (!e) { // hadamard reduction
if (!e) { // no external outer index (fused + contracted): a nesting-
// preserving reduction. By default the expression layer handles
// it natively (synthetic unit left-external mode, see
// ContEngine::synthetic_unit_left_external); the legacy local
// kernel below is retained only as the opt-in legacy
// cross-check oracle (cf. the generalized-subworld path).
if (!detail::einsum_legacy_subworld()) {
_ein_call.branch = "no-external-expression";
// evaluate in the engine's canonical inner layout, then apply any
// inner result permutation (cf. the broadcast / inner-perm-recurse)
std::string canon_inner;
if constexpr (IsArrayToT<ArrayC>)
canon_inner = ";" + (std::string)(inner.h + inner.e);
std::string canon_layout = std::string(c) + canon_inner;
ArrayC C0;
C0(canon_layout) = A * B;
std::string target_layout = std::string(c) + inner.c;
if (target_layout == canon_layout) return C0;
ArrayC C;
C(target_layout) = C0(canon_layout);
return C;
}

_ein_call.branch = "hadamard-reduction-local";
const auto _ein_he_t0 = _ein_call.active ? now() : time_point{};
Expand Down
57 changes: 38 additions & 19 deletions src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,18 +677,20 @@ class ContEngine : public BinaryEngine<Derived> {
this->init_perm(target_indices);
general_repermute_ = (outer(target_indices) != outer(indices_));

// A product with NO external (free) outer indices (every outer index
// fused or contracted, e.g. C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b"))
// folds to a GEMM with no free modes, i.e. rank-0 tensors, which the
// tile kernels do not support. Evaluate it with a SYNTHETIC UNIT
// left-external mode instead: the folded product becomes
// (1,K) x (K) -> (1), the exact shape of the (supported) one-sided
// neB == 0 case. The unit mode lives only in the tile op's GemmHelper;
// tranges, shapes and tiles carry the true (external-free) ranks, and
// BatchedContractReduce / SparseShape::gemm_batched detect the
// synthetic mode from the one-rank mismatch and pad their folded views
// Some degenerate folded shapes would carry a rank-0 tensor, which the
// tile kernels do not support (see synthetic_unit_left_external()):
// - a NO-EXTERNAL product (every outer index fused or contracted, e.g.
// C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) folds to a rank-0 RESULT;
// - a FUSED BROADCAST (the left operand is entirely fused, no contraction,
// e.g. C("b,k") = A("b") * B("b,k")) folds to a rank-0 LEFT operand.
// Evaluate both with a SYNTHETIC UNIT left-external mode: the folded
// product becomes (1,K) x (K,N) -> (1,N), a supported shape (the no-
// external case has N == 1, the broadcast has K == 1). The unit mode lives
// only in the tile op's GemmHelper; tranges, shapes and tiles carry the
// true ranks, and BatchedContractReduce / SparseShape::gemm_batched detect
// the synthetic mode from the one-rank mismatch and pad their folded views
// with a unit extent.
const unsigned int u = (outer_size(indices_) == nh) ? 1u : 0u;
const unsigned int u = synthetic_unit_left_external();

// the tile op operates on the folded (fused-mode-free) shapes; the
// synthetic unit mode leads the folded left operand, so it is NoTrans
Expand Down Expand Up @@ -749,17 +751,34 @@ class ContEngine : public BinaryEngine<Derived> {
}
}

/// \return 1 if the folded general product needs a SYNTHETIC unit
/// left-external mode, else 0. The folded (fused-mode-free) GEMM cannot host
/// a rank-0 tensor, which arises in two degenerate cases:
/// - rank-0 RESULT: every outer index is fused or contracted (no external),
/// i.e. outer_size(indices_) == n_fused_modes_;
/// - rank-0 LEFT operand: the left argument is entirely fused with no
/// contraction (a fused broadcast / per-fused-block scale), i.e.
/// outer_size(left_indices_) == n_fused_modes_.
/// A unit left-external mode (carried only in the GemmHelper) restores a
/// supported (1,K) x (K,N) -> (1,N) shape in both cases.
unsigned int synthetic_unit_left_external() const {
return (outer_size(indices_) == n_fused_modes_ ||
outer_size(left_indices_) == n_fused_modes_)
? 1u
: 0u;
}

/// Tiled range factory function for a general product

/// \return The result tiled range: the fused mode ranges followed by the
/// left- and right-external mode ranges
trange_type make_trange_general() const {
const unsigned int nh = n_fused_modes_;
const unsigned int nc = op_.gemm_helper().num_contract_ranks();
// the no-external case carries a synthetic unit left-external mode in
// the GemmHelper only (see init_struct_general); the actual tranges do
// not have it
const unsigned int u = (outer_size(indices_) == n_fused_modes_) ? 1u : 0u;
// degenerate folds carry a synthetic unit left-external mode in the
// GemmHelper only (see synthetic_unit_left_external()); the actual tranges
// do not have it
const unsigned int u = synthetic_unit_left_external();
const unsigned int neA = op_.gemm_helper().left_rank() - nc - u;
const unsigned int neB = op_.gemm_helper().right_rank() - nc;

Expand Down Expand Up @@ -812,10 +831,10 @@ class ContEngine : public BinaryEngine<Derived> {
std::shared_ptr<const pmap_interface> pmap) {
const unsigned int nh = n_fused_modes_;
const unsigned int nc = op_.gemm_helper().num_contract_ranks();
// the no-external case carries a synthetic unit left-external mode in
// the GemmHelper only (see init_struct_general); the actual tranges do
// not have it
const unsigned int u = (outer_size(indices_) == nh) ? 1u : 0u;
// degenerate folds carry a synthetic unit left-external mode in the
// GemmHelper only (see synthetic_unit_left_external()); the actual tranges
// do not have it
const unsigned int u = synthetic_unit_left_external();
const unsigned int neA = op_.gemm_helper().left_rank() - nc - u;
const unsigned int neB = op_.gemm_helper().right_rank() - nc;

Expand Down
Loading
Loading