From 350f8b23429d61a12ae4e47d86a3d4f2f0c05197 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Thu, 11 Jun 2026 12:34:15 +0200 Subject: [PATCH 1/6] Add beginning of dirichlebc --- src/dolfinx_adjoint/__init__.py | 3 +- src/dolfinx_adjoint/blocks/solvers.py | 15 +- src/dolfinx_adjoint/types/__init__.py | 3 +- src/dolfinx_adjoint/types/dirichletbc.py | 66 +++++++++ tests/test_dirichlet_bc.py | 167 +++++++++++++++++++++++ 5 files changed, 251 insertions(+), 3 deletions(-) create mode 100644 src/dolfinx_adjoint/types/dirichletbc.py create mode 100644 tests/test_dirichlet_bc.py diff --git a/src/dolfinx_adjoint/__init__.py b/src/dolfinx_adjoint/__init__.py index 00e96e7..7a0a688 100644 --- a/src/dolfinx_adjoint/__init__.py +++ b/src/dolfinx_adjoint/__init__.py @@ -7,7 +7,7 @@ from .assembly import assemble_scalar, error_norm from .solvers import LinearProblem, NonlinearProblem -from .types import Constant, Function +from .types import Constant, Function, dirichletbc from .types.function import assign meta = metadata("dolfinx_adjoint") @@ -24,6 +24,7 @@ __all__ = [ "Constant", "Function", + "dirichletbc", "LinearProblem", "NonlinearProblem", "assemble_scalar", diff --git a/src/dolfinx_adjoint/blocks/solvers.py b/src/dolfinx_adjoint/blocks/solvers.py index b5db2e3..ae14c65 100644 --- a/src/dolfinx_adjoint/blocks/solvers.py +++ b/src/dolfinx_adjoint/blocks/solvers.py @@ -64,6 +64,7 @@ def __init__( self.add_dependency(c, no_duplicates=True) for c in self._rhs.coefficients(): # type: ignore self.add_dependency(c, no_duplicates=True) + except AttributeError: raise NotImplementedError("Blocked systems not implemented yet.") self._compiled_lhs = dolfinx.fem.form( @@ -86,6 +87,11 @@ def __init__( self._petsc_options = petsc_options if petsc_options is not None else {} self._petsc_options_prefix = petsc_options_prefix self._bcs = bcs if bcs is not None else [] + + # Add dependencies from the boundary conditions + for bc in self._bcs: + self.add_dependency(bc, no_duplicates=True) + # Solver for recomputing the linear problem self._forward_solver = dolfinx.fem.petsc.LinearProblem( a=self._lhs, @@ -164,13 +170,20 @@ def prepare_recompute_component(self, inputs, relevant_outputs): # Replace values in the DirichletBC if it is dependent on a control # NOTE: Currently assume that BCS are control independent. - bcs = self._bcs + # bcs = self._bcs # for block_variable in self.get_dependencies(): # c = block_variable.output # c_rep = block_variable.saved_output # if isinstance(c, dolfinx.fem.DirichletBC): # bcs.append(c_rep) + bcs = [] + for bc in self._bcs: + if hasattr(bc, "block_variable") and bc.block_variable in self.get_dependencies(): + # Extract the newly minted dolfinx.fem.DirichletBC from the DirichletBCBlock + bcs.append(bc.block_variable.saved_output) + else: + bcs.append(bc) # Replace form coefficients with checkpointed values. # Loop through the dependencies of the lhs and rhs, check if they are in the respective form diff --git a/src/dolfinx_adjoint/types/__init__.py b/src/dolfinx_adjoint/types/__init__.py index c1a57a0..9ada54f 100644 --- a/src/dolfinx_adjoint/types/__init__.py +++ b/src/dolfinx_adjoint/types/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["Function", "Constant"] +__all__ = ["Function", "Constant", "dirichletbc"] from .function import Constant, Function +from .dirichletbc import dirichletbc diff --git a/src/dolfinx_adjoint/types/dirichletbc.py b/src/dolfinx_adjoint/types/dirichletbc.py new file mode 100644 index 0000000..f6c21d7 --- /dev/null +++ b/src/dolfinx_adjoint/types/dirichletbc.py @@ -0,0 +1,66 @@ +import dolfinx +from pyadjoint.block import Block +from pyadjoint.block_variable import BlockVariable +from pyadjoint.tape import annotate_tape, get_working_tape, stop_annotating + + +class DirichletBCBlock(Block): + def __init__(self, value, dofs, V=None, ad_block_tag=None): + super().__init__(ad_block_tag=ad_block_tag) + self.dofs = dofs + self.V = V + + # Add dependency on the underlying overloaded Function or Constant + self.add_dependency(value) + + def __str__(self): + return "dirichletbc" + + def prepare_recompute_component(self, inputs, relevant_outputs): + # Extract the checkpointed `value` from the inputs + return inputs[0] if inputs else None + + def recompute_component(self, inputs, block_variable, idx, prepared): + # Re-instantiate the FEniCSx boundary condition with the rewound tape value + with stop_annotating(): + return dolfinx.fem.dirichletbc(prepared, self.dofs, self.V) + + # Empty stubs required by the PyAdjoint Block interface for passive nodes + def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): + pass + + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): + pass + + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): + pass + + def evaluate_hessian_component( + self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None + ): + pass + + +def dirichletbc(value, dofs, V=None, **kwargs): + """Overloaded dolfinx.fem.dirichletbc.""" + annotate = annotate_tape(kwargs) + + with stop_annotating(): + bc = dolfinx.fem.dirichletbc(value, dofs, V) + + if annotate and hasattr(value, "block_variable"): + block = DirichletBCBlock(value, dofs, V, ad_block_tag=kwargs.get("ad_block_tag")) + get_working_tape().add_block(block) + + bv = BlockVariable(bc) + bc.block_variable = bv + + bc._ad_will_add_as_output = lambda: False + bc._ad_will_add_as_dependency = lambda: False + bc._ad_create_checkpoint = lambda: None + bc._ad_restore_at_checkpoint = lambda checkpoint: bc + # -------------------------------------------------------------------- + + block.add_output(bv) + + return bc diff --git a/tests/test_dirichlet_bc.py b/tests/test_dirichlet_bc.py new file mode 100644 index 0000000..aefb258 --- /dev/null +++ b/tests/test_dirichlet_bc.py @@ -0,0 +1,167 @@ +from mpi4py import MPI + +import dolfinx +import numpy as np +import pyadjoint +import ufl + +from dolfinx_adjoint import Function, LinearProblem, assemble_scalar, assign, dirichletbc +from dolfinx_adjoint.types.dirichletbc import DirichletBCBlock + + +def test_dirichletbc_recording(): + """Test that creating an overloaded dirichletbc correctly registers a block and dependency on the tape.""" + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: x[0]) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + bc = dirichletbc(c, dofs) + + tape = pyadjoint.get_working_tape() + blocks = tape.get_blocks() + + # The tape should have 1 block: DirichletBCBlock + assert len(blocks) == 1 + assert isinstance(blocks[0], DirichletBCBlock) + + # The block should have exactly 1 dependency (the function 'c') + assert len(blocks[0].get_dependencies()) == 1 + assert blocks[0].get_dependencies()[0].output is c + + # The returned BC object should now possess the injected block_variable + assert hasattr(bc, "block_variable") + + +def test_dirichletbc_no_annotate(): + """Test that setting annotate=False bypasses tape recording entirely.""" + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: x[0]) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + + # Run with annotation off + bc = dirichletbc(c, dofs, annotate=False) + + tape = pyadjoint.get_working_tape() + + assert len(tape.get_blocks()) == 0 + assert not hasattr(bc, "block_variable") + + +def test_dirichletbc_recompute(): + """Test the PyAdjoint internal recompute logic specifically for the DirichletBCBlock.""" + pyadjoint.get_working_tape().clear_tape() + mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + c = Function(V, name="boundary_value") + c.interpolate(lambda x: np.full_like(x[0], 5.0)) + + dofs = dolfinx.fem.locate_dofs_geometrical(V, lambda x: np.isclose(x[0], 0.0)) + bc = dirichletbc(c, dofs) + + tape = pyadjoint.get_working_tape() + block = tape.get_blocks()[0] + + # Simulate an optimizer changing the function value + c.interpolate(lambda x: np.full_like(x[0], 15.0)) + + # Replay the PyAdjoint mechanics manually + prepared = block.prepare_recompute_component([c], None) + new_bc = block.recompute_component([c], bc.block_variable, 0, prepared) + + # Assert that the re-instantiated C++ object captured the updated control value + assert isinstance(new_bc, dolfinx.fem.bcs.DirichletBC) + assert np.isclose(new_bc.g.x.array[0], 15.0) + + +def test_time_dependent_bc_replay(): + """ + Tests that time-dependent boundary conditions updated via `assign` are + properly tracked by LinearProblemBlock and that the solver does not reuse + polluted hot-state memory (which breaks Dirichlet lifting) during tape replays. + """ + pyadjoint.get_working_tape().clear_tape() + + mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 8, 8) + V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) + + dt = 0.1 + num_steps = 3 + + # Define the control variable (source term) + m = Function(V, name="control") + m.interpolate(lambda x: np.sin(x[0] * np.pi)) + + u = ufl.TrialFunction(V) + v = ufl.TestFunction(V) + + # State variables + uh = Function(V, name="state") + assign(0.0, uh) + u_prev = Function(V, name="state_prev") + + u_prev = Function(V, name="state_prev") + assign(0.0, u_prev) + + # Formulate a simple heat equation: (u - u_prev)/dt - div(grad(u)) = m + F = (u - u_prev) / dt * v * ufl.dx + ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx - m * v * ufl.dx + a, L = ufl.system(F) + + # Create a time-dependent boundary condition function + bc_func = Function(V, name="bc_func") + mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim) + boundary_facets = dolfinx.mesh.exterior_facet_indices(mesh.topology) + boundary_dofs = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundary_facets) + bc = dirichletbc(bc_func, boundary_dofs) + + # Initialize the overloaded solver + problem = LinearProblem(a, L, bcs=[bc], u=uh) + + J = 0.0 + + # Forward time-stepping loop + for i in range(num_steps): + # 1. Dynamically assign a new boundary value + # If the DAG is missing dependencies, PyAdjoint won't know this happened! + assign(float(i + 1), bc_func) + + # 2. Solve the PDE + problem.solve() + + # 3. Accumulate objective + J += assemble_scalar(0.5 * ufl.inner(uh, uh) * ufl.dx) + + # 4. Advance time + assign(uh, u_prev) + + # Extract the total forward cost + J_forward = float(J) + + # Create the reduced functional + control = pyadjoint.Control(m) + Jhat = pyadjoint.ReducedFunctional(J, control) + + # Re-evaluate the tape using the EXACT same control parameters + J_replay = Jhat(m) + + # If the solver caches the "hot state" of `uh` from the end of the forward run, + # FEniCSx will use those large values during the Dirichlet lifting step + # of the replay, causing the PDE to explode and J_replay to be drastically wrong. + assert np.isclose(J_replay, J_forward, atol=1e-10, rtol=1e-10), ( + f"Tape replay failed! Forward J = {J_forward:.6e}, Replay J = {J_replay:.6e}. " + "The solver is likely caching corrupted 'hot state' memory or missing BC dependencies." + ) + + # Finally, ensure gradients can be computed cleanly without tape corruption + dJdm = Jhat.derivative() + assert dJdm is not None, "Derivative computation failed!" + assert np.linalg.norm(dJdm.x.array) > 0, "Gradient evaluated to absolute zero!" From 9db8d731008c22b974da61c815e75f4ffdcdc52e Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Thu, 11 Jun 2026 12:55:15 +0200 Subject: [PATCH 2/6] Try using the .g function instead --- src/dolfinx_adjoint/blocks/solvers.py | 12 ++++++-- tests/test_dirichlet_bc.py | 41 +++------------------------ 2 files changed, 14 insertions(+), 39 deletions(-) diff --git a/src/dolfinx_adjoint/blocks/solvers.py b/src/dolfinx_adjoint/blocks/solvers.py index ae14c65..a723a46 100644 --- a/src/dolfinx_adjoint/blocks/solvers.py +++ b/src/dolfinx_adjoint/blocks/solvers.py @@ -89,8 +89,12 @@ def __init__( self._bcs = bcs if bcs is not None else [] # Add dependencies from the boundary conditions - for bc in self._bcs: - self.add_dependency(bc, no_duplicates=True) + # for bc in self._bcs: + # self.add_dependency(bc, no_duplicates=True) + if self._bcs is not None: + for bc in self._bcs: + if hasattr(bc, "g") and hasattr(bc.g, "block_variable"): + self.add_dependency(bc.g, no_duplicates=True) # Solver for recomputing the linear problem self._forward_solver = dolfinx.fem.petsc.LinearProblem( @@ -177,6 +181,10 @@ def prepare_recompute_component(self, inputs, relevant_outputs): # if isinstance(c, dolfinx.fem.DirichletBC): # bcs.append(c_rep) + if self._bcs is not None: + for bc in self._bcs: + if hasattr(bc, "g") and hasattr(bc.g, "block_variable"): + bc.g.x.array[:] = bc.g.block_variable.saved_output.x.array[:] bcs = [] for bc in self._bcs: if hasattr(bc, "block_variable") and bc.block_variable in self.get_dependencies(): diff --git a/tests/test_dirichlet_bc.py b/tests/test_dirichlet_bc.py index aefb258..2227bb7 100644 --- a/tests/test_dirichlet_bc.py +++ b/tests/test_dirichlet_bc.py @@ -84,11 +84,6 @@ def test_dirichletbc_recompute(): def test_time_dependent_bc_replay(): - """ - Tests that time-dependent boundary conditions updated via `assign` are - properly tracked by LinearProblemBlock and that the solver does not reuse - polluted hot-state memory (which breaks Dirichlet lifting) during tape replays. - """ pyadjoint.get_working_tape().clear_tape() mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 8, 8) @@ -97,71 +92,43 @@ def test_time_dependent_bc_replay(): dt = 0.1 num_steps = 3 - # Define the control variable (source term) m = Function(V, name="control") m.interpolate(lambda x: np.sin(x[0] * np.pi)) u = ufl.TrialFunction(V) v = ufl.TestFunction(V) - # State variables uh = Function(V, name="state") assign(0.0, uh) - u_prev = Function(V, name="state_prev") u_prev = Function(V, name="state_prev") assign(0.0, u_prev) - # Formulate a simple heat equation: (u - u_prev)/dt - div(grad(u)) = m F = (u - u_prev) / dt * v * ufl.dx + ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx - m * v * ufl.dx a, L = ufl.system(F) - # Create a time-dependent boundary condition function bc_func = Function(V, name="bc_func") mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim) boundary_facets = dolfinx.mesh.exterior_facet_indices(mesh.topology) boundary_dofs = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundary_facets) - bc = dirichletbc(bc_func, boundary_dofs) - # Initialize the overloaded solver + # Use native dolfinx here! PyAdjoint traces the bc_func inside it. + bc = dolfinx.fem.dirichletbc(bc_func, boundary_dofs) + problem = LinearProblem(a, L, bcs=[bc], u=uh) J = 0.0 - # Forward time-stepping loop for i in range(num_steps): - # 1. Dynamically assign a new boundary value - # If the DAG is missing dependencies, PyAdjoint won't know this happened! assign(float(i + 1), bc_func) - - # 2. Solve the PDE problem.solve() - - # 3. Accumulate objective J += assemble_scalar(0.5 * ufl.inner(uh, uh) * ufl.dx) - - # 4. Advance time assign(uh, u_prev) - # Extract the total forward cost J_forward = float(J) - # Create the reduced functional control = pyadjoint.Control(m) Jhat = pyadjoint.ReducedFunctional(J, control) - - # Re-evaluate the tape using the EXACT same control parameters J_replay = Jhat(m) - # If the solver caches the "hot state" of `uh` from the end of the forward run, - # FEniCSx will use those large values during the Dirichlet lifting step - # of the replay, causing the PDE to explode and J_replay to be drastically wrong. - assert np.isclose(J_replay, J_forward, atol=1e-10, rtol=1e-10), ( - f"Tape replay failed! Forward J = {J_forward:.6e}, Replay J = {J_replay:.6e}. " - "The solver is likely caching corrupted 'hot state' memory or missing BC dependencies." - ) - - # Finally, ensure gradients can be computed cleanly without tape corruption - dJdm = Jhat.derivative() - assert dJdm is not None, "Derivative computation failed!" - assert np.linalg.norm(dJdm.x.array) > 0, "Gradient evaluated to absolute zero!" + assert np.isclose(J_replay, J_forward, atol=1e-10, rtol=1e-10) From 424be4ce4920a8ecd9eab469ba2cd8c2fc2a6b5f Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Tue, 16 Jun 2026 15:30:04 +0200 Subject: [PATCH 3/6] New attempt to implement overload for dirichletbc inspired by irksome --- src/dolfinx_adjoint/blocks/dirichletbc.py | 40 ++++ src/dolfinx_adjoint/blocks/solvers.py | 3 +- src/dolfinx_adjoint/types/dirichletbc.py | 249 ++++++++++++++++------ tests/test_dirichlet_bc.py | 4 +- 4 files changed, 228 insertions(+), 68 deletions(-) create mode 100644 src/dolfinx_adjoint/blocks/dirichletbc.py diff --git a/src/dolfinx_adjoint/blocks/dirichletbc.py b/src/dolfinx_adjoint/blocks/dirichletbc.py new file mode 100644 index 0000000..2a42b7e --- /dev/null +++ b/src/dolfinx_adjoint/blocks/dirichletbc.py @@ -0,0 +1,40 @@ +import dolfinx +from pyadjoint.block import Block +from pyadjoint.tape import stop_annotating + + +class DirichletBCBlock(Block): + def __init__(self, value, dofs, V=None, ad_block_tag=None): + super().__init__(ad_block_tag=ad_block_tag) + self.dofs = dofs + self.V = V + + # Add dependency on the underlying overloaded Function or Constant + self.add_dependency(value) + + def __str__(self): + return "dirichletbc" + + def prepare_recompute_component(self, inputs, relevant_outputs): + # Extract the checkpointed `value` from the inputs + return inputs[0] if inputs else None + + def recompute_component(self, inputs, block_variable, idx, prepared): + # Re-instantiate the FEniCSx boundary condition with the rewound tape value + with stop_annotating(): + return dolfinx.fem.dirichletbc(prepared, self.dofs, self.V) + + # Empty stubs required by the PyAdjoint Block interface for passive nodes + def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): + pass + + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): + pass + + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): + pass + + def evaluate_hessian_component( + self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None + ): + pass diff --git a/src/dolfinx_adjoint/blocks/solvers.py b/src/dolfinx_adjoint/blocks/solvers.py index a723a46..093dc83 100644 --- a/src/dolfinx_adjoint/blocks/solvers.py +++ b/src/dolfinx_adjoint/blocks/solvers.py @@ -93,8 +93,7 @@ def __init__( # self.add_dependency(bc, no_duplicates=True) if self._bcs is not None: for bc in self._bcs: - if hasattr(bc, "g") and hasattr(bc.g, "block_variable"): - self.add_dependency(bc.g, no_duplicates=True) + self.add_dependency(bc, no_duplicates=True) # Solver for recomputing the linear problem self._forward_solver = dolfinx.fem.petsc.LinearProblem( diff --git a/src/dolfinx_adjoint/types/dirichletbc.py b/src/dolfinx_adjoint/types/dirichletbc.py index f6c21d7..67564ba 100644 --- a/src/dolfinx_adjoint/types/dirichletbc.py +++ b/src/dolfinx_adjoint/types/dirichletbc.py @@ -1,66 +1,187 @@ import dolfinx -from pyadjoint.block import Block -from pyadjoint.block_variable import BlockVariable -from pyadjoint.tape import annotate_tape, get_working_tape, stop_annotating - - -class DirichletBCBlock(Block): - def __init__(self, value, dofs, V=None, ad_block_tag=None): - super().__init__(ad_block_tag=ad_block_tag) - self.dofs = dofs - self.V = V - - # Add dependency on the underlying overloaded Function or Constant - self.add_dependency(value) - - def __str__(self): - return "dirichletbc" - - def prepare_recompute_component(self, inputs, relevant_outputs): - # Extract the checkpointed `value` from the inputs - return inputs[0] if inputs else None - - def recompute_component(self, inputs, block_variable, idx, prepared): - # Re-instantiate the FEniCSx boundary condition with the rewound tape value - with stop_annotating(): - return dolfinx.fem.dirichletbc(prepared, self.dofs, self.V) - - # Empty stubs required by the PyAdjoint Block interface for passive nodes - def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): - pass - - def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): - pass - - def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): - pass - - def evaluate_hessian_component( - self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None +import numpy as np +import numpy.typing as npt +import ufl +from pyadjoint.overloaded_type import ( + FloatingType, + create_overloaded_object, +) + +# from pyadjoint.block_variable import BlockVariable +# from pyadjoint.tape import annotate_tape, get_working_tape, stop_annotating + + +def extract_dtype(expr: ufl.core.expr.Expr) -> npt.DTypeLike: + """Extract the dtype from an expression. + + Looks for any constants or coefficients and returning their dtype. + This is necessary for determining which DOLFINx DirichletBC constructor + to use when packing UFL expressions into DOLFINx Expressions for use in + BC reconstruction. + """ + consts = ufl.algorithms.analysis.extract_constants(expr) + for c in consts: + if hasattr(c, "dtype"): + return c.dtype + coeffs = ufl.algorithms.extract_coefficients(expr) + for c in coeffs: + if hasattr(c, "dtype"): + return c.dtype + raise ValueError( + "Could not extract dtype from expression, " + "please ensure that all constants and coefficients have a " + "dtype attribute" + ) + + +class DirichletBC(dolfinx.fem.DirichletBC, FloatingType): + _pack_expression: dolfinx.fem.Expression | None + _ufl_expr: ufl.core.expr.Expr | None # Store original UFL expression + + def __init__( + self, + g: ufl.core.expr.Expr, + dofs: npt.NDArray[np.int32], + V: dolfinx.fem.FunctionSpace, + name: str = "dirichletbc", + **kwargs, ): - pass - - -def dirichletbc(value, dofs, V=None, **kwargs): - """Overloaded dolfinx.fem.dirichletbc.""" - annotate = annotate_tape(kwargs) - - with stop_annotating(): - bc = dolfinx.fem.dirichletbc(value, dofs, V) - - if annotate and hasattr(value, "block_variable"): - block = DirichletBCBlock(value, dofs, V, ad_block_tag=kwargs.get("ad_block_tag")) - get_working_tape().add_block(block) - - bv = BlockVariable(bc) - bc.block_variable = bv - - bc._ad_will_add_as_output = lambda: False - bc._ad_will_add_as_dependency = lambda: False - bc._ad_create_checkpoint = lambda: None - bc._ad_restore_at_checkpoint = lambda checkpoint: bc - # -------------------------------------------------------------------- - - block.add_output(bv) - - return bc + """ + Create an Irksome compatible DirichletBC from an existing DOLFINx bc. + + :param g: The boundary condition expression + :param dofs: An array of degree-of-freedom indices in V + :param V: The space to construct the BC on. + :param name: The name of the boundary condition. + """ + # Attach UFL function space (to be able to reconstruct functions and constants on the same UFL domain) + self.name = name + self._ufl_space = V.ufl_function_space() + + # Store original UFL expression for time-varying BCs + if not isinstance(g, (dolfinx.fem.Function, dolfinx.fem.Constant, int, float, complex)): + self._ufl_expr = g # Save the symbolic expression + else: + self._ufl_expr = None + self._ufl_space = V.ufl_function_space() + + # If reconstructing with a sub space, we need to get the subspace dof indices + # If working with a subspace of a single stage, we need to create the (parent_dof, sub_dof) mapping + if V.component() != []: + V_sub, sub_to_parent = V.collapse() + if len(sub_to_parent) != 1: + msg = "Mixed topology is not supported for reconstructing BCs with UFL expressions" + raise NotImplementedError(msg) + else: + sub_to_parent = sub_to_parent[0] + parent_to_sub = np.full( + (V.dofmap.index_map.size_local + V.dofmap.index_map.num_ghosts) * V.dofmap.index_map_bs, + -1, + dtype=np.int32, + ) + parent_to_sub[sub_to_parent] = np.arange(len(sub_to_parent)) + sub_dofs = parent_to_sub[dofs] + dofs = (dofs, sub_dofs) + + # If we are not reconstructing the BC with a new value, + # we can reuse existing C++ objects + self._pack_expression = None + + # If we are reconstructing the BC with a new value, + # we need to check if the new value is a DOLFINx function or Constant. + # If True, we do not need to do anything for reconstruction. + if isinstance(g, (dolfinx.fem.Function, dolfinx.fem.Constant)): + val = g + self._pack_expression = None + else: + # If not, we need to take the ufl.core.expr.Expr and pack it into a DOLFINx Expression + if V.component() != []: + val = dolfinx.fem.Function(V_sub, name=f"bc_{str(g)}")._cpp_object + else: + val = dolfinx.fem.Function(V, name=f"bc_{str(g)}")._cpp_object + self._pack_expression = dolfinx.fem.Expression(g, V.element.interpolation_points) + + # Get correct C++ implementation based on dtype of expression + dtype = extract_dtype(g) + if np.issubdtype(dtype, np.float32): + bctype = dolfinx.cpp.fem.DirichletBC_float32 + elif np.issubdtype(dtype, np.float64): + bctype = dolfinx.cpp.fem.DirichletBC_float64 + elif np.issubdtype(dtype, np.complex64): + bctype = dolfinx.cpp.fem.DirichletBC_complex64 + elif np.issubdtype(dtype, np.complex128): + bctype = dolfinx.cpp.fem.DirichletBC_complex128 + else: + raise NotImplementedError(f"Type {dtype} not supported.") + + if ( + isinstance( + val, + ( + dolfinx.cpp.fem.Function_complex128, + dolfinx.cpp.fem.Function_complex64, + dolfinx.cpp.fem.Function_float32, + dolfinx.cpp.fem.Function_float64, + ), + ) + and val.function_space == V._cpp_object + ): + new_cpp_object = bctype(val, dofs) + elif isinstance(val, dolfinx.fem.Function): + new_cpp_object = bctype(val._cpp_object, dofs) + else: + # Depending on your FEniCSx version, the C++ constructor might strictly + # expect the C++ FunctionSpace instead of the Python FunctionSpace wrapper. + try: + new_cpp_object = bctype(val, dofs, V._cpp_object) + except TypeError: + new_cpp_object = bctype(val._cpp_object, dofs, V._cpp_object) + + # 4. Initialize the parent dolfinx.fem.DirichletBC wrapper with the newly minted C++ object + super().__init__(new_cpp_object) + + # 5. Store your custom properties + # self._orig_g = val + FloatingType.__init__( + self, + V, + val, + # name=name, + dtype=dtype, + block_class=kwargs.pop("block_class", None), + _ad_floating_active=kwargs.pop("_ad_floating_active", False), + _ad_args=kwargs.pop("_ad_args", None), + output_block_class=kwargs.pop("output_block_class", None), + _ad_output_args=kwargs.pop("_ad_output_args", None), + _ad_outputs=kwargs.pop("_ad_outputs", None), + annotate=kwargs.pop("annotate", True), + **kwargs, + ) + + def _ad_create_checkpoint(self): + checkpoint = create_overloaded_object(self) + checkpoint.name = self.name + "_checkpoint" + return checkpoint + + def _ad_restore_at_checkpoint(self, checkpoint): + return checkpoint + + +def dirichletbc( + value: ufl.core.expr.Expr, + dofs: npt.NDArray[np.int32], + V: dolfinx.fem.FunctionSpace | None = None, + **kwargs, +) -> DirichletBC: + """Overloaded DirichletBC so that we can reconstruct BCs with UFL expressions. + + .. note:: + This class is user-facing. + + :param value: A UFL expression representing the boundary condition. + :param dofs: An array of degree-of-freedom indices in `V` where the BC should be applied. + :param V: The function space on which the BC applies. It can be a subspace of a mixed/blocked space. + """ + if isinstance(value, dolfinx.fem.Function): + V = value.function_space + return DirichletBC(value, dofs, V, **kwargs) diff --git a/tests/test_dirichlet_bc.py b/tests/test_dirichlet_bc.py index 2227bb7..f5420a2 100644 --- a/tests/test_dirichlet_bc.py +++ b/tests/test_dirichlet_bc.py @@ -6,7 +6,7 @@ import ufl from dolfinx_adjoint import Function, LinearProblem, assemble_scalar, assign, dirichletbc -from dolfinx_adjoint.types.dirichletbc import DirichletBCBlock +from dolfinx_adjoint.blocks.dirichletbc import DirichletBCBlock def test_dirichletbc_recording(): @@ -113,7 +113,7 @@ def test_time_dependent_bc_replay(): boundary_dofs = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundary_facets) # Use native dolfinx here! PyAdjoint traces the bc_func inside it. - bc = dolfinx.fem.dirichletbc(bc_func, boundary_dofs) + bc = dirichletbc(bc_func, boundary_dofs) problem = LinearProblem(a, L, bcs=[bc], u=uh) From 4e651acf1f5c6f3f7104ee4166163215f15769c6 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Tue, 16 Jun 2026 16:42:26 +0200 Subject: [PATCH 4/6] Add fix to FunctionAssigner to work with dirichlet bc --- src/dolfinx_adjoint/blocks/dirichletbc.py | 9 +- .../blocks/function_assigner.py | 22 ++- src/dolfinx_adjoint/blocks/solvers.py | 28 +--- src/dolfinx_adjoint/types/dirichletbc.py | 134 +++++++----------- tests/test_dirichlet_bc.py | 5 +- 5 files changed, 78 insertions(+), 120 deletions(-) diff --git a/src/dolfinx_adjoint/blocks/dirichletbc.py b/src/dolfinx_adjoint/blocks/dirichletbc.py index 2a42b7e..88c9e01 100644 --- a/src/dolfinx_adjoint/blocks/dirichletbc.py +++ b/src/dolfinx_adjoint/blocks/dirichletbc.py @@ -1,6 +1,4 @@ -import dolfinx from pyadjoint.block import Block -from pyadjoint.tape import stop_annotating class DirichletBCBlock(Block): @@ -16,13 +14,12 @@ def __str__(self): return "dirichletbc" def prepare_recompute_component(self, inputs, relevant_outputs): - # Extract the checkpointed `value` from the inputs return inputs[0] if inputs else None def recompute_component(self, inputs, block_variable, idx, prepared): - # Re-instantiate the FEniCSx boundary condition with the rewound tape value - with stop_annotating(): - return dolfinx.fem.dirichletbc(prepared, self.dofs, self.V) + # PyAdjoint relies on checkpoints. The dynamic `_cpp_object` property + # in DirichletBC ensures FEniCSx handles the C++ updates internally. + return block_variable.saved_output # Empty stubs required by the PyAdjoint Block interface for passive nodes def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): diff --git a/src/dolfinx_adjoint/blocks/function_assigner.py b/src/dolfinx_adjoint/blocks/function_assigner.py index 7407de3..aa7723e 100644 --- a/src/dolfinx_adjoint/blocks/function_assigner.py +++ b/src/dolfinx_adjoint/blocks/function_assigner.py @@ -175,18 +175,34 @@ def prepare_recompute_component(self, inputs, relevant_outputs): return None return self._replace_with_saved_output() + # def recompute_component(self, inputs, block_variable, idx, prepared): + # if self.expr is None: + # prepared = inputs[0] + # output = dolfinx.fem.Function( + # block_variable.output.function_space, name="f{block_variable.output.name}_AssignBlockRecompute" + # ) + # try: + # if output.function_space == prepared.function_space: + # output.x.array[:] = prepared.x.array[:] + # except AttributeError: + # # Handling float value + # output.x.array[:] = prepared + # return output def recompute_component(self, inputs, block_variable, idx, prepared): if self.expr is None: prepared = inputs[0] - output = dolfinx.fem.Function( - block_variable.output.function_space, name="f{block_variable.output.name}_AssignBlockRecompute" - ) + + # We must return the exact same object instance to maintain C++ memory bindings + # (especially for DirichletBCs), updating it in-place. + output = block_variable.saved_output + try: if output.function_space == prepared.function_space: output.x.array[:] = prepared.x.array[:] except AttributeError: # Handling float value output.x.array[:] = prepared + return output def __str__(self): diff --git a/src/dolfinx_adjoint/blocks/solvers.py b/src/dolfinx_adjoint/blocks/solvers.py index 093dc83..a55de85 100644 --- a/src/dolfinx_adjoint/blocks/solvers.py +++ b/src/dolfinx_adjoint/blocks/solvers.py @@ -89,11 +89,10 @@ def __init__( self._bcs = bcs if bcs is not None else [] # Add dependencies from the boundary conditions - # for bc in self._bcs: - # self.add_dependency(bc, no_duplicates=True) if self._bcs is not None: for bc in self._bcs: - self.add_dependency(bc, no_duplicates=True) + if hasattr(bc, "block_variable"): + self.add_dependency(bc, no_duplicates=True) # Solver for recomputing the linear problem self._forward_solver = dolfinx.fem.petsc.LinearProblem( @@ -171,27 +170,6 @@ def prepare_recompute_component(self, inputs, relevant_outputs): else: initial_guess = [dolfinx.fem.Function(u.function_space, name=u.name + "_initial_guess") for u in self._u] - # Replace values in the DirichletBC if it is dependent on a control - # NOTE: Currently assume that BCS are control independent. - # bcs = self._bcs - # for block_variable in self.get_dependencies(): - # c = block_variable.output - # c_rep = block_variable.saved_output - - # if isinstance(c, dolfinx.fem.DirichletBC): - # bcs.append(c_rep) - if self._bcs is not None: - for bc in self._bcs: - if hasattr(bc, "g") and hasattr(bc.g, "block_variable"): - bc.g.x.array[:] = bc.g.block_variable.saved_output.x.array[:] - bcs = [] - for bc in self._bcs: - if hasattr(bc, "block_variable") and bc.block_variable in self.get_dependencies(): - # Extract the newly minted dolfinx.fem.DirichletBC from the DirichletBCBlock - bcs.append(bc.block_variable.saved_output) - else: - bcs.append(bc) - # Replace form coefficients with checkpointed values. # Loop through the dependencies of the lhs and rhs, check if they are in the respective form lhs = self._replace_coefficients_in_form(self._lhs) @@ -226,7 +204,7 @@ def prepare_recompute_component(self, inputs, relevant_outputs): self._forward_solver._a = compiled_lhs self._forward_solver._L = compiled_rhs self._forward_solver._P = compiled_preconditioner - self._forward_solver.bcs = bcs + self._forward_solver.bcs = self._bcs self._forward_solver._u = initial_guess def recompute_component( diff --git a/src/dolfinx_adjoint/types/dirichletbc.py b/src/dolfinx_adjoint/types/dirichletbc.py index 67564ba..ad3f40f 100644 --- a/src/dolfinx_adjoint/types/dirichletbc.py +++ b/src/dolfinx_adjoint/types/dirichletbc.py @@ -1,14 +1,11 @@ import dolfinx import numpy as np import numpy.typing as npt +import pyadjoint import ufl -from pyadjoint.overloaded_type import ( - FloatingType, - create_overloaded_object, -) +from pyadjoint.overloaded_type import FloatingType -# from pyadjoint.block_variable import BlockVariable -# from pyadjoint.tape import annotate_tape, get_working_tape, stop_annotating +from ..blocks.dirichletbc import DirichletBCBlock def extract_dtype(expr: ufl.core.expr.Expr) -> npt.DTypeLike: @@ -35,43 +32,19 @@ def extract_dtype(expr: ufl.core.expr.Expr) -> npt.DTypeLike: class DirichletBC(dolfinx.fem.DirichletBC, FloatingType): - _pack_expression: dolfinx.fem.Expression | None - _ufl_expr: ufl.core.expr.Expr | None # Store original UFL expression - - def __init__( - self, - g: ufl.core.expr.Expr, - dofs: npt.NDArray[np.int32], - V: dolfinx.fem.FunctionSpace, - name: str = "dirichletbc", - **kwargs, - ): - """ - Create an Irksome compatible DirichletBC from an existing DOLFINx bc. - - :param g: The boundary condition expression - :param dofs: An array of degree-of-freedom indices in V - :param V: The space to construct the BC on. - :param name: The name of the boundary condition. - """ - # Attach UFL function space (to be able to reconstruct functions and constants on the same UFL domain) + def __init__(self, g, dofs, V, name="dirichletbc", **kwargs): self.name = name self._ufl_space = V.ufl_function_space() - # Store original UFL expression for time-varying BCs if not isinstance(g, (dolfinx.fem.Function, dolfinx.fem.Constant, int, float, complex)): - self._ufl_expr = g # Save the symbolic expression + self._ufl_expr = g else: self._ufl_expr = None - self._ufl_space = V.ufl_function_space() - # If reconstructing with a sub space, we need to get the subspace dof indices - # If working with a subspace of a single stage, we need to create the (parent_dof, sub_dof) mapping if V.component() != []: V_sub, sub_to_parent = V.collapse() if len(sub_to_parent) != 1: - msg = "Mixed topology is not supported for reconstructing BCs with UFL expressions" - raise NotImplementedError(msg) + raise NotImplementedError("Mixed topology is not supported for reconstructing BCs") else: sub_to_parent = sub_to_parent[0] parent_to_sub = np.full( @@ -83,25 +56,14 @@ def __init__( sub_dofs = parent_to_sub[dofs] dofs = (dofs, sub_dofs) - # If we are not reconstructing the BC with a new value, - # we can reuse existing C++ objects - self._pack_expression = None - - # If we are reconstructing the BC with a new value, - # we need to check if the new value is a DOLFINx function or Constant. - # If True, we do not need to do anything for reconstruction. if isinstance(g, (dolfinx.fem.Function, dolfinx.fem.Constant)): val = g self._pack_expression = None else: - # If not, we need to take the ufl.core.expr.Expr and pack it into a DOLFINx Expression - if V.component() != []: - val = dolfinx.fem.Function(V_sub, name=f"bc_{str(g)}")._cpp_object - else: - val = dolfinx.fem.Function(V, name=f"bc_{str(g)}")._cpp_object - self._pack_expression = dolfinx.fem.Expression(g, V.element.interpolation_points) + val = dolfinx.fem.Function(V_sub if V.component() != [] else V, name=f"bc_{str(g)}") + self._pack_expression = dolfinx.fem.Expression(g, V.element.interpolation_points()) + val.interpolate(self._pack_expression) - # Get correct C++ implementation based on dtype of expression dtype = extract_dtype(g) if np.issubdtype(dtype, np.float32): bctype = dolfinx.cpp.fem.DirichletBC_float32 @@ -114,57 +76,59 @@ def __init__( else: raise NotImplementedError(f"Type {dtype} not supported.") - if ( - isinstance( - val, - ( - dolfinx.cpp.fem.Function_complex128, - dolfinx.cpp.fem.Function_complex64, - dolfinx.cpp.fem.Function_float32, - dolfinx.cpp.fem.Function_float64, - ), - ) - and val.function_space == V._cpp_object - ): - new_cpp_object = bctype(val, dofs) - elif isinstance(val, dolfinx.fem.Function): - new_cpp_object = bctype(val._cpp_object, dofs) - else: - # Depending on your FEniCSx version, the C++ constructor might strictly - # expect the C++ FunctionSpace instead of the Python FunctionSpace wrapper. - try: - new_cpp_object = bctype(val, dofs, V._cpp_object) - except TypeError: - new_cpp_object = bctype(val._cpp_object, dofs, V._cpp_object) + # Save internal references for dynamic C++ object generation + self._g_val = val + self._dofs_array = dofs + self._V_space = V + self._bctype = bctype + + # Initialize FEniCSx wrapper. This will trigger our _cpp_object.setter + super().__init__(self._generate_cpp_object()) - # 4. Initialize the parent dolfinx.fem.DirichletBC wrapper with the newly minted C++ object - super().__init__(new_cpp_object) + annotate = kwargs.pop("annotate", True) + annotate = annotate and pyadjoint.annotate_tape() - # 5. Store your custom properties - # self._orig_g = val FloatingType.__init__( self, V, val, - # name=name, dtype=dtype, - block_class=kwargs.pop("block_class", None), - _ad_floating_active=kwargs.pop("_ad_floating_active", False), - _ad_args=kwargs.pop("_ad_args", None), - output_block_class=kwargs.pop("output_block_class", None), - _ad_output_args=kwargs.pop("_ad_output_args", None), - _ad_outputs=kwargs.pop("_ad_outputs", None), - annotate=kwargs.pop("annotate", True), + block_class=kwargs.pop("block_class", DirichletBCBlock), + _ad_floating_active=False, + _ad_args=kwargs.pop("_ad_args", (val, dofs, V)), + annotate=annotate, **kwargs, ) + if annotate: + self._ad_annotate_block() + + def _generate_cpp_object(self): + """Dynamically construct a C++ BC reflecting the current array memory.""" + val_cpp = self._g_val._cpp_object if hasattr(self._g_val, "_cpp_object") else self._g_val + if isinstance(self._g_val, dolfinx.fem.Function): + return self._bctype(val_cpp, self._dofs_array) + else: + try: + return self._bctype(val_cpp, self._dofs_array, self._V_space._cpp_object) + except TypeError: + return self._bctype(val_cpp, self._dofs_array) + + @property + def _cpp_object(self): + # Solvers internally read this property every time they assemble/set_bcs + return self._generate_cpp_object() + + @_cpp_object.setter + def _cpp_object(self, value): + # Absorb the assignment from dolfinx.fem.DirichletBC.__init__ + self._initial_cpp_object = value + def _ad_create_checkpoint(self): - checkpoint = create_overloaded_object(self) - checkpoint.name = self.name + "_checkpoint" - return checkpoint + return self def _ad_restore_at_checkpoint(self, checkpoint): - return checkpoint + return self def dirichletbc( diff --git a/tests/test_dirichlet_bc.py b/tests/test_dirichlet_bc.py index f5420a2..571553f 100644 --- a/tests/test_dirichlet_bc.py +++ b/tests/test_dirichlet_bc.py @@ -4,6 +4,7 @@ import numpy as np import pyadjoint import ufl +from pyadjoint.overloaded_type import Weakref from dolfinx_adjoint import Function, LinearProblem, assemble_scalar, assign, dirichletbc from dolfinx_adjoint.blocks.dirichletbc import DirichletBCBlock @@ -38,6 +39,7 @@ def test_dirichletbc_recording(): def test_dirichletbc_no_annotate(): """Test that setting annotate=False bypasses tape recording entirely.""" + pyadjoint.get_working_tape().clear_tape() mesh = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, 10) V = dolfinx.fem.functionspace(mesh, ("Lagrange", 1)) @@ -53,7 +55,8 @@ def test_dirichletbc_no_annotate(): tape = pyadjoint.get_working_tape() assert len(tape.get_blocks()) == 0 - assert not hasattr(bc, "block_variable") + # FIX: Check the underlying weak reference rather than invoking the property + assert getattr(bc, "_block_variable", Weakref())() is None def test_dirichletbc_recompute(): From 621b2ee129e55cfa8bdc054533a20d9cec5e956f Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Tue, 16 Jun 2026 17:34:17 +0200 Subject: [PATCH 5/6] Formatting --- src/dolfinx_adjoint/types/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dolfinx_adjoint/types/__init__.py b/src/dolfinx_adjoint/types/__init__.py index 9ada54f..a781b83 100644 --- a/src/dolfinx_adjoint/types/__init__.py +++ b/src/dolfinx_adjoint/types/__init__.py @@ -1,4 +1,4 @@ __all__ = ["Function", "Constant", "dirichletbc"] -from .function import Constant, Function from .dirichletbc import dirichletbc +from .function import Constant, Function From b438d6754970e5ce1c4e71edeacc2f1670ea44b6 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Tue, 16 Jun 2026 18:10:36 +0200 Subject: [PATCH 6/6] Cleanup --- src/dolfinx_adjoint/blocks/dirichletbc.py | 33 ++--- .../blocks/function_assigner.py | 15 +-- src/dolfinx_adjoint/types/dirichletbc.py | 120 +++--------------- 3 files changed, 27 insertions(+), 141 deletions(-) diff --git a/src/dolfinx_adjoint/blocks/dirichletbc.py b/src/dolfinx_adjoint/blocks/dirichletbc.py index 88c9e01..0fd8b48 100644 --- a/src/dolfinx_adjoint/blocks/dirichletbc.py +++ b/src/dolfinx_adjoint/blocks/dirichletbc.py @@ -1,37 +1,24 @@ +import dolfinx +import numpy as np +import numpy.typing as npt from pyadjoint.block import Block class DirichletBCBlock(Block): - def __init__(self, value, dofs, V=None, ad_block_tag=None): + def __init__( + self, + value: dolfinx.fem.Function | dolfinx.fem.Constant, + dofs: npt.NDArray[np.int32], + V: dolfinx.fem.FunctionSpace | None = None, + ad_block_tag=None, + ): super().__init__(ad_block_tag=ad_block_tag) self.dofs = dofs self.V = V - - # Add dependency on the underlying overloaded Function or Constant self.add_dependency(value) - def __str__(self): - return "dirichletbc" - def prepare_recompute_component(self, inputs, relevant_outputs): return inputs[0] if inputs else None def recompute_component(self, inputs, block_variable, idx, prepared): - # PyAdjoint relies on checkpoints. The dynamic `_cpp_object` property - # in DirichletBC ensures FEniCSx handles the C++ updates internally. return block_variable.saved_output - - # Empty stubs required by the PyAdjoint Block interface for passive nodes - def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): - pass - - def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): - pass - - def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): - pass - - def evaluate_hessian_component( - self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None - ): - pass diff --git a/src/dolfinx_adjoint/blocks/function_assigner.py b/src/dolfinx_adjoint/blocks/function_assigner.py index aa7723e..fca04b5 100644 --- a/src/dolfinx_adjoint/blocks/function_assigner.py +++ b/src/dolfinx_adjoint/blocks/function_assigner.py @@ -175,24 +175,11 @@ def prepare_recompute_component(self, inputs, relevant_outputs): return None return self._replace_with_saved_output() - # def recompute_component(self, inputs, block_variable, idx, prepared): - # if self.expr is None: - # prepared = inputs[0] - # output = dolfinx.fem.Function( - # block_variable.output.function_space, name="f{block_variable.output.name}_AssignBlockRecompute" - # ) - # try: - # if output.function_space == prepared.function_space: - # output.x.array[:] = prepared.x.array[:] - # except AttributeError: - # # Handling float value - # output.x.array[:] = prepared - # return output def recompute_component(self, inputs, block_variable, idx, prepared): if self.expr is None: prepared = inputs[0] - # We must return the exact same object instance to maintain C++ memory bindings + # We should return the exact object instance to maintain C++ memory bindings # (especially for DirichletBCs), updating it in-place. output = block_variable.saved_output diff --git a/src/dolfinx_adjoint/types/dirichletbc.py b/src/dolfinx_adjoint/types/dirichletbc.py index ad3f40f..547bf15 100644 --- a/src/dolfinx_adjoint/types/dirichletbc.py +++ b/src/dolfinx_adjoint/types/dirichletbc.py @@ -2,69 +2,25 @@ import numpy as np import numpy.typing as npt import pyadjoint -import ufl from pyadjoint.overloaded_type import FloatingType from ..blocks.dirichletbc import DirichletBCBlock - - -def extract_dtype(expr: ufl.core.expr.Expr) -> npt.DTypeLike: - """Extract the dtype from an expression. - - Looks for any constants or coefficients and returning their dtype. - This is necessary for determining which DOLFINx DirichletBC constructor - to use when packing UFL expressions into DOLFINx Expressions for use in - BC reconstruction. - """ - consts = ufl.algorithms.analysis.extract_constants(expr) - for c in consts: - if hasattr(c, "dtype"): - return c.dtype - coeffs = ufl.algorithms.extract_coefficients(expr) - for c in coeffs: - if hasattr(c, "dtype"): - return c.dtype - raise ValueError( - "Could not extract dtype from expression, " - "please ensure that all constants and coefficients have a " - "dtype attribute" - ) +from .function import Function class DirichletBC(dolfinx.fem.DirichletBC, FloatingType): - def __init__(self, g, dofs, V, name="dirichletbc", **kwargs): - self.name = name - self._ufl_space = V.ufl_function_space() + """A class overloading `dolfinx.fem.DirichletBC` to support it being used as a control variable + in the adjoint framework. - if not isinstance(g, (dolfinx.fem.Function, dolfinx.fem.Constant, int, float, complex)): - self._ufl_expr = g - else: - self._ufl_expr = None - - if V.component() != []: - V_sub, sub_to_parent = V.collapse() - if len(sub_to_parent) != 1: - raise NotImplementedError("Mixed topology is not supported for reconstructing BCs") - else: - sub_to_parent = sub_to_parent[0] - parent_to_sub = np.full( - (V.dofmap.index_map.size_local + V.dofmap.index_map.num_ghosts) * V.dofmap.index_map_bs, - -1, - dtype=np.int32, - ) - parent_to_sub[sub_to_parent] = np.arange(len(sub_to_parent)) - sub_dofs = parent_to_sub[dofs] - dofs = (dofs, sub_dofs) + Args: + g: The value of the Dirichlet BC. + dofs: An array of degree-of-freedom indices in `V` where the BC should be applied. + **kwargs: Additional keyword arguments to pass to the `pyadjoint.overloaded_type.FloatingType` constructor. - if isinstance(g, (dolfinx.fem.Function, dolfinx.fem.Constant)): - val = g - self._pack_expression = None - else: - val = dolfinx.fem.Function(V_sub if V.component() != [] else V, name=f"bc_{str(g)}") - self._pack_expression = dolfinx.fem.Expression(g, V.element.interpolation_points()) - val.interpolate(self._pack_expression) + """ - dtype = extract_dtype(g) + def __init__(self, g: Function, dofs: npt.NDArray[np.int32], **kwargs): + dtype = g.dtype if np.issubdtype(dtype, np.float32): bctype = dolfinx.cpp.fem.DirichletBC_float32 elif np.issubdtype(dtype, np.float64): @@ -76,26 +32,18 @@ def __init__(self, g, dofs, V, name="dirichletbc", **kwargs): else: raise NotImplementedError(f"Type {dtype} not supported.") - # Save internal references for dynamic C++ object generation - self._g_val = val - self._dofs_array = dofs - self._V_space = V - self._bctype = bctype - - # Initialize FEniCSx wrapper. This will trigger our _cpp_object.setter - super().__init__(self._generate_cpp_object()) + super().__init__(bctype(g._cpp_object, dofs)) annotate = kwargs.pop("annotate", True) annotate = annotate and pyadjoint.annotate_tape() FloatingType.__init__( self, - V, - val, + g, dtype=dtype, block_class=kwargs.pop("block_class", DirichletBCBlock), _ad_floating_active=False, - _ad_args=kwargs.pop("_ad_args", (val, dofs, V)), + _ad_args=kwargs.pop("_ad_args", (g, dofs)), annotate=annotate, **kwargs, ) @@ -103,27 +51,6 @@ def __init__(self, g, dofs, V, name="dirichletbc", **kwargs): if annotate: self._ad_annotate_block() - def _generate_cpp_object(self): - """Dynamically construct a C++ BC reflecting the current array memory.""" - val_cpp = self._g_val._cpp_object if hasattr(self._g_val, "_cpp_object") else self._g_val - if isinstance(self._g_val, dolfinx.fem.Function): - return self._bctype(val_cpp, self._dofs_array) - else: - try: - return self._bctype(val_cpp, self._dofs_array, self._V_space._cpp_object) - except TypeError: - return self._bctype(val_cpp, self._dofs_array) - - @property - def _cpp_object(self): - # Solvers internally read this property every time they assemble/set_bcs - return self._generate_cpp_object() - - @_cpp_object.setter - def _cpp_object(self, value): - # Absorb the assignment from dolfinx.fem.DirichletBC.__init__ - self._initial_cpp_object = value - def _ad_create_checkpoint(self): return self @@ -131,21 +58,6 @@ def _ad_restore_at_checkpoint(self, checkpoint): return self -def dirichletbc( - value: ufl.core.expr.Expr, - dofs: npt.NDArray[np.int32], - V: dolfinx.fem.FunctionSpace | None = None, - **kwargs, -) -> DirichletBC: - """Overloaded DirichletBC so that we can reconstruct BCs with UFL expressions. - - .. note:: - This class is user-facing. - - :param value: A UFL expression representing the boundary condition. - :param dofs: An array of degree-of-freedom indices in `V` where the BC should be applied. - :param V: The function space on which the BC applies. It can be a subspace of a mixed/blocked space. - """ - if isinstance(value, dolfinx.fem.Function): - V = value.function_space - return DirichletBC(value, dofs, V, **kwargs) +def dirichletbc(value: Function, dofs: npt.NDArray[np.int32], **kwargs) -> DirichletBC: + """Overloaded DirichletBC constructor that creates an adjoint-aware DirichletBC""" + return DirichletBC(value, dofs, **kwargs)