From b21bc799e1c3c3fa36a8b96cb5b5437dd3a218c6 Mon Sep 17 00:00:00 2001 From: Theodore Kisner Date: Thu, 25 Jun 2026 14:47:24 -0700 Subject: [PATCH 1/3] Support MPI distributed arrays where some processes have no data When working with large arrays distributed across a fixed-size MPI communicator, it is sometimes necessary to handle the case where a sliced / reduced array results in no local data on some processes. These changes accomodate that situation by allowing construction of FlacArray instances where some processes have zero streams. Routines for I/O to HDF5 and Zarr are also updated for the case where some processes have no local data. --- src/flacarray/array.py | 66 +++++++++++---- src/flacarray/compress.py | 35 ++++++++ src/flacarray/decompress.py | 16 ++++ src/flacarray/demo.py | 18 ++-- src/flacarray/hdf5.py | 108 +++++++++++++++++------- src/flacarray/hdf5_load_v1.py | 150 +++++++++++++++++----------------- src/flacarray/io_common.py | 107 +++++++++++++++++------- src/flacarray/mpi.py | 78 +++++++++++------- src/flacarray/tests/hdf5.py | 83 +++++++++++++++++++ src/flacarray/tests/zarr.py | 83 +++++++++++++++++++ src/flacarray/zarr.py | 46 ++++++++--- src/flacarray/zarr_load_v1.py | 31 ++++--- 12 files changed, 607 insertions(+), 214 deletions(-) diff --git a/src/flacarray/array.py b/src/flacarray/array.py index b5864c9..55a57be 100644 --- a/src/flacarray/array.py +++ b/src/flacarray/array.py @@ -6,7 +6,7 @@ import numpy as np -from .compress import array_compress +from .compress import array_compress, array_compress_empty from .decompress import array_decompress_slice from .hdf5 import write_compressed as hdf5_write_compressed from .hdf5 import read_compressed as hdf5_read_compressed @@ -93,6 +93,7 @@ def __init__( if other is not None: # We are copying an existing object, make sure we have an # independent copy. + self._empty = copy.deepcopy(other._empty) self._shape = copy.deepcopy(other._shape) self._global_shape = copy.deepcopy(other._global_shape) self._compressed = copy.deepcopy(other._compressed) @@ -107,6 +108,9 @@ def __init__( else: # This form of constructor is used in the class methods where we # have already created these arrays for use by this instance. + self._empty = shape is None or shape[0] == 0 + if self._empty and mpi_comm is None: + raise RuntimeError("Local data is empty, and MPI is not being used") self._shape = shape self._global_shape = global_shape self._compressed = compressed @@ -125,25 +129,33 @@ def _init_params(self): # stream, this tracks the user intentions about whether to flatten the # leading dimension. We also track the "local shape", with is the same, # but which always keeps the leading dimension. - if len(self._shape) == 1: + if self._empty: + # No local data + self._flatten_single = False + self._local_shape = (0,) + self._global_shape[1:] + self._local_nbytes = 0 + elif len(self._shape) == 1: self._flatten_single = True self._local_shape = (1, self._shape[0]) + self._local_nbytes = self._compressed.nbytes else: self._flatten_single = False self._local_shape = self._shape + self._local_nbytes = self._compressed.nbytes - self._local_nbytes = self._compressed.nbytes ( self._global_nbytes, self._global_proc_nbytes, self._global_stream_starts, ) = global_bytes(self._local_nbytes, self._stream_starts, self._mpi_comm) + self._leading_shape = self._local_shape[:-1] self._global_leading_shape = self._global_shape[:-1] - self._stream_size = self._local_shape[-1] + self._stream_size = self._global_shape[-1] # For reference, record the type string of the original data. self._typestr = self._dtype_str(self._dtype) + # Track whether we have 32bit or 64bit data self._is_int64 = self._dtype == np.dtype(np.int64) or self._dtype == np.dtype( np.float64 @@ -351,7 +363,9 @@ def _get_leading_axes(self, full_key): if self._flatten_single: # Our array is a single stream with flattened shape. - keep_slice = [0,] + keep_slice = [ + 0, + ] else: for axis, axkey in enumerate(full_key[:-1]): if not isinstance(axkey, (int, np.integer)): @@ -398,7 +412,7 @@ def _get_sample_axis(self, full_key): if stop - start <= 0: # No samples return (0, 0, (0,)) - return (start, stop, (stop-start,)) + return (start, stop, (stop - start,)) elif isinstance(sample_key, (int, np.integer)): # Just a scalar return (sample_key, sample_key + 1, ()) @@ -609,25 +623,45 @@ def from_array( """ # Get the global shape of the array - global_props = global_array_properties(arr.shape, mpi_comm=mpi_comm) + empty_data = arr is None or arr.shape[0] == 0 + if empty_data and mpi_comm is None: + raise RuntimeError("Local array is None, and MPI is not being used") + + if empty_data: + # No data on this process + global_props = global_array_properties(None, None, mpi_comm=mpi_comm) + else: + global_props = global_array_properties( + arr.shape, arr.dtype, mpi_comm=mpi_comm + ) + global_shape = global_props["shape"] + dtype = global_props["dtype"] mpi_dist = global_props["dist"] # Compress our local piece of the array - compressed, starts, nbytes, offsets, gains = array_compress( - arr, - level=level, - quanta=quanta, - precision=precision, - use_threads=use_threads, - ) + if empty_data: + # No data + arr_shape = (0,) + global_shape[1:] + compressed, starts, nbytes, offsets, gains = array_compress_empty( + global_shape, dtype, quanta, precision + ) + else: + arr_shape = arr.shape + compressed, starts, nbytes, offsets, gains = array_compress( + arr, + level=level, + quanta=quanta, + precision=precision, + use_threads=use_threads, + ) return FlacArray( None, - shape=arr.shape, + shape=arr_shape, global_shape=global_shape, compressed=compressed, - dtype=arr.dtype, + dtype=dtype, stream_starts=starts, stream_nbytes=nbytes, stream_offsets=offsets, diff --git a/src/flacarray/compress.py b/src/flacarray/compress.py index 5ba0d64..140aeb6 100644 --- a/src/flacarray/compress.py +++ b/src/flacarray/compress.py @@ -82,3 +82,38 @@ def array_compress(arr, level=5, quanta=None, precision=None, use_threads=False) return (compressed, starts, nbytes, None, None) else: raise ValueError(f"Unsupported data type '{arr.dtype}'") + + +def array_compress_empty(global_shape, dtype, quanta, precision): + """Mock the compressed data parameters for processes with no data. + + When using MPI, some processes may have no local data. This helper function + returns the equivalent parameters for those processes which cannot call + `array_compress()`. + + Args: + global_shape (tuple): The shape of the global data + dtype (np.dtype): The array dtype + quanta (array): The quanta values (only checked for None) + precision (array): The precision (only check for None) + + Returns: + (tuple): The (compressed bytes, stream starts, stream_nbytes, stream offsets, + stream gains) + + """ + leading_shape = global_shape[:-1] + compressed = np.zeros(0, dtype=np.uint8) + if len(leading_shape[1:]) == 0: + starts_shape = (0,) + else: + starts_shape = (0,) + leading_shape[1:] + starts = np.zeros(starts_shape, dtype=np.int64) + nbytes = np.zeros(starts_shape, dtype=np.int64) + if quanta is None and precision is None: + offsets = None + gains = None + else: + offsets = np.zeros(starts_shape, dtype=dtype) + gains = np.zeros(starts_shape, dtype=dtype) + return (compressed, starts, nbytes, offsets, gains) diff --git a/src/flacarray/decompress.py b/src/flacarray/decompress.py index 73c3a74..d0d567c 100644 --- a/src/flacarray/decompress.py +++ b/src/flacarray/decompress.py @@ -73,6 +73,22 @@ def array_decompress_slice( (tuple): The (output array, list of stream indices). """ + if len(compressed) == 0: + # This process has no data. Return an empty array with zero + # in the leading shape + empty_shape = stream_starts.shape + (stream_size,) + if stream_offsets is None: + if is_int64: + empty_dtype = np.int64 + else: + empty_dtype = np.int32 + else: + if is_int64: + empty_dtype = np.float64 + else: + empty_dtype = np.float32 + return np.zeros(empty_shape, dtype=empty_dtype), list() + if first_stream_sample is None: first_stream_sample = -1 if last_stream_sample is None: diff --git a/src/flacarray/demo.py b/src/flacarray/demo.py index 4fda9a7..1b61ae2 100644 --- a/src/flacarray/demo.py +++ b/src/flacarray/demo.py @@ -37,8 +37,9 @@ def create_fake_data( rank = comm.rank # Get the global array properties - gprops = global_array_properties(local_shape, comm) + gprops = global_array_properties(local_shape, dtype, comm) shape = gprops["shape"] + dtype = gprops["dtype"] mpi_dist = gprops["dist"] flatshape = np.prod(shape) @@ -99,18 +100,19 @@ def create_fake_data( global_data = comm.bcast(global_data, root=0) # Extract our local piece of the global data - if len(leading_shape) == 0 or (len(leading_shape) == 1 and leading_shape[0] == 1): - data = global_data + local_start = mpi_dist[rank][0] + local_stop = mpi_dist[rank][1] + if local_start == local_stop: + # This process has no data + empty_shape = (0,) + shape[1:] + data = np.zeros(empty_shape, dtype=dtype) else: - local_start = mpi_dist[rank][0] - local_stop = mpi_dist[rank][1] local_slice = [slice(local_start, local_stop, 1)] local_slice.extend([slice(None) for x in shape[1:]]) local_slice = tuple(local_slice) data = global_data[local_slice] - if len(data.shape) == 2 and data.shape[0] == 1: - data = data.reshape((-1)) - + if len(data.shape) == 2 and data.shape[0] == 1: + data = data.reshape((-1)) return data, mpi_dist diff --git a/src/flacarray/hdf5.py b/src/flacarray/hdf5.py index ea27c07..6e42de2 100644 --- a/src/flacarray/hdf5.py +++ b/src/flacarray/hdf5.py @@ -9,12 +9,13 @@ versions. """ + import importlib import numpy as np from . import __version__ as flacarray_version -from .compress import array_compress +from .compress import array_compress, array_compress_empty from .hdf5_utils import have_hdf5, hdf5_use_serial, check_dataset_buffer_size from .io_common import receive_write_compressed from .mpi import global_array_properties, global_bytes @@ -246,18 +247,32 @@ def write_compressed( if use_serial: # Use the common writing function - writer = WriterHDF5( - global_stream_starts.reshape(aux_local_shape), - stream_nbytes.reshape(aux_local_shape), - compressed, - stream_offsets, - stream_gains, - dstarts, - dbytes, - dcomp, - dsoff, - dsgain, - ) + if len(compressed) == 0: + writer = WriterHDF5( + global_stream_starts, + stream_nbytes, + compressed, + stream_offsets, + stream_gains, + dstarts, + dbytes, + dcomp, + dsoff, + dsgain, + ) + else: + writer = WriterHDF5( + global_stream_starts.reshape(aux_local_shape), + stream_nbytes.reshape(aux_local_shape), + compressed, + stream_offsets, + stream_gains, + dstarts, + dbytes, + dcomp, + dsoff, + dsgain, + ) receive_write_compressed( writer, global_leading_shape, @@ -288,16 +303,22 @@ def write_compressed( ) + tuple([slice(0, x) for x in aux_local_shape[1:]]) with dstarts.collective: - dstarts.write_direct(global_stream_starts, dslc, hslc) + if len(compressed) > 0: + dstarts.write_direct(global_stream_starts, dslc, hslc) + with dbytes.collective: - dbytes.write_direct(stream_nbytes, dslc, hslc) + if len(compressed) > 0: + dbytes.write_direct(stream_nbytes, dslc, hslc) if stream_offsets is not None: with dsoff.collective: - dsoff.write_direct(stream_offsets, dslc, hslc) + if len(compressed) > 0: + dsoff.write_direct(stream_offsets, dslc, hslc) + if stream_gains is not None: with dsgain.collective: - dsgain.write_direct(stream_gains, dslc, hslc) + if len(compressed) > 0: + dsgain.write_direct(stream_gains, dslc, hslc) dslc = (slice(0, global_process_nbytes[rank]),) hslc = (slice(comp_doff[rank], comp_doff[rank] + global_process_nbytes[rank]),) @@ -305,7 +326,8 @@ def write_compressed( "Parallel write of compressed data", dslc, np.uint8, True ) with dcomp.collective: - dcomp.write_direct(compressed, dslc, hslc) + if len(compressed) > 0: + dcomp.write_direct(compressed, dslc, hslc) @function_timer @@ -349,32 +371,58 @@ def write_array( raise RuntimeError("h5py is not importable, cannot write to HDF5") # Get the global shape of the array - global_props = global_array_properties(arr.shape, mpi_comm=mpi_comm) + empty_data = arr is None or arr.shape[0] == 0 + if empty_data and mpi_comm is None: + raise RuntimeError("Local array is None, and MPI is not being used") + + if empty_data: + # No data on this process + arr_shape = None + arr_dtype = None + else: + arr_shape = arr.shape + arr_dtype = arr.dtype + + global_props = global_array_properties(arr_shape, arr_dtype, mpi_comm=mpi_comm) global_shape = global_props["shape"] + dtype = global_props["dtype"] mpi_dist = global_props["dist"] # Get the number of channels - if arr.dtype == np.dtype(np.int64) or arr.dtype == np.dtype(np.float64): + if dtype == np.dtype(np.int64) or dtype == np.dtype(np.float64): n_channels = 2 else: n_channels = 1 + stream_size = global_shape[-1] + global_leading_shape = global_shape[:-1] + # Compress our local piece of the array - compressed, starts, nbytes, offsets, gains = array_compress( - arr, level=level, quanta=quanta, precision=precision, use_threads=use_threads - ) + if empty_data: + compressed, starts, nbytes, offsets, gains = array_compress_empty( + global_shape, dtype, quanta, precision + ) + if len(global_leading_shape[1:]) == 0: + leading_shape = (0,) + else: + leading_shape = (0,) + global_leading_shape[1:] + else: + compressed, starts, nbytes, offsets, gains = array_compress( + arr, + level=level, + quanta=quanta, + precision=precision, + use_threads=use_threads, + ) + if len(arr.shape) == 1: + leading_shape = (1,) + else: + leading_shape = arr.shape[:-1] local_nbytes = compressed.nbytes global_nbytes, global_proc_bytes, global_starts = global_bytes( local_nbytes, starts, mpi_comm ) - stream_size = arr.shape[-1] - - if len(arr.shape) == 1: - leading_shape = (1,) - else: - leading_shape = arr.shape[:-1] - global_leading_shape = global_shape[:-1] write_compressed( hgrp, diff --git a/src/flacarray/hdf5_load_v1.py b/src/flacarray/hdf5_load_v1.py index 852300b..7ffd3c4 100644 --- a/src/flacarray/hdf5_load_v1.py +++ b/src/flacarray/hdf5_load_v1.py @@ -16,6 +16,7 @@ read_send_compressed, select_keep_indices, read_compressed_dataset_slice, + initialize_empty_buffers, ) from .utils import function_timer @@ -180,13 +181,15 @@ def read_compressed(hgrp, keep=None, mpi_comm=None, mpi_dist=None): mpi_dist = distribute_and_verify(mpi_comm, global_shape[0], mpi_dist=mpi_dist) # Local data buffers we will load from the file. - local_shape = None - local_starts = None - stream_nbytes = None - compressed = None - stream_offsets = None - stream_gains = None - keep_indices = None + ( + local_shape, + local_starts, + stream_nbytes, + compressed, + stream_offsets, + stream_gains, + keep_indices, + ) = initialize_empty_buffers(global_shape, stream_off_dtype) if use_serial: # Use the common function for reading data and communicating it. @@ -213,57 +216,59 @@ def read_compressed(hgrp, keep=None, mpi_comm=None, mpi_dist=None): # We are using parallel HDF5. All processes have a handle to the dataset # from above, and each process reads its local slice. ds_range = mpi_dist[rank] - leading_shape = (ds_range[1] - ds_range[0],) + global_leading_shape[1:] - local_shape = leading_shape + (stream_size,) - - # The helper datasets all have the same slab definitions - dslc = tuple([slice(0, x) for x in leading_shape]) - hslc = (slice(ds_range[0], ds_range[0] + leading_shape[0]),) + tuple( - [slice(0, x) for x in leading_shape[1:]] - ) - - # If we are using the "keep" array to select streams, slice that - # to cover only data for this process. - if keep is None: - proc_keep = None - else: - proc_keep = keep[hslc] - - # Stream starts - raw_starts = np.empty(leading_shape, dtype=dstarts.dtype) - dstarts.read_direct(raw_starts, hslc, dslc) - - # Stream nbytes - raw_nbytes = np.empty(leading_shape, dtype=dstarts.dtype) - dbytes.read_direct(raw_nbytes, hslc, dslc) - - # Offsets and gains for type conversions - raw_offsets = None - if dsoff is not None: - raw_offsets = np.empty(leading_shape, dtype=stream_off_dtype) - dsoff.read_direct(raw_offsets, hslc, dslc) - raw_gains = None - if dsgain is not None: - raw_gains = np.empty(leading_shape, dtype=stream_gain_dtype) - dsgain.read_direct(raw_gains, hslc, dslc) - - # Compressed bytes. Apply our stream selection and load just those - # streams we are keeping for this process. - compressed, local_starts, keep_indices = read_compressed_dataset_slice( - dcomp, proc_keep, raw_starts, raw_nbytes - ) - - # Cut our other arrays to only include the indices selected by the keep mask. - stream_nbytes = select_keep_indices(raw_nbytes, keep_indices) - stream_offsets = select_keep_indices(raw_offsets, keep_indices) - stream_gains = select_keep_indices(raw_gains, keep_indices) - - if local_starts is None: - # This rank has no data after masking - local_shape = None - else: - local_shape = local_starts.shape + (stream_size,) - + if ds_range[1] > ds_range[0]: + # We have some data + leading_shape = (ds_range[1] - ds_range[0],) + global_leading_shape[1:] + local_shape = leading_shape + (stream_size,) + + # The helper datasets all have the same slab definitions + dslc = tuple([slice(0, x) for x in leading_shape]) + hslc = (slice(ds_range[0], ds_range[0] + leading_shape[0]),) + tuple( + [slice(0, x) for x in leading_shape[1:]] + ) + + # If we are using the "keep" array to select streams, slice that + # to cover only data for this process. + if keep is None: + proc_keep = None + else: + proc_keep = keep[hslc] + + # Stream starts + raw_starts = np.empty(leading_shape, dtype=dstarts.dtype) + dstarts.read_direct(raw_starts, hslc, dslc) + + # Stream nbytes + raw_nbytes = np.empty(leading_shape, dtype=dstarts.dtype) + dbytes.read_direct(raw_nbytes, hslc, dslc) + + # Offsets and gains for type conversions + raw_offsets = None + if dsoff is not None: + raw_offsets = np.empty(leading_shape, dtype=stream_off_dtype) + dsoff.read_direct(raw_offsets, hslc, dslc) + raw_gains = None + if dsgain is not None: + raw_gains = np.empty(leading_shape, dtype=stream_gain_dtype) + dsgain.read_direct(raw_gains, hslc, dslc) + + # Compressed bytes. Apply our stream selection and load just those + # streams we are keeping for this process. + compressed, local_starts, keep_indices = read_compressed_dataset_slice( + dcomp, proc_keep, raw_starts, raw_nbytes + ) + + # Cut our other arrays to only include the indices selected by the keep + # mask. + stream_nbytes = select_keep_indices(raw_nbytes, keep_indices) + stream_offsets = select_keep_indices(raw_offsets, keep_indices) + stream_gains = select_keep_indices(raw_gains, keep_indices) + + if local_starts is None: + # This rank has no data after masking + local_shape = None + else: + local_shape = local_starts.shape + (stream_size,) return ( local_shape, global_shape, @@ -356,22 +361,19 @@ def read_array( first_samp = stream_slice.start last_samp = stream_slice.stop - if compressed is None: - arr = None - else: - arr = array_decompress( - compressed, - local_shape[-1], - stream_starts, - stream_nbytes, - stream_offsets=stream_offsets, - stream_gains=stream_gains, - first_stream_sample=first_samp, - last_stream_sample=last_samp, - is_int64=(n_channel == 2), - use_threads=use_threads, - no_flatten=no_flatten, - ) + arr = array_decompress( + compressed, + local_shape[-1], + stream_starts, + stream_nbytes, + stream_offsets=stream_offsets, + stream_gains=stream_gains, + first_stream_sample=first_samp, + last_stream_sample=last_samp, + is_int64=(n_channel == 2), + use_threads=use_threads, + no_flatten=no_flatten, + ) if keep_indices: return arr, indices else: diff --git a/src/flacarray/io_common.py b/src/flacarray/io_common.py index e59adc7..4f79b87 100644 --- a/src/flacarray/io_common.py +++ b/src/flacarray/io_common.py @@ -55,7 +55,8 @@ def read_compressed_dataset_slice(dcomp, keep, stream_starts, stream_nbytes): # do multiple reads to fill sections of that buffer. starts, nbytes, indices = keep_select(keep, stream_starts, stream_nbytes) if len(starts) == 0: - return (None, None, None) + # All data cut by keep mask + return (np.zeros(0, dtype=np.uint8), np.zeros(0, dtype=np.int64), None) total_bytes = np.sum(nbytes) rel_starts = np.zeros_like(starts) rel_starts[1:] = np.cumsum(nbytes)[:-1] @@ -259,6 +260,43 @@ def receive_proc_buffers( ) +def initialize_empty_buffers(global_shape, offset_dtype): + """Initialize common arrays and shapes for processes with no data. + + Args: + global_shape (tuple): Global shape of the uncompressed array. + offset_dtype (np.dtype): The dtype of the offset array or None. + + Returns: + (tuple): The empty data and metadata + + """ + leading_shape = global_shape[:-1] + if len(leading_shape[1:]) == 0: + empty_starts_shape = (0,) + else: + empty_starts_shape = (0,) + leading_shape[1:] + local_shape = (0,) + global_shape[1:] + local_starts = np.zeros(empty_starts_shape, dtype=np.int64) + stream_nbytes = np.zeros(empty_starts_shape, dtype=np.int64) + compressed = np.zeros(0, dtype=np.uint8) + stream_offsets = None + stream_gains = None + if offset_dtype is not None: + stream_offsets = np.zeros(empty_starts_shape, dtype=offset_dtype) + stream_gains = np.zeros(empty_starts_shape, dtype=offset_dtype) + keep_indices = None + return ( + local_shape, + local_starts, + stream_nbytes, + compressed, + stream_offsets, + stream_gains, + keep_indices, + ) + + @function_timer def read_send_compressed( reader, global_shape, n_channel, keep=None, mpi_comm=None, mpi_dist=None @@ -297,14 +335,6 @@ def read_send_compressed( global_leading_shape = global_shape[:-1] stream_size = global_shape[-1] - local_shape = None - local_starts = None - stream_nbytes = None - compressed = None - stream_offsets = None - stream_gains = None - keep_indices = None - is_64bit = False if n_channel == 2: is_64bit = True @@ -329,12 +359,26 @@ def read_send_compressed( "Reader offsets / gains are float64, but n_channel != 2" ) + # Initialize to default values for processes with no data. + ( + local_shape, + local_starts, + stream_nbytes, + compressed, + stream_offsets, + stream_gains, + keep_indices, + ) = initialize_empty_buffers(global_shape, reader.stream_off_dtype) + # One process reads and sends. # The rank zero process will read data and send to the other # processes. Keep a handle to the asynchronous send buffers # and delete them after the sends are complete. for proc in range(nproc): if rank == 0: + if mpi_dist[proc][0] == mpi_dist[proc][1]: + # No data on this process + continue ( proc_shape, proc_keep, @@ -371,21 +415,25 @@ def read_send_compressed( proc_gains, ) elif proc == rank: - ( - local_shape, - keep_indices, - local_starts, - stream_nbytes, - compressed, - stream_offsets, - stream_gains, - ) = receive_proc_buffers( - comm, - proc, - stream_size, - is_64bit=is_64bit, - offsetgain=offsets_and_gains, - ) + if mpi_dist[proc][0] == mpi_dist[proc][1]: + # No data on this process + continue + else: + ( + local_shape, + keep_indices, + local_starts, + stream_nbytes, + compressed, + stream_offsets, + stream_gains, + ) = receive_proc_buffers( + comm, + proc, + stream_size, + is_64bit=is_64bit, + offsetgain=offsets_and_gains, + ) return ( local_shape, @@ -481,7 +529,11 @@ def receive_write_compressed( # and write it into the global datasets. For each dataset we build # the "slab" (tuple of slices) that we will write from the array # in memory and to the HDF5 dataset. - # + + if global_process_nbytes[proc] == 0: + # This process has no data + continue + # The range of the leading dimension on this process. recv_range = mpi_dist[proc] recv_leading_shape = ( @@ -573,9 +625,8 @@ def receive_write_compressed( del recv elif proc == rank: # We are sending. - send_range = mpi_dist[proc] - if send_range[1] - send_range[0] == 0: - # We have no data + if global_process_nbytes[proc] == 0: + # This process has no data continue comm.Send(writer.starts.astype(np.int64), dest=0, tag=tag_starts) comm.Send(writer.nbytes.astype(np.int64), dest=0, tag=tag_nbytes) diff --git a/src/flacarray/mpi.py b/src/flacarray/mpi.py index 2af0547..3e48535 100644 --- a/src/flacarray/mpi.py +++ b/src/flacarray/mpi.py @@ -68,10 +68,6 @@ def distribute_and_verify(mpi_comm, n_elem, mpi_dist=None): raise RuntimeError( "mpi_dist must have contiguous ranges of first, last (exclusive)" ) - if mpi_dist[proc][1] <= mpi_dist[proc][0]: - raise RuntimeError( - f"mpi_dist has no data for process {proc}" - ) # Everything checks out return mpi_dist else: @@ -90,16 +86,19 @@ def distribute_and_verify(mpi_comm, n_elem, mpi_dist=None): return [(x[0], x[-1] + 1) for x in chunks] -def global_array_properties(local_shape, mpi_comm): +def global_array_properties(local_shape, local_dtype, mpi_comm): """Compute various properties of the global data distribution. Given the local data properties on each process and the MPI communicator, - compute various useful quantities for working with global data. + compute various useful quantities for working with global data. If `local_shape` + is None or has a leading zero, then that means that current process has no data + in the global array. This function also verifies that non-leading dimensions match across all processes. Args: local_shape (tuple): The local data shape on each process. + local_dtype (np.dtype): The dtype of the local data. mpi_comm (MPI.Comm): The MPI communicator or None. Returns: @@ -111,39 +110,60 @@ def global_array_properties(local_shape, mpi_comm): if len(local_shape) == 1: # Just one stream props["shape"] = (1, local_shape[0]) + props["dtype"] = np.dtype(local_dtype) props["dist"] = [(0, 1)] else: props["shape"] = local_shape + props["dtype"] = np.dtype(local_dtype) props["dist"] = [(0, local_shape[0])] return props + all_dtypes = mpi_comm.gather(local_dtype, root=0) all_shapes = mpi_comm.gather(local_shape, root=0) err = False if mpi_comm.rank == 0: + global_dtype = None + for iproc, proc_dtype in enumerate(all_dtypes): + if global_dtype is None and proc_dtype is not None: + global_dtype = proc_dtype + else: + if proc_dtype is not None and proc_dtype != global_dtype: + msg = f"Process {iproc} has dtype '{proc_dtype}'" + msg += f" instead of '{global_dtype}'" + raise RuntimeError(msg) + props["dtype"] = global_dtype dist = list() - shp = all_shapes[0] - if len(shp) == 1: - lda = 1 - trl = shp - else: - lda = shp[0] - trl = shp[1:] - dist.append((0, lda)) - ldoff = lda - for s in all_shapes[1:]: - if len(s) == 1: - lda += 1 - dist.append((ldoff, ldoff + 1)) - ldoff += 1 - if s != trl: - err = True - break + lda = None + trl = None + ldoff = 0 + for shp in all_shapes: + if shp is None or shp[0] == 0: + # This process has no data + dist.append((ldoff, ldoff)) else: - lda += s[0] - dist.append((ldoff, ldoff + s[0])) - ldoff += s[0] - if s[1:] != trl: - err = True - break + if len(shp) == 1: + if lda is None: + lda = 1 + else: + lda += 1 + dist.append((ldoff, ldoff + 1)) + ldoff += 1 + if trl is None: + trl = shp + elif shp != trl: + err = True + break + else: + if lda is None: + lda = shp[0] + else: + lda += shp[0] + dist.append((ldoff, ldoff + shp[0])) + ldoff += shp[0] + if trl is None: + trl = shp[1:] + elif shp[1:] != trl: + err = True + break props["shape"] = (lda,) + trl props["dist"] = dist err = mpi_comm.bcast(err, root=0) diff --git a/src/flacarray/tests/hdf5.py b/src/flacarray/tests/hdf5.py index 8408908..60ea789 100644 --- a/src/flacarray/tests/hdf5.py +++ b/src/flacarray/tests/hdf5.py @@ -172,6 +172,89 @@ def test_array_write_read(self): tmpdir.cleanup() del tmpdir + def test_array_write_read_nodata(self): + if not have_hdf5: + print("h5py not available, skipping tests", flush=True) + return + if self.comm is None or self.comm.size < 2: + print("Less than 2 processes, skipping MPI test with empty procs") + return + + rank = self.comm.rank + tmpdir = None + tmppath = None + if rank == 0: + tmpdir = tempfile.TemporaryDirectory() + tmppath = tmpdir.name + if self.comm is not None: + tmppath = self.comm.bcast(tmppath, root=0) + + for local_shape in [(4, 3, 1000), (10000,)]: + shpstr = "x".join([f"{int(x)}" for x in local_shape]) + for dt, dtstr, sigma, quant in [ + (np.dtype(np.int32), "i32", None, None), + (np.dtype(np.int64), "i64", None, None), + (np.dtype(np.float32), "f32", 1.0, 1.0e-7), + (np.dtype(np.float64), "f64", 1.0, 1.0e-15), + ]: + if rank == self.comm.size - 1: + trailing = local_shape[1:] + if len(trailing) == 0: + input_shape = (0,) + else: + input_shape = (0,) + trailing + else: + input_shape = local_shape + input, mpi_dist = create_fake_data( + input_shape, sigma=sigma, dtype=dt, comm=self.comm + ) + flcarr = FlacArray.from_array( + input, quanta=quant, mpi_comm=self.comm, use_threads=True + ) + + filename = os.path.join(tmppath, f"data_{dtstr}_{shpstr}.h5") + with H5File(filename, "w", comm=self.comm) as hf: + flcarr.write_hdf5(hf.handle) + if self.comm is not None: + self.comm.barrier() + with H5File(filename, "r", comm=self.comm) as hf: + check = FlacArray.read_hdf5( + hf.handle, mpi_comm=self.comm, mpi_dist=mpi_dist + ) + + local_fail = check != flcarr + if self.comm is not None: + fail = self.comm.allreduce(local_fail, op=MPI.SUM) + else: + fail = local_fail + + if fail: + print(f"check_{dtstr}_{shpstr}[{rank}] = {check}", flush=True) + print(f"flcarr_{dtstr}_{shpstr}[{rank}] = {flcarr}", flush=True) + print(f"FAIL on {dtstr} FlacArray roundtrip to hdf5", flush=True) + self.assertTrue(False) + else: + output = check.to_array(use_threads=True) + if dtstr == "i32" or dtstr == "i64": + local_arr_fail = not np.array_equal(output, input) + else: + local_arr_fail = not np.allclose(output, input, atol=1e-6) + if self.comm is not None: + arr_fail = self.comm.allreduce(local_arr_fail, op=MPI.SUM) + else: + arr_fail = local_arr_fail + if arr_fail: + print(f"output_{dtstr}_{shpstr}[{rank}] = {output}", flush=True) + print(f"input_{dtstr}_{shpstr}[{rank}] = {input}", flush=True) + print(f"FAIL on {dtstr} array roundtrip to hdf5", flush=True) + self.assertTrue(False) + + if self.comm is not None: + self.comm.barrier() + if tmpdir is not None: + tmpdir.cleanup() + del tmpdir + def test_array_keep_dist(self): if not have_hdf5: print("h5py not available, skipping tests", flush=True) diff --git a/src/flacarray/tests/zarr.py b/src/flacarray/tests/zarr.py index 3b2c946..62ead4e 100644 --- a/src/flacarray/tests/zarr.py +++ b/src/flacarray/tests/zarr.py @@ -169,6 +169,89 @@ def test_array_write_read(self): tmpdir.cleanup() del tmpdir + def test_array_write_read_nodata(self): + if not have_zarr: + print("zarr not available, skipping tests", flush=True) + return + if self.comm is None: + rank = 0 + else: + rank = self.comm.rank + + tmpdir = None + tmppath = None + if rank == 0: + tmpdir = tempfile.TemporaryDirectory() + tmppath = tmpdir.name + if self.comm is not None: + tmppath = self.comm.bcast(tmppath, root=0) + + for local_shape in [(4, 3, 1000), (10000,)]: + shpstr = "x".join([f"{int(x)}" for x in local_shape]) + for dt, dtstr, sigma, quant in [ + (np.dtype(np.int32), "i32", None, None), + (np.dtype(np.int64), "i64", None, None), + (np.dtype(np.float32), "f32", 1.0, 1.0e-7), + (np.dtype(np.float64), "f64", 1.0, 1.0e-15), + ]: + if rank == self.comm.size - 1: + trailing = local_shape[1:] + if len(trailing) == 0: + input_shape = (0,) + else: + input_shape = (0,) + trailing + else: + input_shape = local_shape + input, mpi_dist = create_fake_data( + input_shape, sigma=sigma, dtype=dt, comm=self.comm + ) + flcarr = FlacArray.from_array( + input, quanta=quant, mpi_comm=self.comm, use_threads=True + ) + + filename = os.path.join(tmppath, f"data_{dtstr}.zarr") + with ZarrGroup(filename, mode="w", comm=self.comm) as zf: + flcarr.write_zarr(zf) + if self.comm is not None: + self.comm.barrier() + with ZarrGroup(filename, mode="r", comm=self.comm) as zf: + check = FlacArray.read_zarr( + zf, mpi_comm=self.comm, mpi_dist=mpi_dist + ) + + local_fail = check != flcarr + if self.comm is not None: + fail = self.comm.allreduce(local_fail, op=MPI.SUM) + else: + fail = local_fail + + if fail: + print(f"check_{dtstr}_{shpstr}[{rank}] = {check}", flush=True) + print(f"flcarr_{dtstr}_{shpstr}[{rank}] = {flcarr}", flush=True) + print(f"FAIL on {dtstr} FlacArray roundtrip to zarr", flush=True) + self.assertTrue(False) + else: + output = check.to_array(use_threads=True) + if dtstr == "i32" or dtstr == "i64": + local_arr_fail = not np.array_equal(output, input) + else: + local_arr_fail = not np.allclose(output, input, atol=1e-6) + if self.comm is not None: + arr_fail = self.comm.allreduce(local_arr_fail, op=MPI.SUM) + else: + arr_fail = local_arr_fail + if arr_fail: + print(f"output_{dtstr}_{shpstr}[{rank}] = {output}", flush=True) + print(f"input_{dtstr}_{shpstr}[{rank}] = {input}", flush=True) + print(f"FAIL on {dtstr} array roundtrip to zarr", flush=True) + self.assertTrue(False) + + if self.comm is not None: + self.comm.barrier() + if tmpdir is not None: + tmpdir.cleanup() + del tmpdir + def test_array_keep_dist(self): if not have_zarr: print("zarr not available, skipping tests", flush=True) diff --git a/src/flacarray/zarr.py b/src/flacarray/zarr.py index 7d564ee..9d39ce6 100644 --- a/src/flacarray/zarr.py +++ b/src/flacarray/zarr.py @@ -21,7 +21,7 @@ have_zarr = False from . import __version__ as flacarray_version -from .compress import array_compress +from .compress import array_compress, array_compress_empty from .io_common import receive_write_compressed from .mpi import global_array_properties, global_bytes from .utils import function_timer @@ -345,8 +345,21 @@ def write_array( raise RuntimeError("zarr is not importable, cannot write to zarr.Group") # Get the global shape of the array - global_props = global_array_properties(arr.shape, mpi_comm=mpi_comm) + empty_data = arr is None or arr.shape[0] == 0 + if empty_data and mpi_comm is None: + raise RuntimeError("Local array is None, and MPI is not being used") + + if empty_data: + # No data on this process + arr_shape = None + arr_dtype = None + else: + arr_shape = arr.shape + arr_dtype = arr.dtype + + global_props = global_array_properties(arr_shape, arr_dtype, mpi_comm=mpi_comm) global_shape = global_props["shape"] + dtype = global_props["dtype"] mpi_dist = global_props["dist"] # Get the number of channels @@ -355,22 +368,31 @@ def write_array( else: n_channels = 1 + stream_size = global_shape[-1] + global_leading_shape = global_shape[:-1] + # Compress our local piece of the array - compressed, starts, nbytes, offsets, gains = array_compress( - arr, level=level, quanta=quanta, precision=precision, use_threads=use_threads - ) + if empty_data: + compressed, starts, nbytes, offsets, gains = array_compress_empty( + global_shape, dtype, quanta, precision + ) + if len(global_leading_shape[1:]) == 0: + leading_shape = (0,) + else: + leading_shape = (0,) + global_leading_shape[1:] + else: + compressed, starts, nbytes, offsets, gains = array_compress( + arr, level=level, quanta=quanta, precision=precision, use_threads=use_threads + ) + if len(arr.shape) == 1: + leading_shape = (1,) + else: + leading_shape = arr.shape[:-1] local_nbytes = compressed.nbytes global_nbytes, global_proc_bytes, global_starts = global_bytes( local_nbytes, starts, mpi_comm ) - stream_size = arr.shape[-1] - - if len(arr.shape) == 1: - leading_shape = (1,) - else: - leading_shape = arr.shape[:-1] - global_leading_shape = global_shape[:-1] write_compressed( zgrp, diff --git a/src/flacarray/zarr_load_v1.py b/src/flacarray/zarr_load_v1.py index e8ad7d7..53a84be 100644 --- a/src/flacarray/zarr_load_v1.py +++ b/src/flacarray/zarr_load_v1.py @@ -13,7 +13,7 @@ from .decompress import array_decompress from .mpi import distribute_and_verify -from .io_common import read_send_compressed +from .io_common import read_send_compressed, initialize_empty_buffers from .utils import function_timer @@ -282,22 +282,19 @@ def read_array( first_samp = stream_slice.start last_samp = stream_slice.stop - if compressed is None: - arr = None - else: - arr = array_decompress( - compressed, - local_shape[-1], - stream_starts, - stream_nbytes, - stream_offsets=stream_offsets, - stream_gains=stream_gains, - first_stream_sample=first_samp, - last_stream_sample=last_samp, - is_int64=(n_channel == 2), - use_threads=use_threads, - no_flatten=no_flatten, - ) + arr = array_decompress( + compressed, + local_shape[-1], + stream_starts, + stream_nbytes, + stream_offsets=stream_offsets, + stream_gains=stream_gains, + first_stream_sample=first_samp, + last_stream_sample=last_samp, + is_int64=(n_channel == 2), + use_threads=use_threads, + no_flatten=no_flatten, + ) if keep_indices: return arr, indices else: From 1cb480687779903144ad7a4237949a7dad0b9ae8 Mon Sep 17 00:00:00 2001 From: Theodore Kisner Date: Mon, 29 Jun 2026 12:24:41 -0700 Subject: [PATCH 2/3] Skip tests that require MPI when not running with MPI. --- src/flacarray/tests/zarr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/flacarray/tests/zarr.py b/src/flacarray/tests/zarr.py index 62ead4e..b149daf 100644 --- a/src/flacarray/tests/zarr.py +++ b/src/flacarray/tests/zarr.py @@ -173,11 +173,11 @@ def test_array_write_read_nodata(self): if not have_zarr: print("zarr not available, skipping tests", flush=True) return - if self.comm is None: - rank = 0 - else: - rank = self.comm.rank + if self.comm is None or self.comm.size < 2: + print("Less than 2 processes, skipping MPI test with empty procs") + return + rank = self.comm.rank tmpdir = None tmppath = None if rank == 0: From 3cb1df85f28e7b5537a5eb3f4607b42dd74b8aeb Mon Sep 17 00:00:00 2001 From: Theodore Kisner Date: Tue, 30 Jun 2026 07:43:28 -0700 Subject: [PATCH 3/3] Address review feedback --- src/flacarray/array.py | 4 +--- src/flacarray/decompress.py | 2 ++ src/flacarray/hdf5.py | 5 +++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/flacarray/array.py b/src/flacarray/array.py index 55a57be..7b90a74 100644 --- a/src/flacarray/array.py +++ b/src/flacarray/array.py @@ -363,9 +363,7 @@ def _get_leading_axes(self, full_key): if self._flatten_single: # Our array is a single stream with flattened shape. - keep_slice = [ - 0, - ] + keep_slice = [0] else: for axis, axkey in enumerate(full_key[:-1]): if not isinstance(axkey, (int, np.integer)): diff --git a/src/flacarray/decompress.py b/src/flacarray/decompress.py index d0d567c..c3f0550 100644 --- a/src/flacarray/decompress.py +++ b/src/flacarray/decompress.py @@ -78,11 +78,13 @@ def array_decompress_slice( # in the leading shape empty_shape = stream_starts.shape + (stream_size,) if stream_offsets is None: + # This means the output array will contain integer data if is_int64: empty_dtype = np.int64 else: empty_dtype = np.int32 else: + # The output array will contain floating point data. if is_int64: empty_dtype = np.float64 else: diff --git a/src/flacarray/hdf5.py b/src/flacarray/hdf5.py index 6e42de2..d01b50f 100644 --- a/src/flacarray/hdf5.py +++ b/src/flacarray/hdf5.py @@ -302,6 +302,11 @@ def write_compressed( ), ) + tuple([slice(0, x) for x in aux_local_shape[1:]]) + # Each call to the HDF5 Dataset.collective context manager creates a + # region of synchronous operations on the dataset. Even processes with + # no data for a given dataset must enter this to avoid a deadlock. + # However, only processes with local data will write. + with dstarts.collective: if len(compressed) > 0: dstarts.write_direct(global_stream_starts, dslc, hslc)