Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
variants = []
for mapper in self._generate(cgroup, exclude):
# Clusters -> AliasList
found = collect(mapper.extracted, meta.ispace, self.opt_minstorage)
found = collect(mapper.extracted, meta, self.opt_minstorage)
exprs, aliases = self._choose(found, cgroup, mapper)

# AliasList -> Schedule
schedule = lower_aliases(aliases, meta, self.opt_maxpar)

variants.append(Variant(schedule, exprs))
variants.append(make_variant(schedule, exprs, mapper))

if not variants:
return []
Expand Down Expand Up @@ -282,8 +282,6 @@ def _do_generate(self, exprs, exclude, cbk_search, cbk_compose=None):

class CireInvariants(CireTransformerLegacy, Queue):

_q_guards_in_key = True

def __init__(self, sregistry, options, platform):
super().__init__(sregistry, options, platform)

Expand Down Expand Up @@ -511,7 +509,7 @@ def _cbk_search2(self, expr, rank):
}


def collect(extracted, ispace, minstorage):
def collect(extracted, meta, minstorage):
"""
Find groups of aliasing expressions.

Expand Down Expand Up @@ -575,11 +573,11 @@ def collect(extracted, ispace, minstorage):

group.append(u)
unseen.remove(u)
group = Group(group, ispace=ispace)
group = Group(group, ispace=meta.ispace)

k = group.dimensions_translated if minstorage else group.dimensions

k = frozenset(d for d in k if not d.is_NonlinearDerived)

mapper.setdefault(k, []).append(group)

aliases = AliasList()
Expand Down Expand Up @@ -657,8 +655,9 @@ def collect(extracted, ispace, minstorage):

# Compute the alias score
na = g.naliases
nr = nredundants(ispace, pivot)
nr = nredundants(meta.ispace, pivot)
score = estimate_cost(pivot, True)*((na - 1) + nr)

aliases.add(pivot, aliaseds, list(mapper), distances, score)

return aliases
Expand Down Expand Up @@ -728,8 +727,9 @@ def lower_aliases(aliases, meta, maxpar):
m = i.dim.symbolic_min - i.dim.parent.symbolic_min
else:
m = 0
d = dmapper[i.dim] = IncrDimension(f"{i.dim.name}s", i.dim, m,
dd.symbolic_size, 1, dd.step)
d = dmapper[i.dim] = IncrDimension(
f"{i.dim.name}s", i.dim, m, dd.symbolic_size, 1, dd.step
)
sub_iterators[i.dim] = d
else:
d = i.dim
Expand All @@ -745,6 +745,11 @@ def lower_aliases(aliases, meta, maxpar):
# The alias write-to space
writeto = IterationSpace(IntervalGroup(writeto), sub_iterators)

# Avoid scalar aliases in the presence of guards, since hoisting them
# might cause scope issues (see `test_dse.py::TestAliases::test_split_cond`)
if not writeto and meta.guards:
continue

# The alias iteration space
ispace = IterationSpace(IntervalGroup(intervals, meta.ispace.relations),
meta.ispace.sub_iterators,
Expand All @@ -764,6 +769,34 @@ def lower_aliases(aliases, meta, maxpar):
return Schedule(*processed, dmapper=dmapper, is_frame=aliases.is_frame)


def make_variant(schedule, exprs, mapper):
"""
Create a Variant from a Schedule and the corresponding expressions.
"""
# Some aliases may have been discarded along the way, and for
# them we reinstate the original sub-expressions
retained = flatten(sa.aliaseds for sa in schedule)

subs = {}
for k, v in mapper.items():
if v in retained:
continue
elif isinstance(v, dict):
# E.g., `mapper = {u[t0, x+3, y+3] + u[t0, x+3, y+4]:
# {u[t0, x+3, y+4]: None, u[t0, x+3, y+3]: dummy0}}`
try:
v1, = [i for i in v.values() if i not in retained]
except ValueError:
continue
subs[v1] = k
else:
subs[v] = k

exprs = [uxreplace(e, subs) for e in exprs]

return Variant(schedule, exprs)


def optimize_schedule_rotations(schedule, sregistry):
"""
Transform the schedule such that the tensor temporaries "rotate" along
Expand Down Expand Up @@ -1493,7 +1526,7 @@ def nredundants(ispace, expr):
redundant if it defines an iteration space for `expr` while not appearing
among its free symbols. Note that the converse isn't generally true: there
could be a Dimension that does not appear in the free symbols while defining
a non-redundant iteration space (e.g., a BlockDimension).
a non-redundant iteration space (e.g., a BlockDimension or a reduction).
"""
iterated = {i.dim for i in ispace}
used = {i for i in expr.free_symbols if i.is_Dimension}
Expand Down
22 changes: 13 additions & 9 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from devito.types import Array, StencilDimension, Symbol
from devito.types.basic import Scalar
from devito.types.dimension import AffineIndexAccessFunction, Thickness
from devito.types.misc import Temp


class TestIndexAccessFunction:
Expand Down Expand Up @@ -2130,9 +2131,10 @@ def test_topofusion_w_subdims_conddims(self):
assert exprs[0].write is h

exprs = FindNodes(Expression).visit(bns['x2_blk0'])
assert len(exprs) == 2
assert exprs[0].write is fsave
assert exprs[1].write is gsave
assert len(exprs) == 3
assert isinstance(exprs[0].expr.lhs, Temp)
assert exprs[1].write is fsave
assert exprs[2].write is gsave

def test_topofusion_w_subdims_conddims_v2(self):
"""
Expand Down Expand Up @@ -2163,9 +2165,10 @@ def test_topofusion_w_subdims_conddims_v2(self):
bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0'})
assert len(FindNodes(Expression).visit(bns['x0_blk0'])) == 3
exprs = FindNodes(Expression).visit(bns['x1_blk0'])
assert len(exprs) == 2
assert exprs[0].write is fsave
assert exprs[1].write is gsave
assert len(exprs) == 3
assert isinstance(exprs[0].expr.lhs, Temp)
assert exprs[1].write is fsave
assert exprs[2].write is gsave

def test_topofusion_w_subdims_conddims_v3(self):
"""
Expand Down Expand Up @@ -2200,9 +2203,10 @@ def test_topofusion_w_subdims_conddims_v3(self):
assert exprs[1].write is g

exprs = FindNodes(Expression).visit(bns['x2_blk0'])
assert len(exprs) == 2
assert exprs[0].write is fsave
assert exprs[1].write is gsave
assert len(exprs) == 3
assert isinstance(exprs[0].expr.lhs, Temp)
assert exprs[1].write is fsave
assert exprs[2].write is gsave

# Additional nest due to anti-dependence
exprs = FindNodes(Expression).visit(bns['x1_blk0'])
Expand Down
38 changes: 17 additions & 21 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Conditional, DummyEq, Expression, FindNodes, FindSymbols, Iteration,
ParallelIteration, retrieve_iteration_tree
)
from devito.passes.clusters.aliases import collect
from devito.passes.clusters.aliases import AliasKey, collect
from devito.passes.clusters.factorization import collect_nested
from devito.passes.iet.parpragma import VExpanded
from devito.symbolics import ( # noqa
Expand Down Expand Up @@ -423,8 +423,9 @@ def test_collection(self, exprs, expected):

extracted = {i.rhs: i.lhs for i in exprs}
ispace = exprs[0].ispace
meta = AliasKey(ispace, None, None, None, None)

aliases = collect(extracted, ispace, False)
aliases = collect(extracted, meta, False)
aliases.filter(lambda a: a.score > 0)

assert len(aliases) == len(expected)
Expand Down Expand Up @@ -2553,15 +2554,15 @@ def test_invariants_with_conditional(self):

op = Operator(eqn, opt='advanced')

assert_structure(op, ['t', 't,fd', 't,fd,x,y'], 't,fd,x,y')
assert_structure(op, ['t', 't,fd,x,y'], 't,fd,x,y')
# Make sure it compiles
_ = op.cfunction

# Check hoisting for time invariant
eqn = Eq(u, u - (cos(time_sub * factor * f) * sin(g) * uf))

op = Operator(eqn, opt='advanced')
assert_structure(op, ['x,y', 't', 't,fd', 't,fd,x,y'], 'x,y,t,fd,x,y')
assert_structure(op, ['x,y', 't', 't,fd,x,y'], 'x,y,t,fd,x,y')
# Make sure it compiles
_ = op.cfunction

Expand Down Expand Up @@ -2705,10 +2706,9 @@ def test_split_cond(self):

cond = FindNodes(Conditional).visit(op)
assert len(cond) == 3
# Each guard should have its own alias for cos(time)
assert 'float r0 = cos(time);' in str(body0(op))
# No aliases in this case due to guards
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
assert len(scalars) == 2
assert len(scalars) == 0

def test_split_cond_multi_alias(self):
grid = Grid((11, 11))
Expand All @@ -2728,11 +2728,9 @@ def test_split_cond_multi_alias(self):

cond = FindNodes(Conditional).visit(op)
assert len(cond) == 3
# Each guard should have its own aliases for cos(time) and sin(time)
assert 'const float r0 = sin(time) + cos(time)' in str(body0(op))
assert 'const float r1 = cos(time);' in str(body0(op))
# No aliases in this case due to guards
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
assert len(scalars) == 3
assert len(scalars) == 0

def test_multi_cond_no_split(self):
grid = Grid((11, 11))
Expand All @@ -2758,7 +2756,7 @@ def test_multi_cond_no_split(self):
)

scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
assert len(scalars) == 3
assert len(scalars) == 0

def test_alias_with_conditional(self):
grid = Grid((11, 11))
Expand All @@ -2779,9 +2777,9 @@ def test_alias_with_conditional(self):
cond = FindNodes(Conditional).visit(op)
assert len(cond) == 3

# Each guard should have its own alias for cos(time/ctf)
# No aliases in this case due to guards
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
assert len(scalars) == 2
assert len(scalars) == 0

def test_scalar_alias_interp(self):
grid = Grid(shape=(11, 11))
Expand Down Expand Up @@ -2825,9 +2823,9 @@ def test_scalar_with_cond_access(self):
cond = FindNodes(Conditional).visit(op)
assert len(cond) == 3

# # Each guard should have its own alias for cos/sin(f1[time-2])
# The guards prevent some aliases from being hoisted out
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
assert len(scalars) == 3
assert len(scalars) == 0

assert_structure(
op,
Expand Down Expand Up @@ -2855,21 +2853,19 @@ def test_scalar_with_cond_tinvariant(self):

cond = FindNodes(Conditional).visit(op)
assert len(cond) == 1
# One for each 1/dt 1/dt**2
# One for 1/dt, while 1/dt**2 ain't hoisted out due to the guard
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
assert len(scalars) == 2
assert len(scalars) == 1

assert_structure(
op,
['t,x,y', 't', 't,x,y'],
'txyxy'
)

# Both aliases should be hoisted outside the time loop
# The 1/dt alias should be hoisted outside the time loop
assert str(body0(op).body[0]) == 'const float r0 = 1.0F/dt;'
assert not body0(op).body[0].ispace
assert str(body0(op).body[1]) == 'const float r1 = 1.0F/(dt*dt);'
assert not body0(op).body[1].ispace


class TestIsoAcoustic:
Expand Down
Loading