Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions xrspatial/hydro/stream_link_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,38 +924,50 @@ def _stream_link_mfd_dask(fractions_da, accum_da, threshold):
_mask_bdry = mask_bdry
_threshold = threshold

# Assemble result via da.block
rows = []
for iy in range(n_tile_y):
row = []
for ix in range(n_tile_x):
y_start = sum(chunks_y[:iy])
y_end = y_start + chunks_y[iy]
x_start = sum(chunks_x[:ix])
x_end = x_start + chunks_x[ix]

frac_chunk = np.asarray(
fractions_da[:, y_start:y_end, x_start:x_end].compute(),
dtype=np.float64)
ac_chunk = _to_numpy_f64(
accum_da[y_start:y_end, x_start:x_end].compute())

sm = np.where(ac_chunk >= _threshold, 1, 0).astype(np.int8)
sm = np.where(np.isnan(ac_chunk), 0, sm).astype(np.int8)
sm = np.where(np.isnan(frac_chunk[0]), 0, sm).astype(np.int8)
_, h, w = frac_chunk.shape

seeds = _compute_link_seeds_mfd(
iy, ix, _boundaries, _frac_bdry, _mask_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)

tile_link = _stream_link_mfd_tile_kernel(
frac_chunk, sm, h, w, *seeds,
row_offsets[iy], col_offsets[ix], total_width)
row.append(da.from_array(tile_link, chunks=tile_link.shape))
rows.append(row)

return da.block(rows)
# Lazy assembly: each tile is recomputed on demand from the converged
# boundary state. Driver memory holds only the small boundary/mask
# snapshots, so peak memory scales with chunk size rather than the full
# grid. ``fractions_da`` is 3D (8, H, W) and cannot align with the 2D
# output via map_blocks, so we map over ``accum_da`` (aligned to the
# fractions' spatial tile grid) and slice the matching fractions strip
# inside the closure. ``chunk-location`` gives the tile index directly.
accum_da = accum_da.rechunk((chunks_y, chunks_x))
cum_y = np.zeros(n_tile_y + 1, dtype=np.int64)
np.cumsum(chunks_y, out=cum_y[1:])
cum_x = np.zeros(n_tile_x + 1, dtype=np.int64)
np.cumsum(chunks_x, out=cum_x[1:])

def _tile(ac_block, block_info=None):
# ``meta`` is passed to map_blocks below, so dask never runs a
# block_info=None dry-run; block_info is always populated here.
iy, ix = block_info[0]['chunk-location']
y_start = int(cum_y[iy])
y_end = int(cum_y[iy + 1])
x_start = int(cum_x[ix])
x_end = int(cum_x[ix + 1])

frac_chunk = np.asarray(
fractions_da[:, y_start:y_end, x_start:x_end].compute(),
dtype=np.float64)
ac_chunk = _to_numpy_f64(ac_block)

sm = np.where(ac_chunk >= _threshold, 1, 0).astype(np.int8)
sm = np.where(np.isnan(ac_chunk), 0, sm).astype(np.int8)
sm = np.where(np.isnan(frac_chunk[0]), 0, sm).astype(np.int8)
_, h, w = frac_chunk.shape

seeds = _compute_link_seeds_mfd(
iy, ix, _boundaries, _frac_bdry, _mask_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)

return _stream_link_mfd_tile_kernel(
frac_chunk, sm, h, w, *seeds,
row_offsets[iy], col_offsets[ix], total_width)

return da.map_blocks(
_tile, accum_da,
dtype=np.float64, meta=np.array((), dtype=np.float64),
)


def _stream_link_mfd_dask_cupy(fractions_da, accum_da, threshold):
Expand Down
128 changes: 72 additions & 56 deletions xrspatial/hydro/stream_order_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,34 +1314,43 @@ def _stream_order_mfd_dask_strahler(fractions_da, accum_da, threshold):
_mask_bdry = mask_bdry
_threshold = threshold

# Assemble result by re-running each tile with converged seeds
rows = []
for iy in range(n_tile_y):
row = []
for ix in range(n_tile_x):
y_start = sum(chunks_y[:iy])
y_end = y_start + chunks_y[iy]
x_start = sum(chunks_x[:ix])
x_end = x_start + chunks_x[ix]

frac_chunk = np.asarray(
fractions_da[:, y_start:y_end, x_start:x_end].compute(),
dtype=np.float64)
ac_chunk = _to_numpy_f64(
accum_da[y_start:y_end, x_start:x_end].compute())
sm = _make_stream_mask_mfd_np(ac_chunk, frac_chunk, _threshold)
_, h, w = frac_chunk.shape

seeds = _compute_strahler_seeds_mfd(
iy, ix, _bdry_max, _bdry_cnt, _frac_bdry, _mask_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)

tile_order, _, _ = _strahler_mfd_tile_kernel(
frac_chunk, sm, h, w, *seeds)
row.append(da.from_array(tile_order, chunks=tile_order.shape))
rows.append(row)

return da.block(rows)
# Lazy assembly: re-run each tile on demand from the converged seeds.
# Driver memory holds only the small boundary/mask snapshots, so peak
# memory scales with chunk size rather than the full grid. Map over
# ``accum_da`` (aligned to the fractions' spatial tile grid) and slice
# the matching 3D fractions strip inside the closure.
accum_da = accum_da.rechunk((chunks_y, chunks_x))
cum_y = np.zeros(n_tile_y + 1, dtype=np.int64)
np.cumsum(chunks_y, out=cum_y[1:])
cum_x = np.zeros(n_tile_x + 1, dtype=np.int64)
np.cumsum(chunks_x, out=cum_x[1:])

def _tile(ac_block, block_info=None):
# ``meta`` is passed to map_blocks below, so dask never runs a
# block_info=None dry-run; block_info is always populated here.
iy, ix = block_info[0]['chunk-location']
y_start, y_end = int(cum_y[iy]), int(cum_y[iy + 1])
x_start, x_end = int(cum_x[ix]), int(cum_x[ix + 1])

frac_chunk = np.asarray(
fractions_da[:, y_start:y_end, x_start:x_end].compute(),
dtype=np.float64)
ac_chunk = _to_numpy_f64(ac_block)
sm = _make_stream_mask_mfd_np(ac_chunk, frac_chunk, _threshold)
_, h, w = frac_chunk.shape

seeds = _compute_strahler_seeds_mfd(
iy, ix, _bdry_max, _bdry_cnt, _frac_bdry, _mask_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)

tile_order, _, _ = _strahler_mfd_tile_kernel(
frac_chunk, sm, h, w, *seeds)
return tile_order

return da.map_blocks(
_tile, accum_da,
dtype=np.float64, meta=np.array((), dtype=np.float64),
)


def _stream_order_mfd_dask_shreve(fractions_da, accum_da, threshold):
Expand Down Expand Up @@ -1384,34 +1393,41 @@ def _stream_order_mfd_dask_shreve(fractions_da, accum_da, threshold):
_mask_bdry = mask_bdry
_threshold = threshold

# Assemble result
rows = []
for iy in range(n_tile_y):
row = []
for ix in range(n_tile_x):
y_start = sum(chunks_y[:iy])
y_end = y_start + chunks_y[iy]
x_start = sum(chunks_x[:ix])
x_end = x_start + chunks_x[ix]

frac_chunk = np.asarray(
fractions_da[:, y_start:y_end, x_start:x_end].compute(),
dtype=np.float64)
ac_chunk = _to_numpy_f64(
accum_da[y_start:y_end, x_start:x_end].compute())
sm = _make_stream_mask_mfd_np(ac_chunk, frac_chunk, _threshold)
_, h, w = frac_chunk.shape

seeds = _compute_shreve_seeds_mfd(
iy, ix, _boundaries, _frac_bdry, _mask_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)

tile_order = _shreve_mfd_tile_kernel(
frac_chunk, sm, h, w, *seeds)
row.append(da.from_array(tile_order, chunks=tile_order.shape))
rows.append(row)

return da.block(rows)
# Lazy assembly: re-run each tile on demand from the converged seeds.
# Driver memory holds only the small boundary/mask snapshots, so peak
# memory scales with chunk size rather than the full grid. Map over
# ``accum_da`` (aligned to the fractions' spatial tile grid) and slice
# the matching 3D fractions strip inside the closure.
accum_da = accum_da.rechunk((chunks_y, chunks_x))
cum_y = np.zeros(n_tile_y + 1, dtype=np.int64)
np.cumsum(chunks_y, out=cum_y[1:])
cum_x = np.zeros(n_tile_x + 1, dtype=np.int64)
np.cumsum(chunks_x, out=cum_x[1:])

def _tile(ac_block, block_info=None):
# ``meta`` is passed to map_blocks below, so dask never runs a
# block_info=None dry-run; block_info is always populated here.
iy, ix = block_info[0]['chunk-location']
y_start, y_end = int(cum_y[iy]), int(cum_y[iy + 1])
x_start, x_end = int(cum_x[ix]), int(cum_x[ix + 1])

frac_chunk = np.asarray(
fractions_da[:, y_start:y_end, x_start:x_end].compute(),
dtype=np.float64)
ac_chunk = _to_numpy_f64(ac_block)
sm = _make_stream_mask_mfd_np(ac_chunk, frac_chunk, _threshold)
_, h, w = frac_chunk.shape

seeds = _compute_shreve_seeds_mfd(
iy, ix, _boundaries, _frac_bdry, _mask_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)

return _shreve_mfd_tile_kernel(frac_chunk, sm, h, w, *seeds)

return da.map_blocks(
_tile, accum_da,
dtype=np.float64, meta=np.array((), dtype=np.float64),
)


# =====================================================================
Expand Down
78 changes: 78 additions & 0 deletions xrspatial/hydro/tests/test_stream_link_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,84 @@ def test_dask_matches_numpy():
np.nan_to_num(dask_result.values, nan=-999))


@dask_array_available
def test_dask_accum_chunk_mismatch():
"""flow_accum chunked differently from fractions still matches numpy."""
import dask.array as da

fracs = _make_fractions({
(0, 0): [(1, 1.0)],
(0, 2): [(3, 1.0)],
(1, 1): [(2, 1.0)],
(2, 1): [],
}, (3, 3))
accum = np.array([
[1.0, 0.0, 1.0],
[0.0, 3.0, 0.0],
[0.0, 4.0, 0.0],
], dtype=np.float64)

frac_np = xr.DataArray(fracs, dims=['neighbor', 'y', 'x'])
fa_np = create_test_raster(accum)
np_result = stream_link_mfd(frac_np, fa_np, threshold=1)

frac_dask = xr.DataArray(
da.from_array(fracs, chunks=(8, 2, 2)),
dims=['neighbor', 'y', 'x'])
# flow_accum chunked 3x3 while fractions are 2x2 -- the lazy assembly
# must realign it onto the fractions' tile grid.
fa_dask = xr.DataArray(
da.from_array(accum, chunks=(3, 3)),
dims=['y', 'x'])
dask_result = stream_link_mfd(frac_dask, fa_dask, threshold=1)

np.testing.assert_array_equal(
np.nan_to_num(np_result.values, nan=-999),
np.nan_to_num(dask_result.values, nan=-999))


@dask_array_available
def test_dask_assembly_is_lazy(monkeypatch):
"""Building the output raster must be deferred to compute time (#2885)."""
import importlib
import dask.array as da
mod = importlib.import_module('xrspatial.hydro.stream_link_mfd')

counter = {'n': 0}
orig = mod._stream_link_mfd_tile_kernel

def _spy(*args, **kwargs):
counter['n'] += 1
return orig(*args, **kwargs)

monkeypatch.setattr(mod, '_stream_link_mfd_tile_kernel', _spy)

fracs = _make_fractions({
(0, 0): [(1, 1.0)],
(0, 2): [(3, 1.0)],
(1, 1): [(2, 1.0)],
(2, 1): [],
}, (3, 3))
accum = np.array([
[1.0, 0.0, 1.0],
[0.0, 3.0, 0.0],
[0.0, 4.0, 0.0],
], dtype=np.float64)

frac_dask = xr.DataArray(
da.from_array(fracs, chunks=(8, 2, 2)),
dims=['neighbor', 'y', 'x'])
fa_dask = xr.DataArray(
da.from_array(accum, chunks=(2, 2)),
dims=['y', 'x'])

result = stream_link_mfd(frac_dask, fa_dask, threshold=1)
# The convergence sweep runs eagerly, but assembling the result must not.
calls_after_call = counter['n']
result.data.compute()
assert counter['n'] - calls_after_call > 0


# ====================================================================
# Memory guard tests
# ====================================================================
Expand Down
79 changes: 79 additions & 0 deletions xrspatial/hydro/tests/test_stream_order_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,85 @@ def test_dask_matches_numpy():
np.nan_to_num(dask_result.values, nan=-999))


_LAZY_FRACS = _make_fractions({
(0, 0): [(1, 1.0)],
(0, 1): [],
(0, 2): [(3, 1.0)],
(1, 0): [],
(1, 1): [(2, 1.0)],
(1, 2): [],
(2, 0): [(0, 1.0)],
(2, 1): [],
(2, 2): [],
}, (3, 3))
_LAZY_ACCUM = np.array([
[1.0, 1.0, 1.0],
[1.0, 3.0, 1.0],
[1.0, 5.0, 1.0],
], dtype=np.float64)


@dask_array_available
@pytest.mark.parametrize('method', ['strahler', 'shreve'])
def test_dask_accum_chunk_mismatch(method):
"""flow_accum chunked differently from fractions still matches numpy."""
import dask.array as da

frac_da_np = xr.DataArray(_LAZY_FRACS, dims=['neighbor', 'y', 'x'])
fa_da_np = create_test_raster(_LAZY_ACCUM)
np_result = stream_order_mfd(frac_da_np, fa_da_np, threshold=1,
method=method)

frac_dask = xr.DataArray(
da.from_array(_LAZY_FRACS, chunks=(8, 2, 2)),
dims=['neighbor', 'y', 'x'])
# flow_accum chunked 3x3 while fractions are 2x2 -- the lazy assembly
# must realign it onto the fractions' tile grid.
fa_dask = xr.DataArray(
da.from_array(_LAZY_ACCUM, chunks=(3, 3)),
dims=['y', 'x'])
dask_result = stream_order_mfd(frac_dask, fa_dask, threshold=1,
method=method)

np.testing.assert_array_equal(
np.nan_to_num(np_result.values, nan=-999),
np.nan_to_num(dask_result.values, nan=-999))


@dask_array_available
@pytest.mark.parametrize('method,kernel', [
('strahler', '_strahler_mfd_tile_kernel'),
('shreve', '_shreve_mfd_tile_kernel'),
])
def test_dask_assembly_is_lazy(monkeypatch, method, kernel):
"""Building the output raster must be deferred to compute time (#2885)."""
import importlib
import dask.array as da
mod = importlib.import_module('xrspatial.hydro.stream_order_mfd')

counter = {'n': 0}
orig = getattr(mod, kernel)

def _spy(*args, **kwargs):
counter['n'] += 1
return orig(*args, **kwargs)

monkeypatch.setattr(mod, kernel, _spy)

frac_dask = xr.DataArray(
da.from_array(_LAZY_FRACS, chunks=(8, 2, 2)),
dims=['neighbor', 'y', 'x'])
fa_dask = xr.DataArray(
da.from_array(_LAZY_ACCUM, chunks=(2, 2)),
dims=['y', 'x'])

result = stream_order_mfd(frac_dask, fa_dask, threshold=1, method=method)
# The convergence sweep runs eagerly, but assembling the result must not.
calls_after_call = counter['n']
result.data.compute()
assert counter['n'] - calls_after_call > 0


# ====================================================================
# Memory guard tests
# ====================================================================
Expand Down
Loading