diff --git a/xrspatial/hydro/stream_link_mfd.py b/xrspatial/hydro/stream_link_mfd.py index a32e65eb2..63947d08c 100644 --- a/xrspatial/hydro/stream_link_mfd.py +++ b/xrspatial/hydro/stream_link_mfd.py @@ -924,38 +924,50 @@ def _stream_link_mfd_dask(fractions_da, accum_da, threshold): _mask_bdry = mask_bdry _threshold = threshold - # Assemble result via da.block - rows = [] - for iy in range(n_tile_y): - row = [] - for ix in range(n_tile_x): - y_start = sum(chunks_y[:iy]) - y_end = y_start + chunks_y[iy] - x_start = sum(chunks_x[:ix]) - x_end = x_start + chunks_x[ix] - - frac_chunk = np.asarray( - fractions_da[:, y_start:y_end, x_start:x_end].compute(), - dtype=np.float64) - ac_chunk = _to_numpy_f64( - accum_da[y_start:y_end, x_start:x_end].compute()) - - sm = np.where(ac_chunk >= _threshold, 1, 0).astype(np.int8) - sm = np.where(np.isnan(ac_chunk), 0, sm).astype(np.int8) - sm = np.where(np.isnan(frac_chunk[0]), 0, sm).astype(np.int8) - _, h, w = frac_chunk.shape - - seeds = _compute_link_seeds_mfd( - iy, ix, _boundaries, _frac_bdry, _mask_bdry, - chunks_y, chunks_x, n_tile_y, n_tile_x) - - tile_link = _stream_link_mfd_tile_kernel( - frac_chunk, sm, h, w, *seeds, - row_offsets[iy], col_offsets[ix], total_width) - row.append(da.from_array(tile_link, chunks=tile_link.shape)) - rows.append(row) - - return da.block(rows) + # Lazy assembly: each tile is recomputed on demand from the converged + # boundary state. Driver memory holds only the small boundary/mask + # snapshots, so peak memory scales with chunk size rather than the full + # grid. ``fractions_da`` is 3D (8, H, W) and cannot align with the 2D + # output via map_blocks, so we map over ``accum_da`` (aligned to the + # fractions' spatial tile grid) and slice the matching fractions strip + # inside the closure. ``chunk-location`` gives the tile index directly. + accum_da = accum_da.rechunk((chunks_y, chunks_x)) + cum_y = np.zeros(n_tile_y + 1, dtype=np.int64) + np.cumsum(chunks_y, out=cum_y[1:]) + cum_x = np.zeros(n_tile_x + 1, dtype=np.int64) + np.cumsum(chunks_x, out=cum_x[1:]) + + def _tile(ac_block, block_info=None): + # ``meta`` is passed to map_blocks below, so dask never runs a + # block_info=None dry-run; block_info is always populated here. + iy, ix = block_info[0]['chunk-location'] + y_start = int(cum_y[iy]) + y_end = int(cum_y[iy + 1]) + x_start = int(cum_x[ix]) + x_end = int(cum_x[ix + 1]) + + frac_chunk = np.asarray( + fractions_da[:, y_start:y_end, x_start:x_end].compute(), + dtype=np.float64) + ac_chunk = _to_numpy_f64(ac_block) + + sm = np.where(ac_chunk >= _threshold, 1, 0).astype(np.int8) + sm = np.where(np.isnan(ac_chunk), 0, sm).astype(np.int8) + sm = np.where(np.isnan(frac_chunk[0]), 0, sm).astype(np.int8) + _, h, w = frac_chunk.shape + + seeds = _compute_link_seeds_mfd( + iy, ix, _boundaries, _frac_bdry, _mask_bdry, + chunks_y, chunks_x, n_tile_y, n_tile_x) + + return _stream_link_mfd_tile_kernel( + frac_chunk, sm, h, w, *seeds, + row_offsets[iy], col_offsets[ix], total_width) + + return da.map_blocks( + _tile, accum_da, + dtype=np.float64, meta=np.array((), dtype=np.float64), + ) def _stream_link_mfd_dask_cupy(fractions_da, accum_da, threshold): diff --git a/xrspatial/hydro/stream_order_mfd.py b/xrspatial/hydro/stream_order_mfd.py index ae4f54ba4..605a84e0f 100644 --- a/xrspatial/hydro/stream_order_mfd.py +++ b/xrspatial/hydro/stream_order_mfd.py @@ -1314,34 +1314,43 @@ def _stream_order_mfd_dask_strahler(fractions_da, accum_da, threshold): _mask_bdry = mask_bdry _threshold = threshold - # Assemble result by re-running each tile with converged seeds - rows = [] - for iy in range(n_tile_y): - row = [] - for ix in range(n_tile_x): - y_start = sum(chunks_y[:iy]) - y_end = y_start + chunks_y[iy] - x_start = sum(chunks_x[:ix]) - x_end = x_start + chunks_x[ix] - - frac_chunk = np.asarray( - fractions_da[:, y_start:y_end, x_start:x_end].compute(), - dtype=np.float64) - ac_chunk = _to_numpy_f64( - accum_da[y_start:y_end, x_start:x_end].compute()) - sm = _make_stream_mask_mfd_np(ac_chunk, frac_chunk, _threshold) - _, h, w = frac_chunk.shape - - seeds = _compute_strahler_seeds_mfd( - iy, ix, _bdry_max, _bdry_cnt, _frac_bdry, _mask_bdry, - chunks_y, chunks_x, n_tile_y, n_tile_x) - - tile_order, _, _ = _strahler_mfd_tile_kernel( - frac_chunk, sm, h, w, *seeds) - row.append(da.from_array(tile_order, chunks=tile_order.shape)) - rows.append(row) - - return da.block(rows) + # Lazy assembly: re-run each tile on demand from the converged seeds. + # Driver memory holds only the small boundary/mask snapshots, so peak + # memory scales with chunk size rather than the full grid. Map over + # ``accum_da`` (aligned to the fractions' spatial tile grid) and slice + # the matching 3D fractions strip inside the closure. + accum_da = accum_da.rechunk((chunks_y, chunks_x)) + cum_y = np.zeros(n_tile_y + 1, dtype=np.int64) + np.cumsum(chunks_y, out=cum_y[1:]) + cum_x = np.zeros(n_tile_x + 1, dtype=np.int64) + np.cumsum(chunks_x, out=cum_x[1:]) + + def _tile(ac_block, block_info=None): + # ``meta`` is passed to map_blocks below, so dask never runs a + # block_info=None dry-run; block_info is always populated here. + iy, ix = block_info[0]['chunk-location'] + y_start, y_end = int(cum_y[iy]), int(cum_y[iy + 1]) + x_start, x_end = int(cum_x[ix]), int(cum_x[ix + 1]) + + frac_chunk = np.asarray( + fractions_da[:, y_start:y_end, x_start:x_end].compute(), + dtype=np.float64) + ac_chunk = _to_numpy_f64(ac_block) + sm = _make_stream_mask_mfd_np(ac_chunk, frac_chunk, _threshold) + _, h, w = frac_chunk.shape + + seeds = _compute_strahler_seeds_mfd( + iy, ix, _bdry_max, _bdry_cnt, _frac_bdry, _mask_bdry, + chunks_y, chunks_x, n_tile_y, n_tile_x) + + tile_order, _, _ = _strahler_mfd_tile_kernel( + frac_chunk, sm, h, w, *seeds) + return tile_order + + return da.map_blocks( + _tile, accum_da, + dtype=np.float64, meta=np.array((), dtype=np.float64), + ) def _stream_order_mfd_dask_shreve(fractions_da, accum_da, threshold): @@ -1384,34 +1393,41 @@ def _stream_order_mfd_dask_shreve(fractions_da, accum_da, threshold): _mask_bdry = mask_bdry _threshold = threshold - # Assemble result - rows = [] - for iy in range(n_tile_y): - row = [] - for ix in range(n_tile_x): - y_start = sum(chunks_y[:iy]) - y_end = y_start + chunks_y[iy] - x_start = sum(chunks_x[:ix]) - x_end = x_start + chunks_x[ix] - - frac_chunk = np.asarray( - fractions_da[:, y_start:y_end, x_start:x_end].compute(), - dtype=np.float64) - ac_chunk = _to_numpy_f64( - accum_da[y_start:y_end, x_start:x_end].compute()) - sm = _make_stream_mask_mfd_np(ac_chunk, frac_chunk, _threshold) - _, h, w = frac_chunk.shape - - seeds = _compute_shreve_seeds_mfd( - iy, ix, _boundaries, _frac_bdry, _mask_bdry, - chunks_y, chunks_x, n_tile_y, n_tile_x) - - tile_order = _shreve_mfd_tile_kernel( - frac_chunk, sm, h, w, *seeds) - row.append(da.from_array(tile_order, chunks=tile_order.shape)) - rows.append(row) - - return da.block(rows) + # Lazy assembly: re-run each tile on demand from the converged seeds. + # Driver memory holds only the small boundary/mask snapshots, so peak + # memory scales with chunk size rather than the full grid. Map over + # ``accum_da`` (aligned to the fractions' spatial tile grid) and slice + # the matching 3D fractions strip inside the closure. + accum_da = accum_da.rechunk((chunks_y, chunks_x)) + cum_y = np.zeros(n_tile_y + 1, dtype=np.int64) + np.cumsum(chunks_y, out=cum_y[1:]) + cum_x = np.zeros(n_tile_x + 1, dtype=np.int64) + np.cumsum(chunks_x, out=cum_x[1:]) + + def _tile(ac_block, block_info=None): + # ``meta`` is passed to map_blocks below, so dask never runs a + # block_info=None dry-run; block_info is always populated here. + iy, ix = block_info[0]['chunk-location'] + y_start, y_end = int(cum_y[iy]), int(cum_y[iy + 1]) + x_start, x_end = int(cum_x[ix]), int(cum_x[ix + 1]) + + frac_chunk = np.asarray( + fractions_da[:, y_start:y_end, x_start:x_end].compute(), + dtype=np.float64) + ac_chunk = _to_numpy_f64(ac_block) + sm = _make_stream_mask_mfd_np(ac_chunk, frac_chunk, _threshold) + _, h, w = frac_chunk.shape + + seeds = _compute_shreve_seeds_mfd( + iy, ix, _boundaries, _frac_bdry, _mask_bdry, + chunks_y, chunks_x, n_tile_y, n_tile_x) + + return _shreve_mfd_tile_kernel(frac_chunk, sm, h, w, *seeds) + + return da.map_blocks( + _tile, accum_da, + dtype=np.float64, meta=np.array((), dtype=np.float64), + ) # ===================================================================== diff --git a/xrspatial/hydro/tests/test_stream_link_mfd.py b/xrspatial/hydro/tests/test_stream_link_mfd.py index 933c24a6f..fe109247d 100644 --- a/xrspatial/hydro/tests/test_stream_link_mfd.py +++ b/xrspatial/hydro/tests/test_stream_link_mfd.py @@ -183,6 +183,84 @@ def test_dask_matches_numpy(): np.nan_to_num(dask_result.values, nan=-999)) +@dask_array_available +def test_dask_accum_chunk_mismatch(): + """flow_accum chunked differently from fractions still matches numpy.""" + import dask.array as da + + fracs = _make_fractions({ + (0, 0): [(1, 1.0)], + (0, 2): [(3, 1.0)], + (1, 1): [(2, 1.0)], + (2, 1): [], + }, (3, 3)) + accum = np.array([ + [1.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + [0.0, 4.0, 0.0], + ], dtype=np.float64) + + frac_np = xr.DataArray(fracs, dims=['neighbor', 'y', 'x']) + fa_np = create_test_raster(accum) + np_result = stream_link_mfd(frac_np, fa_np, threshold=1) + + frac_dask = xr.DataArray( + da.from_array(fracs, chunks=(8, 2, 2)), + dims=['neighbor', 'y', 'x']) + # flow_accum chunked 3x3 while fractions are 2x2 -- the lazy assembly + # must realign it onto the fractions' tile grid. + fa_dask = xr.DataArray( + da.from_array(accum, chunks=(3, 3)), + dims=['y', 'x']) + dask_result = stream_link_mfd(frac_dask, fa_dask, threshold=1) + + np.testing.assert_array_equal( + np.nan_to_num(np_result.values, nan=-999), + np.nan_to_num(dask_result.values, nan=-999)) + + +@dask_array_available +def test_dask_assembly_is_lazy(monkeypatch): + """Building the output raster must be deferred to compute time (#2885).""" + import importlib + import dask.array as da + mod = importlib.import_module('xrspatial.hydro.stream_link_mfd') + + counter = {'n': 0} + orig = mod._stream_link_mfd_tile_kernel + + def _spy(*args, **kwargs): + counter['n'] += 1 + return orig(*args, **kwargs) + + monkeypatch.setattr(mod, '_stream_link_mfd_tile_kernel', _spy) + + fracs = _make_fractions({ + (0, 0): [(1, 1.0)], + (0, 2): [(3, 1.0)], + (1, 1): [(2, 1.0)], + (2, 1): [], + }, (3, 3)) + accum = np.array([ + [1.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + [0.0, 4.0, 0.0], + ], dtype=np.float64) + + frac_dask = xr.DataArray( + da.from_array(fracs, chunks=(8, 2, 2)), + dims=['neighbor', 'y', 'x']) + fa_dask = xr.DataArray( + da.from_array(accum, chunks=(2, 2)), + dims=['y', 'x']) + + result = stream_link_mfd(frac_dask, fa_dask, threshold=1) + # The convergence sweep runs eagerly, but assembling the result must not. + calls_after_call = counter['n'] + result.data.compute() + assert counter['n'] - calls_after_call > 0 + + # ==================================================================== # Memory guard tests # ==================================================================== diff --git a/xrspatial/hydro/tests/test_stream_order_mfd.py b/xrspatial/hydro/tests/test_stream_order_mfd.py index ebeddd4ab..ec7e898fd 100644 --- a/xrspatial/hydro/tests/test_stream_order_mfd.py +++ b/xrspatial/hydro/tests/test_stream_order_mfd.py @@ -238,6 +238,85 @@ def test_dask_matches_numpy(): np.nan_to_num(dask_result.values, nan=-999)) +_LAZY_FRACS = _make_fractions({ + (0, 0): [(1, 1.0)], + (0, 1): [], + (0, 2): [(3, 1.0)], + (1, 0): [], + (1, 1): [(2, 1.0)], + (1, 2): [], + (2, 0): [(0, 1.0)], + (2, 1): [], + (2, 2): [], +}, (3, 3)) +_LAZY_ACCUM = np.array([ + [1.0, 1.0, 1.0], + [1.0, 3.0, 1.0], + [1.0, 5.0, 1.0], +], dtype=np.float64) + + +@dask_array_available +@pytest.mark.parametrize('method', ['strahler', 'shreve']) +def test_dask_accum_chunk_mismatch(method): + """flow_accum chunked differently from fractions still matches numpy.""" + import dask.array as da + + frac_da_np = xr.DataArray(_LAZY_FRACS, dims=['neighbor', 'y', 'x']) + fa_da_np = create_test_raster(_LAZY_ACCUM) + np_result = stream_order_mfd(frac_da_np, fa_da_np, threshold=1, + method=method) + + frac_dask = xr.DataArray( + da.from_array(_LAZY_FRACS, chunks=(8, 2, 2)), + dims=['neighbor', 'y', 'x']) + # flow_accum chunked 3x3 while fractions are 2x2 -- the lazy assembly + # must realign it onto the fractions' tile grid. + fa_dask = xr.DataArray( + da.from_array(_LAZY_ACCUM, chunks=(3, 3)), + dims=['y', 'x']) + dask_result = stream_order_mfd(frac_dask, fa_dask, threshold=1, + method=method) + + np.testing.assert_array_equal( + np.nan_to_num(np_result.values, nan=-999), + np.nan_to_num(dask_result.values, nan=-999)) + + +@dask_array_available +@pytest.mark.parametrize('method,kernel', [ + ('strahler', '_strahler_mfd_tile_kernel'), + ('shreve', '_shreve_mfd_tile_kernel'), +]) +def test_dask_assembly_is_lazy(monkeypatch, method, kernel): + """Building the output raster must be deferred to compute time (#2885).""" + import importlib + import dask.array as da + mod = importlib.import_module('xrspatial.hydro.stream_order_mfd') + + counter = {'n': 0} + orig = getattr(mod, kernel) + + def _spy(*args, **kwargs): + counter['n'] += 1 + return orig(*args, **kwargs) + + monkeypatch.setattr(mod, kernel, _spy) + + frac_dask = xr.DataArray( + da.from_array(_LAZY_FRACS, chunks=(8, 2, 2)), + dims=['neighbor', 'y', 'x']) + fa_dask = xr.DataArray( + da.from_array(_LAZY_ACCUM, chunks=(2, 2)), + dims=['y', 'x']) + + result = stream_order_mfd(frac_dask, fa_dask, threshold=1, method=method) + # The convergence sweep runs eagerly, but assembling the result must not. + calls_after_call = counter['n'] + result.data.compute() + assert counter['n'] - calls_after_call > 0 + + # ==================================================================== # Memory guard tests # ====================================================================