Why does the amount of dask task increase with an "optimized" chunking compared to a "basic" chunking schema?

I'm trying to understand how different chunking schemas can speed up or slow down my computation using xarray and dask.
I have read dask and xarray guides but I might have missed something to understand this.
I have 2 storage with the same content but chunked differently.
Both contains a data variable tasmax and the necessary coordinate variables and metadata for it to be opened with xarray.
tasmax shape is <xarray.DataArray 'tasmax' (time: 3660, lat: 256, lon: 512)>
The first storage is a zarr store zarr_init which I made from netCDF files, 1 file per year, 10 .nc files.
When opening it with xarray I get a chunking schema of chunksize=(366, 256, 512), thus 1 year per chunk, same as the initial netCDF storage.
Each chunk is around 191MB.
The second storage, zarr_time_opti is also a zarr store but, there is no chunking on time dimension.
When I open it with xarray and inspect tasmax, it's chunking schema is chunksize=(3660, 114, 115).
Each chunk is around 191MB as well.
Naively, I would expect spatially independent computations to run much faster and to generate much fewer tasks on zarr_time_opti than on zarr_init.
However, I observe the complete opposite:
When computing the same calculus based on groupby("time.month"), I get 2370 tasks with zarr_time_opti and only 570 tasks with zarr_init. As you can see with the MRE below, this has nothing to do with zarr itself as I'm able to reproduce the issue with only xarray and dask.
So my questions are:
What is the mechanism with xarray or dask which create that many tasks ?
Then, what would be the strategy to find the best chunking schema ?
def simple_climate_index(da):
import time
time_start = time.perf_counter()
# computations
res =( da.groupby("time.month") - da.groupby("time.month").mean("time")).compute()
# summer_days = (da > 25).resample(time="MS").sum().compute()
time_elapsed = time.perf_counter() - time_start
print(f"wall time: {time_elapsed} secs")
def mre_so():
import distributed
import pandas as pd
import numpy as np
client = distributed.Client(memory_limit="16GB", n_workers=1, threads_per_worker=4)
tasmax = xr.DataArray(
data=np.empty((3660, 256, 512), dtype=float),
dims=["time", "lat", "lon"],
time=pd.date_range("2042-01-01", periods=3660, freq="D"),
attrs={"units": "degC"},
da_optimized = tasmax.copy(deep=True).chunk(dict(time=-1, lat=114, lon=115))
# wall time: ~47 secs - 2370 tasks (observed on client)
da_init = tasmax.copy(deep=True).chunk(dict(time=366, lat=-1, lon=-1))
# wall time: ~37 secs - 570 tasks (observed on client)
if __name__ == "__main__":
zarr_time_opti is obtained by rechunking zarr_init with rechunker, a library to efficiently rewrite to different chunking schemas.
In reality, I'm doing time series analyses by computing (for example) the 90th daily percentile over 30 years on each pixel and then computing the exceedance rate of tasmax compare to this percentile on each pixel again.
And in this case, using ~100 years I get around 2000 tasks when time is chunked and around 85000 when time is not chunked.

What is the mechanism with xarray or dask which create that many tasks ?
In the case of da_optimized, you seem to be chunking along both lat and lon dimensions, and in da_init, you're chunking along only the time dimension.
When you do a compute, in the beginning, each task will correspond to one chunk.
Sidenotes about your specific example:
da_optimized starts with 15 chunks and da_init with 10, this is adding to fewer overall tasks in da_init. So, to balance them, I've modified it to be:
da_optimized = tasmax.copy(deep=True).chunk(dict(time=-1, lat=128, lon=103))
While executing, xarray shows this warning: PerformanceWarning: Slicing with an out-of-order index is generating 11 times more chunks. So, I've simplifies the comupation in simple_climate_index to be:
res = da.groupby("time.month").mean("time").compute()
The best chunking technique would depend on what operation you're doing.
For a groupby operation, commonly seen in pandas, I can see why da_init has fewer tasks, and is faster. All the lat+lon data is conserved within a chunk for any given timestamp. (Moreover, Dask can optimize the number of chunks based on groups in this case. For example, you're grouping-by month, so even if you start with 100 chunks, you'll end up with 12 groups, which can potentially be stored as one-group-per-chunk, so, 12 chunks in total. I'm not sure if xarray actually does this optimization, I'm just saying it's possible.)
In da_optimized, a groupby will require communication between chunks because the lat+lon data are spread across different chunks, which will result in more tasks, and therefore a performance penalty.
Here are the (task) graph visualizations for both operations:
Then, what would be the strategy to find the best chunking schema ?
Since you're doing the groupby() on "time", the task graph would be most efficient if you chunk along the same (time) dimension.


