diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py index acccebd..3826f61 100755 --- a/src/pycea/tl/ancestral_states.py +++ b/src/pycea/tl/ancestral_states.py @@ -193,6 +193,24 @@ def _reconstruct_sum(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: _set_node_value(tree, node, key, sum(valid) if valid else None, index) +def _reconstruct_sum_array(tree: nx.DiGraph, key: str, fixed_nodes: set | None = None) -> None: + """Reconstructs ancestral states by vectorized numpy sum over array-valued attributes.""" + for node in reversed(list(nx.topological_sort(tree))): + is_fixed = fixed_nodes is not None and node in fixed_nodes + if tree.out_degree(node) == 0 or is_fixed: + continue + child_arrays = [tree.nodes[child][key] for child in tree.successors(node)] + stacked = np.stack(child_arrays) + result = np.nansum(stacked, axis=0) + result[np.all(np.isnan(stacked), axis=0)] = np.nan + tree.nodes[node][key] = result + # Convert numpy arrays back to lists for compatibility + for node in tree.nodes: + val = tree.nodes[node].get(key) + if isinstance(val, np.ndarray): + tree.nodes[node][key] = val.tolist() + + def _reconstruct_mean(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: set | None = None) -> None: """Reconstructs ancestral by averaging the values of the children.""" @@ -387,18 +405,26 @@ def ancestral_states( # If array add to tree as list if is_array: length = data.shape[1] - node_attrs = data.apply(lambda row: list(row), axis=1).to_dict() - for node in t.nodes: - if node not in node_attrs: - node_attrs[node] = [None] * length - _remove_node_attributes(t, keys_added[0]) - nx.set_node_attributes(t, node_attrs, keys_added[0]) fixed_nodes = None if tdata.alignment != "leaves": not_all_nan = ~data.isna().all(axis=1) fixed_nodes = set(data[not_all_nan].index) - leaves_set - for index in range(length): - _ancestral_states(t, keys_added[0], method, costs, missing_state, default_state, index, fixed_nodes) + _remove_node_attributes(t, keys_added[0]) + if method == "sum": + node_attrs = dict(zip(data.index, data.to_numpy(dtype=float))) + for node in t.nodes: + if node not in node_attrs: + node_attrs[node] = np.full(length, np.nan) + nx.set_node_attributes(t, node_attrs, keys_added[0]) + _reconstruct_sum_array(t, keys_added[0], fixed_nodes) + else: + node_attrs = data.apply(lambda row: list(row), axis=1).to_dict() + for node in t.nodes: + if node not in node_attrs: + node_attrs[node] = [None] * length + nx.set_node_attributes(t, node_attrs, keys_added[0]) + for index in range(length): + _ancestral_states(t, keys_added[0], method, costs, missing_state, default_state, index, fixed_nodes) # If column add to tree as scalar else: for key, key_added in zip(keys, keys_added, strict=False):