Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/dolfinx_adjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .assembly import assemble_scalar, error_norm
from .solvers import LinearProblem, NonlinearProblem
from .types import Constant, Function
from .types import Constant, Function, dirichletbc
from .types.function import assign

meta = metadata("dolfinx_adjoint")
Expand All @@ -24,6 +24,7 @@
__all__ = [
"Constant",
"Function",
"dirichletbc",
"LinearProblem",
"NonlinearProblem",
"assemble_scalar",
Expand Down
24 changes: 24 additions & 0 deletions src/dolfinx_adjoint/blocks/dirichletbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import dolfinx
import numpy as np
import numpy.typing as npt
from pyadjoint.block import Block


class DirichletBCBlock(Block):
def __init__(
self,
value: dolfinx.fem.Function | dolfinx.fem.Constant,
dofs: npt.NDArray[np.int32],
V: dolfinx.fem.FunctionSpace | None = None,
ad_block_tag=None,
):
super().__init__(ad_block_tag=ad_block_tag)
self.dofs = dofs
self.V = V
self.add_dependency(value)

def prepare_recompute_component(self, inputs, relevant_outputs):
return inputs[0] if inputs else None

def recompute_component(self, inputs, block_variable, idx, prepared):
return block_variable.saved_output
9 changes: 6 additions & 3 deletions src/dolfinx_adjoint/blocks/function_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,18 @@ def prepare_recompute_component(self, inputs, relevant_outputs):
def recompute_component(self, inputs, block_variable, idx, prepared):
if self.expr is None:
prepared = inputs[0]
output = dolfinx.fem.Function(
block_variable.output.function_space, name="f{block_variable.output.name}_AssignBlockRecompute"
)

# We should return the exact object instance to maintain C++ memory bindings
# (especially for DirichletBCs), updating it in-place.
output = block_variable.saved_output

try:
if output.function_space == prepared.function_space:
output.x.array[:] = prepared.x.array[:]
except AttributeError:
# Handling float value
output.x.array[:] = prepared

return output

def __str__(self):
Expand Down
20 changes: 9 additions & 11 deletions src/dolfinx_adjoint/blocks/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self.add_dependency(c, no_duplicates=True)
for c in self._rhs.coefficients(): # type: ignore
self.add_dependency(c, no_duplicates=True)

except AttributeError:
raise NotImplementedError("Blocked systems not implemented yet.")
self._compiled_lhs = dolfinx.fem.form(
Expand All @@ -86,6 +87,13 @@ def __init__(
self._petsc_options = petsc_options if petsc_options is not None else {}
self._petsc_options_prefix = petsc_options_prefix
self._bcs = bcs if bcs is not None else []

# Add dependencies from the boundary conditions
if self._bcs is not None:
for bc in self._bcs:
if hasattr(bc, "block_variable"):
self.add_dependency(bc, no_duplicates=True)

# Solver for recomputing the linear problem
self._forward_solver = dolfinx.fem.petsc.LinearProblem(
a=self._lhs,
Expand Down Expand Up @@ -162,16 +170,6 @@ def prepare_recompute_component(self, inputs, relevant_outputs):
else:
initial_guess = [dolfinx.fem.Function(u.function_space, name=u.name + "_initial_guess") for u in self._u]

# Replace values in the DirichletBC if it is dependent on a control
# NOTE: Currently assume that BCS are control independent.
bcs = self._bcs
# for block_variable in self.get_dependencies():
# c = block_variable.output
# c_rep = block_variable.saved_output

# if isinstance(c, dolfinx.fem.DirichletBC):
# bcs.append(c_rep)

# Replace form coefficients with checkpointed values.
# Loop through the dependencies of the lhs and rhs, check if they are in the respective form
lhs = self._replace_coefficients_in_form(self._lhs)
Expand Down Expand Up @@ -206,7 +204,7 @@ def prepare_recompute_component(self, inputs, relevant_outputs):
self._forward_solver._a = compiled_lhs
self._forward_solver._L = compiled_rhs
self._forward_solver._P = compiled_preconditioner
self._forward_solver.bcs = bcs
self._forward_solver.bcs = self._bcs
self._forward_solver._u = initial_guess

def recompute_component(
Expand Down
3 changes: 2 additions & 1 deletion src/dolfinx_adjoint/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ["Function", "Constant"]
__all__ = ["Function", "Constant", "dirichletbc"]

from .dirichletbc import dirichletbc
from .function import Constant, Function
63 changes: 63 additions & 0 deletions src/dolfinx_adjoint/types/dirichletbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import dolfinx
import numpy as np
import numpy.typing as npt
import pyadjoint
from pyadjoint.overloaded_type import FloatingType

from ..blocks.dirichletbc import DirichletBCBlock
from .function import Function


class DirichletBC(dolfinx.fem.DirichletBC, FloatingType):
"""A class overloading `dolfinx.fem.DirichletBC` to support it being used as a control variable
in the adjoint framework.

Args:
g: The value of the Dirichlet BC.
dofs: An array of degree-of-freedom indices in `V` where the BC should be applied.
**kwargs: Additional keyword arguments to pass to the `pyadjoint.overloaded_type.FloatingType` constructor.

"""

def __init__(self, g: Function, dofs: npt.NDArray[np.int32], **kwargs):
dtype = g.dtype
if np.issubdtype(dtype, np.float32):
bctype = dolfinx.cpp.fem.DirichletBC_float32
elif np.issubdtype(dtype, np.float64):
bctype = dolfinx.cpp.fem.DirichletBC_float64
elif np.issubdtype(dtype, np.complex64):
bctype = dolfinx.cpp.fem.DirichletBC_complex64
elif np.issubdtype(dtype, np.complex128):
bctype = dolfinx.cpp.fem.DirichletBC_complex128
else:
raise NotImplementedError(f"Type {dtype} not supported.")

super().__init__(bctype(g._cpp_object, dofs))

annotate = kwargs.pop("annotate", True)
annotate = annotate and pyadjoint.annotate_tape()

FloatingType.__init__(
self,
g,
dtype=dtype,
block_class=kwargs.pop("block_class", DirichletBCBlock),
_ad_floating_active=False,
_ad_args=kwargs.pop("_ad_args", (g, dofs)),
annotate=annotate,
**kwargs,
)

if annotate:
self._ad_annotate_block()

def _ad_create_checkpoint(self):
return self

def _ad_restore_at_checkpoint(self, checkpoint):
return self


def dirichletbc(value: Function, dofs: npt.NDArray[np.int32], **kwargs) -> DirichletBC:
"""Overloaded DirichletBC constructor that creates an adjoint-aware DirichletBC"""
return DirichletBC(value, dofs, **kwargs)
137 changes: 137 additions & 0 deletions tests/test_dirichlet_bc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from mpi4py import MPI

import dolfinx
import numpy as np
import pyadjoint
import ufl
from pyadjoint.overloaded_type import Weakref

from dolfinx_adjoint import Function, LinearProblem, assemble_scalar, assign, dirichletbc
from dolfinx_adjoint.blocks.dirichletbc import DirichletBCBlock


def test_dirichletbc_recording():
"""Test that creating an overloaded dirichletbc correctly registers a block and dependency on the tape."""
pyadjoint.get_working_tape().clear_tape()
mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10)
V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1))

c = Function(V, name="boundary_value")
c.interpolate(lambda x: x[0])

dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0))
bc = dirichletbc(c, dofs)

tape = pyadjoint.get_working_tape()
blocks = tape.get_blocks()

# The tape should have 1 block: DirichletBCBlock
assert len(blocks) == 1
assert isinstance(blocks[0], DirichletBCBlock)

# The block should have exactly 1 dependency (the function 'c')
assert len(blocks[0].get_dependencies()) == 1
assert blocks[0].get_dependencies()[0].output is c

# The returned BC object should now possess the injected block_variable
assert hasattr(bc, "block_variable")


def test_dirichletbc_no_annotate():
"""Test that setting annotate=False bypasses tape recording entirely."""

pyadjoint.get_working_tape().clear_tape()
mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10)
V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1))

c = Function(V, name="boundary_value")
c.interpolate(lambda x: x[0])

dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0))

# Run with annotation off
bc = dirichletbc(c, dofs, annotate=False)

tape = pyadjoint.get_working_tape()

assert len(tape.get_blocks()) == 0
# FIX: Check the underlying weak reference rather than invoking the property
assert getattr(bc, "_block_variable", Weakref())() is None


def test_dirichletbc_recompute():
"""Test the PyAdjoint internal recompute logic specifically for the DirichletBCBlock."""
pyadjoint.get_working_tape().clear_tape()
mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10)
V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1))

c = Function(V, name="boundary_value")
c.interpolate(lambda x: np.full_like(x[0], 5.0))

dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0))
bc = dirichletbc(c, dofs)

tape = pyadjoint.get_working_tape()
block = tape.get_blocks()[0]

# Simulate an optimizer changing the function value
c.interpolate(lambda x: np.full_like(x[0], 15.0))

# Replay the PyAdjoint mechanics manually
prepared = block.prepare_recompute_component([c], None)
new_bc = block.recompute_component([c], bc.block_variable, 0, prepared)

# Assert that the re-instantiated C++ object captured the updated control value
assert isinstance(new_bc, dolfinx.fem.bcs.DirichletBC)
assert np.isclose(new_bc.g.x.array[0], 15.0)


def test_time_dependent_bc_replay():
pyadjoint.get_working_tape().clear_tape()

mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 8, 8)
V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1))

dt = 0.1
num_steps = 3

m = Function(V, name="control")
m.interpolate(lambda x: np.sin(x[0] * np.pi))

u = ufl.TrialFunction(V)
v = ufl.TestFunction(V)

uh = Function(V, name="state")
assign(0.0, uh)

u_prev = Function(V, name="state_prev")
assign(0.0, u_prev)

F = (u - u_prev) / dt * v * ufl.dx + ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx - m * v * ufl.dx
a, L = ufl.system(F)

bc_func = Function(V, name="bc_func")
mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim)
boundary_facets = dolfinx.mesh.exterior_facet_indices(mesh.topology)
boundary_dofs = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundary_facets)

# Use native dolfinx here! PyAdjoint traces the bc_func inside it.
bc = dirichletbc(bc_func, boundary_dofs)

problem = LinearProblem(a, L, bcs=[bc], u=uh)

J = 0.0

for i in range(num_steps):
assign(float(i + 1), bc_func)
problem.solve()
J += assemble_scalar(0.5 * ufl.inner(uh, uh) * ufl.dx)
assign(uh, u_prev)

J_forward = float(J)

control = pyadjoint.Control(m)
Jhat = pyadjoint.ReducedFunctional(J, control)
J_replay = Jhat(m)

assert np.isclose(J_replay, J_forward, atol=1e-10, rtol=1e-10)