Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 48 additions & 16 deletions src/flacarray/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .compress import array_compress
from .compress import array_compress, array_compress_empty
from .decompress import array_decompress_slice
from .hdf5 import write_compressed as hdf5_write_compressed
from .hdf5 import read_compressed as hdf5_read_compressed
Expand Down Expand Up @@ -93,6 +93,7 @@ def __init__(
if other is not None:
# We are copying an existing object, make sure we have an
# independent copy.
self._empty = copy.deepcopy(other._empty)
self._shape = copy.deepcopy(other._shape)
self._global_shape = copy.deepcopy(other._global_shape)
self._compressed = copy.deepcopy(other._compressed)
Expand All @@ -107,6 +108,9 @@ def __init__(
else:
# This form of constructor is used in the class methods where we
# have already created these arrays for use by this instance.
self._empty = shape is None or shape[0] == 0
if self._empty and mpi_comm is None:
raise RuntimeError("Local data is empty, and MPI is not being used")
self._shape = shape
self._global_shape = global_shape
self._compressed = compressed
Expand All @@ -125,25 +129,33 @@ def _init_params(self):
# stream, this tracks the user intentions about whether to flatten the
# leading dimension. We also track the "local shape", with is the same,
# but which always keeps the leading dimension.
if len(self._shape) == 1:
if self._empty:
# No local data
self._flatten_single = False
self._local_shape = (0,) + self._global_shape[1:]
self._local_nbytes = 0
elif len(self._shape) == 1:
self._flatten_single = True
self._local_shape = (1, self._shape[0])
self._local_nbytes = self._compressed.nbytes
else:
self._flatten_single = False
self._local_shape = self._shape
self._local_nbytes = self._compressed.nbytes

self._local_nbytes = self._compressed.nbytes
(
self._global_nbytes,
self._global_proc_nbytes,
self._global_stream_starts,
) = global_bytes(self._local_nbytes, self._stream_starts, self._mpi_comm)

self._leading_shape = self._local_shape[:-1]
self._global_leading_shape = self._global_shape[:-1]
self._stream_size = self._local_shape[-1]
self._stream_size = self._global_shape[-1]

# For reference, record the type string of the original data.
self._typestr = self._dtype_str(self._dtype)

# Track whether we have 32bit or 64bit data
self._is_int64 = self._dtype == np.dtype(np.int64) or self._dtype == np.dtype(
np.float64
Expand Down Expand Up @@ -351,7 +363,7 @@ def _get_leading_axes(self, full_key):

if self._flatten_single:
# Our array is a single stream with flattened shape.
keep_slice = [0,]
keep_slice = [0]
else:
for axis, axkey in enumerate(full_key[:-1]):
if not isinstance(axkey, (int, np.integer)):
Expand Down Expand Up @@ -398,7 +410,7 @@ def _get_sample_axis(self, full_key):
if stop - start <= 0:
# No samples
return (0, 0, (0,))
return (start, stop, (stop-start,))
return (start, stop, (stop - start,))
elif isinstance(sample_key, (int, np.integer)):
# Just a scalar
return (sample_key, sample_key + 1, ())
Expand Down Expand Up @@ -609,25 +621,45 @@ def from_array(

"""
# Get the global shape of the array
global_props = global_array_properties(arr.shape, mpi_comm=mpi_comm)
empty_data = arr is None or arr.shape[0] == 0
if empty_data and mpi_comm is None:
raise RuntimeError("Local array is None, and MPI is not being used")

if empty_data:
# No data on this process
global_props = global_array_properties(None, None, mpi_comm=mpi_comm)
else:
global_props = global_array_properties(
arr.shape, arr.dtype, mpi_comm=mpi_comm
)

global_shape = global_props["shape"]
dtype = global_props["dtype"]
mpi_dist = global_props["dist"]

# Compress our local piece of the array
compressed, starts, nbytes, offsets, gains = array_compress(
arr,
level=level,
quanta=quanta,
precision=precision,
use_threads=use_threads,
)
if empty_data:
# No data
arr_shape = (0,) + global_shape[1:]
compressed, starts, nbytes, offsets, gains = array_compress_empty(
global_shape, dtype, quanta, precision
)
else:
arr_shape = arr.shape
compressed, starts, nbytes, offsets, gains = array_compress(
arr,
level=level,
quanta=quanta,
precision=precision,
use_threads=use_threads,
)

return FlacArray(
None,
shape=arr.shape,
shape=arr_shape,
global_shape=global_shape,
compressed=compressed,
dtype=arr.dtype,
dtype=dtype,
stream_starts=starts,
stream_nbytes=nbytes,
stream_offsets=offsets,
Expand Down
35 changes: 35 additions & 0 deletions src/flacarray/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,38 @@ def array_compress(arr, level=5, quanta=None, precision=None, use_threads=False)
return (compressed, starts, nbytes, None, None)
else:
raise ValueError(f"Unsupported data type '{arr.dtype}'")


def array_compress_empty(global_shape, dtype, quanta, precision):
"""Mock the compressed data parameters for processes with no data.

When using MPI, some processes may have no local data. This helper function
returns the equivalent parameters for those processes which cannot call
`array_compress()`.

Args:
global_shape (tuple): The shape of the global data
dtype (np.dtype): The array dtype
quanta (array): The quanta values (only checked for None)
precision (array): The precision (only check for None)

Returns:
(tuple): The (compressed bytes, stream starts, stream_nbytes, stream offsets,
stream gains)

"""
leading_shape = global_shape[:-1]
compressed = np.zeros(0, dtype=np.uint8)
if len(leading_shape[1:]) == 0:
starts_shape = (0,)
else:
starts_shape = (0,) + leading_shape[1:]
starts = np.zeros(starts_shape, dtype=np.int64)
nbytes = np.zeros(starts_shape, dtype=np.int64)
if quanta is None and precision is None:
offsets = None
gains = None
else:
offsets = np.zeros(starts_shape, dtype=dtype)
gains = np.zeros(starts_shape, dtype=dtype)
return (compressed, starts, nbytes, offsets, gains)
18 changes: 18 additions & 0 deletions src/flacarray/decompress.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ def array_decompress_slice(
(tuple): The (output array, list of stream indices).

"""
if len(compressed) == 0:
# This process has no data. Return an empty array with zero
# in the leading shape
empty_shape = stream_starts.shape + (stream_size,)
if stream_offsets is None:
Comment thread
tskisner marked this conversation as resolved.
# This means the output array will contain integer data
if is_int64:
empty_dtype = np.int64
else:
empty_dtype = np.int32
else:
# The output array will contain floating point data.
if is_int64:
empty_dtype = np.float64
else:
empty_dtype = np.float32
return np.zeros(empty_shape, dtype=empty_dtype), list()

if first_stream_sample is None:
first_stream_sample = -1
if last_stream_sample is None:
Expand Down
18 changes: 10 additions & 8 deletions src/flacarray/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def create_fake_data(
rank = comm.rank

# Get the global array properties
gprops = global_array_properties(local_shape, comm)
gprops = global_array_properties(local_shape, dtype, comm)
shape = gprops["shape"]
dtype = gprops["dtype"]
mpi_dist = gprops["dist"]

flatshape = np.prod(shape)
Expand Down Expand Up @@ -99,18 +100,19 @@ def create_fake_data(
global_data = comm.bcast(global_data, root=0)

# Extract our local piece of the global data
if len(leading_shape) == 0 or (len(leading_shape) == 1 and leading_shape[0] == 1):
data = global_data
local_start = mpi_dist[rank][0]
local_stop = mpi_dist[rank][1]
if local_start == local_stop:
# This process has no data
empty_shape = (0,) + shape[1:]
data = np.zeros(empty_shape, dtype=dtype)
else:
local_start = mpi_dist[rank][0]
local_stop = mpi_dist[rank][1]
local_slice = [slice(local_start, local_stop, 1)]
local_slice.extend([slice(None) for x in shape[1:]])
local_slice = tuple(local_slice)
data = global_data[local_slice]
if len(data.shape) == 2 and data.shape[0] == 1:
data = data.reshape((-1))

if len(data.shape) == 2 and data.shape[0] == 1:
data = data.reshape((-1))
return data, mpi_dist


Expand Down
Loading