Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 132 additions & 52 deletions src/TiledArray/dist_eval/contraction_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,65 @@ class Summa
const ordinal_type
right_stride_local_; ///< stride for local right row iterators

// 3-d (h-grouped) grid information. The world's first
// proc_h_ * proc_h_stride ranks are partitioned into proc_h_ contiguous
// groups; slab h belongs to group h % proc_h_, and each group runs its
// own 2-d SUMMA grid (proc_grid_ is this rank's GROUP-LOCAL grid,
// constructed over the group's rank interval). proc_h_ == 1 is the
// ordinary shared-grid batched contraction.
const ordinal_type proc_h_; ///< Number of slab (h) groups
const ordinal_type
proc_h_stride_; ///< World ranks per slab group (the
///< group of slab h spans world ranks
///< [(h % proc_h_) * proc_h_stride_, ...))
const ordinal_type first_slab_; ///< This rank's group's first slab (== its
///< group index), or nh_ if this rank is
///< in no group (idle for this eval)
const ordinal_type my_slabs_; ///< Number of slabs of this rank's group

/// \return the world rank that owns result tile \p i: the within-group
/// owner (from the group-local process grid) shifted by the world-rank
/// offset of the group that owns \p i's slab. For proc_h_ == 1 the offset
/// is 0 and this is the ordinary cyclic owner.
ProcessID result_tile_owner(const ordinal_type i) const {
const ordinal_type source_index = DistEvalImpl_::perm_index_to_source(i);
// owner is independent of slab index *within a group*
const ordinal_type slab_index = source_index % result_slab_size_;
const ordinal_type tile_row = slab_index / proc_grid_.cols();
const ordinal_type tile_col = slab_index % proc_grid_.cols();
const ordinal_type proc_row = tile_row % proc_grid_.proc_rows();
const ordinal_type proc_col = tile_col % proc_grid_.proc_cols();
const ProcessID within_group = proc_row * proc_grid_.proc_cols() + proc_col;
// shift by the offset of the group that owns this tile's slab
const ordinal_type slab = source_index / result_slab_size_;
const ordinal_type group = (proc_h_ > 1ul) ? (slab % proc_h_) : 0ul;
return ProcessID(group * proc_h_stride_) + within_group;
}

/// \return the slab index of SUMMA step \p s
ordinal_type step_h(const ordinal_type s) const { return s / k_; }
/// \return the within-slab inner-dimension index of SUMMA step \p s
ordinal_type step_k(const ordinal_type s) const { return s % k_; }

/// \return the smallest SUMMA step >= \p s that belongs to one of this
/// rank's group's slabs, or nsteps_ if there is none
ordinal_type next_step(ordinal_type s) const {
if (proc_h_ == 1ul) return std::min(s, nsteps_);
if (first_slab_ >= nh_) return nsteps_; // not in any group
while (s < nsteps_ && (step_h(s) % proc_h_) != first_slab_)
s = (step_h(s) + 1ul) * k_; // jump to the start of the next slab
return std::min(s, nsteps_);
}

/// \return this rank's group-local ordinal of slab \p h (which must
/// belong to this rank's group)
ordinal_type slab_ord(const ordinal_type h) const {
return (h - first_slab_) / proc_h_;
}

/// \return the number of SUMMA steps of this rank's group's slabs
ordinal_type my_steps() const { return my_slabs_ * k_; }

typedef Future<typename right_type::eval_type>
right_future; ///< Future to a right-hand argument tile
typedef Future<typename left_type::eval_type>
Expand Down Expand Up @@ -225,7 +279,7 @@ class Summa
/// \param end The end of the row or column range
/// \param stride The row or column index stride
/// \param k The broadcast group index
/// \param max_group_size The maximum number of processes in the result
/// \param max_proc_h_stride The maximum number of processes in the result
/// group, which is equal to the number of process in this process row or
/// column as defined by \c proc_grid_.
/// \param key_offset The key that will be used to identify the process group
Expand All @@ -238,21 +292,21 @@ class Summa
const std::vector<bool>& process_mask,
ordinal_type index, const ordinal_type end,
const ordinal_type stride,
const ordinal_type max_group_size,
const ordinal_type max_proc_h_stride,
const ordinal_type k, const ordinal_type key_offset,
const ProcMap& proc_map) const {
// Generate the list of processes in rank_row
std::vector<ProcessID> proc_list(max_group_size, -1);
std::vector<ProcessID> proc_list(max_proc_h_stride, -1);

// Flag the root processes of the broadcast, which may not be included
// by shape.
ordinal_type p = k % max_group_size;
ordinal_type p = k % max_proc_h_stride;
proc_list[p] = proc_map(p);
ordinal_type count = 1ul;

// Flag all processes that have non-zero tiles
for (p = 0ul; (index < end) && (count < max_group_size);
index += stride, p = (p + 1u) % max_group_size) {
for (p = 0ul; (index < end) && (count < max_proc_h_stride);
index += stride, p = (p + 1u) % max_proc_h_stride) {
if ((proc_list[p] != -1) || (shape.is_zero(index)) || !process_mask.at(p))
continue;

Expand Down Expand Up @@ -734,7 +788,7 @@ class Summa
// broadcast root (i.e. within-slab k congruent to rank_col mod Pcols)
const ordinal_type Pcols = proc_grid_.proc_cols();

for (; s < end; ++s) {
for (s = next_step(s); s < end; s = next_step(s + 1ul)) {
const ordinal_type k = step_k(s);
if (k % Pcols != static_cast<ordinal_type>(proc_grid_.rank_col()))
continue;
Expand Down Expand Up @@ -787,7 +841,7 @@ class Summa
// broadcast root (i.e. within-slab k congruent to rank_row mod Prows)
const ordinal_type Prows = proc_grid_.proc_rows();

for (; s < end; ++s) {
for (s = next_step(s); s < end; s = next_step(s + 1ul)) {
const ordinal_type k = step_k(s);
if (k % Prows != static_cast<ordinal_type>(proc_grid_.rank_row()))
continue;
Expand Down Expand Up @@ -847,9 +901,9 @@ class Summa
/// \return The first step, greater than or equal to \c s with non-zero
/// tiles, or \c nsteps_ if none is found.
ordinal_type iterate_row(ordinal_type s) const {
// Iterate over steps until a non-zero tile is found or the end of the
// matrix is reached.
for (; s < nsteps_; ++s) {
// Iterate over this rank's group's steps until a non-zero tile is found
// or the end of the matrix is reached.
for (s = next_step(s); s < nsteps_; s = next_step(s + 1ul)) {
// Search for non-zero tiles in row k of slab h of right
ordinal_type i =
step_h(s) * right_slab_size_ + step_k(s) * proc_grid_.cols();
Expand All @@ -871,9 +925,9 @@ class Summa
/// \return The first step, greater than or equal to \c s, that contains
/// a non-zero tile. If no non-zero tile is not found, return \c nsteps_.
ordinal_type iterate_col(ordinal_type s) const {
// Iterate over steps until a non-zero tile is found or the end of the
// matrix is reached.
for (; s < nsteps_; ++s) {
// Iterate over this rank's group's steps until a non-zero tile is found
// or the end of the matrix is reached.
for (s = next_step(s); s < nsteps_; s = next_step(s + 1ul)) {
// Search column k of slab h for non-zero tiles
const ordinal_type base = step_h(s) * left_slab_size_;
for (ordinal_type i = base + left_start_local_ + step_k(s);
Expand Down Expand Up @@ -963,11 +1017,17 @@ class Summa
return tile_count;
} else {
// Construct static broadcast groups for dense arguments
// (key space [0, 2*nsteps_) is reserved for the sparse per-step groups)
const madness::DistributedID col_did(DistEvalImpl_::id(), 2ul * nsteps_);
// (key space [0, 2*nsteps_) is reserved for the sparse per-step groups,
// whose keys h*k_ and h*k_+nsteps_ are disjoint across h-groups; the
// two static keys are offset PAST that range and made group-unique so
// that two different groups' single-grid static groups never claim the
// same DistributedID with inconsistent membership)
const std::size_t static_key_base = 2ul * nsteps_ + 2ul * first_slab_;
const madness::DistributedID col_did(DistEvalImpl_::id(),
static_key_base);
col_group_ = proc_grid_.make_col_group(col_did);
const madness::DistributedID row_did(DistEvalImpl_::id(),
2ul * nsteps_ + 1ul);
static_key_base + 1ul);
row_group_ = proc_grid_.make_row_group(row_did);

#ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
Expand All @@ -985,12 +1045,12 @@ class Summa
#endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE

// Allocate memory for the reduce pair tasks (one per local result tile
// per slab).
// per slab of this rank's group).
std::allocator<ReducePairTask<op_type>> alloc;
reduce_tasks_ = alloc.allocate(nh_ * proc_grid_.local_size());
reduce_tasks_ = alloc.allocate(my_slabs_ * proc_grid_.local_size());

// Iterate over all local tiles
const ordinal_type n = nh_ * proc_grid_.local_size();
const ordinal_type n = my_slabs_ * proc_grid_.local_size();
for (ordinal_type t = 0ul; t < n; ++t) {
// Initialize the reduction task
ReducePairTask<op_type>* MADNESS_RESTRICT const reduce_task =
Expand Down Expand Up @@ -1019,9 +1079,9 @@ class Summa
if (k_ == 0) return 0;

// Allocate memory for the reduce pair tasks (one per local result tile
// per slab).
// per slab of this rank's group).
std::allocator<ReducePairTask<op_type>> alloc;
reduce_tasks_ = alloc.allocate(nh_ * proc_grid_.local_size());
reduce_tasks_ = alloc.allocate(my_slabs_ * proc_grid_.local_size());

// Initialize iteration variables
const ordinal_type col_stride = // The stride to iterate down a column
Expand All @@ -1030,10 +1090,11 @@ class Summa
proc_grid_.proc_cols();

// Iterate over all local tiles, slab by slab (the block-cyclic phase
// restarts at every slab: the owner of tile (h,i,j) does not depend on h)
// restarts at every slab: within a group, the owner of tile (h,i,j)
// does not depend on h)
ordinal_type tile_count = 0ul;
ReducePairTask<op_type>* MADNESS_RESTRICT reduce_task = reduce_tasks_;
for (ordinal_type h = 0ul; h < nh_; ++h) {
for (ordinal_type h = first_slab_; h < nh_; h += proc_h_) {
const ordinal_type slab_base = h * result_slab_size_;
ordinal_type row_start =
slab_base + proc_grid_.rank_row() * proc_grid_.cols();
Expand Down Expand Up @@ -1103,7 +1164,7 @@ class Summa

// Iterate over all local tiles, slab by slab
ReducePairTask<op_type>* reduce_task = reduce_tasks_;
for (ordinal_type h = 0ul; h < nh_; ++h) {
for (ordinal_type h = first_slab_; h < nh_; h += proc_h_) {
const ordinal_type slab_base = h * result_slab_size_;
ordinal_type row_start =
slab_base + proc_grid_.rank_row() * proc_grid_.cols();
Expand All @@ -1126,7 +1187,7 @@ class Summa

// Deallocate the memory for the reduce pair tasks.
std::allocator<ReducePairTask<op_type>>().deallocate(
reduce_tasks_, nh_ * proc_grid_.local_size());
reduce_tasks_, my_slabs_ * proc_grid_.local_size());
}

/// Set the result tiles and destroy reduce tasks
Expand All @@ -1145,7 +1206,7 @@ class Summa

// Iterate over all local tiles, slab by slab
ReducePairTask<op_type>* reduce_task = reduce_tasks_;
for (ordinal_type h = 0ul; h < nh_; ++h) {
for (ordinal_type h = first_slab_; h < nh_; h += proc_h_) {
const ordinal_type slab_base = h * result_slab_size_;
ordinal_type row_start =
slab_base + proc_grid_.rank_row() * proc_grid_.cols();
Expand Down Expand Up @@ -1177,7 +1238,7 @@ class Summa
}
// Deallocate the memory for the reduce pair tasks.
std::allocator<ReducePairTask<op_type>>().deallocate(
reduce_tasks_, nh_ * proc_grid_.local_size());
reduce_tasks_, my_slabs_ * proc_grid_.local_size());

#ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE
ss << "}\n";
Expand Down Expand Up @@ -1229,9 +1290,10 @@ class Summa
const std::vector<col_datum>& col,
const std::vector<row_datum>& row,
madness::TaskInterface* const task) {
// The reduce tasks of slab h occupy
// [h * local_size, (h+1) * local_size)
const ordinal_type slab_offset = step_h(s) * proc_grid_.local_size();
// The reduce tasks of this group's slab h occupy
// [slab_ord(h) * local_size, (slab_ord(h)+1) * local_size)
const ordinal_type slab_offset =
slab_ord(step_h(s)) * proc_grid_.local_size();

// Iterate over the row
for (ordinal_type i = 0ul; i < col.size(); ++i) {
Expand Down Expand Up @@ -1266,9 +1328,10 @@ class Summa
const std::vector<col_datum>& col,
const std::vector<row_datum>& row,
madness::TaskInterface* const task) {
// The reduce tasks of slab h occupy
// [h * local_size, (h+1) * local_size)
const ordinal_type slab_offset = step_h(s) * proc_grid_.local_size();
// The reduce tasks of this group's slab h occupy
// [slab_ord(h) * local_size, (slab_ord(h)+1) * local_size)
const ordinal_type slab_offset =
slab_ord(step_h(s)) * proc_grid_.local_size();

// Iterate over the row
for (ordinal_type i = 0ul; i < col.size(); ++i) {
Expand Down Expand Up @@ -1452,8 +1515,12 @@ class Summa

template <typename Derived>
void make_next_step_tasks(Derived* task, ordinal_type depth) {
// Set the depth to be no greater than the maximum number steps
if (depth > owner_->nsteps_) depth = owner_->nsteps_;
// Set the depth to be no greater than the number of SUMMA steps this
// rank's group actually executes. In the 2-d (proc_h_ == 1) case this is
// nsteps_ (my_slabs_ == nh_); in the 3-d (proc_h_ > 1) case my_steps() <
// nsteps_, and clamping to nsteps_ would pre-spawn surplus step tasks
// that all resolve to the terminating step (k_ == nsteps_).
if (depth > owner_->my_steps()) depth = owner_->my_steps();

// Spawn n=depth step tasks
for (; depth > 0ul; --depth) {
Expand Down Expand Up @@ -1540,13 +1607,14 @@ class Summa
public:
DenseStepTask(const std::shared_ptr<Summa_>& owner,
const ordinal_type depth)
: StepTask(owner, owner->nsteps_ + 1ul), k_(0) {
: StepTask(owner, owner->my_steps() + 1ul), k_(owner->next_step(0ul)) {
StepTask::make_next_step_tasks(this, depth);
StepTask::spawn_get_row_col_tasks(k_);
if (k_ < owner_->nsteps_) StepTask::spawn_get_row_col_tasks(k_);
Comment on lines 1608 to +1612
}

DenseStepTask(DenseStepTask* const parent, const int ndep)
: StepTask(parent, ndep), k_(parent->k_ + 1ul) {
: StepTask(parent, ndep),
k_(parent->owner_->next_step(parent->k_ + 1ul)) {
// Spawn tasks to get k-th row and column tiles
if (k_ < owner_->nsteps_) StepTask::spawn_get_row_col_tasks(k_);
}
Expand Down Expand Up @@ -1671,7 +1739,8 @@ class Summa
const trange_type trange, const shape_type& shape,
const std::shared_ptr<const pmap_interface>& pmap, const Perm& perm,
const op_type& op, const ordinal_type k, const ProcGrid& proc_grid,
const ordinal_type nh = 1ul)
const ordinal_type nh = 1ul, const ordinal_type proc_h = 1ul,
const ordinal_type proc_h_stride = 0ul)
: DistEvalImpl_(world, trange, shape, pmap, outer(perm)),
left_(left),
right_(right),
Expand All @@ -1685,6 +1754,11 @@ class Summa
left_slab_size_(left.size() / nh),
right_slab_size_(right.size() / nh),
result_slab_size_(proc_grid.rows() * proc_grid.cols()),
proc_h_(proc_h),
proc_h_stride_(proc_h_stride),
first_slab_(compute_first_slab(world, nh, proc_h, proc_h_stride)),
my_slabs_(first_slab_ < nh ? (nh - first_slab_ + proc_h - 1ul) / proc_h
: 0ul),
reduce_tasks_(NULL),
left_start_local_(proc_grid_.rank_row() * k),
left_end_(left.size() / nh),
Expand All @@ -1703,6 +1777,19 @@ class Summa
TA_ASSERT(nh_ > 0);
TA_ASSERT(left.size() % nh_ == 0);
TA_ASSERT(right.size() % nh_ == 0);
TA_ASSERT(proc_h_ > 0);
TA_ASSERT(proc_h_ == 1ul || proc_h_stride > 0ul);
TA_ASSERT(proc_h_ <= nh_);
}

/// \return this rank's group's first slab (== its group index), or
/// \p nh if this rank is outside the grouped rank interval
static ordinal_type compute_first_slab(World& world, const ordinal_type nh,
const ordinal_type proc_h,
const ordinal_type proc_h_stride) {
if (proc_h == 1ul) return 0ul;
const auto rank = ordinal_type(world.rank());
return (rank < proc_h * proc_h_stride) ? (rank / proc_h_stride) : nh;
}

virtual ~Summa() {}
Expand All @@ -1717,18 +1804,11 @@ class Summa
TA_ASSERT(TensorImpl_::is_local(i));
TA_ASSERT(!TensorImpl_::is_zero(i));

const ordinal_type source_index = DistEvalImpl_::perm_index_to_source(i);

// Compute tile coordinate in tile grid (the owner of a tile is
// independent of its slab index)
const ordinal_type slab_index = source_index % result_slab_size_;
const ordinal_type tile_row = slab_index / proc_grid_.cols();
const ordinal_type tile_col = slab_index % proc_grid_.cols();
// Compute process coordinate of tile in the process grid
const ordinal_type proc_row = tile_row % proc_grid_.proc_rows();
const ordinal_type proc_col = tile_col % proc_grid_.proc_cols();
// Compute the process that owns tile
const ProcessID source = proc_row * proc_grid_.proc_cols() + proc_col;
// The process that owns tile i: the within-group cyclic owner shifted by
// the world-rank offset of the tile's slab group (see
// result_tile_owner). For proc_h_ == 1 this is the ordinary cyclic
// owner over the whole world.
const ProcessID source = result_tile_owner(i);

const madness::DistributedID key(DistEvalImpl_::id(), i);
return TensorImpl_::world().gop.template recv<value_type>(source, key);
Expand Down
Loading
Loading