TNO Intern

Commit eac27672 authored by Florian Knappers's avatar Florian Knappers
Browse files

add dask unit tests

parent 20e3c6dd
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ from pathlib import Path

RESOURCES_FOLDER = Path(__file__).parent / "resources"
THERMOGIS_JAR_PATH = RESOURCES_FOLDER / "thermogis-1.7.0-pythermogis.jar"

JVM17_PATH = RESOURCES_FOLDER / "JVM17"

os_name = platform.system()

tests/test_dask.py

0 → 100644
+100 −0
Original line number Diff line number Diff line
import pytest
import numpy as np
import xarray as xr

from pythermogis.dask import auto_chunk_dataset


@pytest.fixture
def small_dataarray():
    """10×10 DataArray – total 100 elements."""
    return xr.DataArray(np.arange(100).reshape(10, 10), dims=["x", "y"])


@pytest.fixture
def large_dataarray():
    """100×100×100 DataArray – total 1 000 000 elements."""
    return xr.DataArray(np.random.rand(100, 100, 100), dims=["x", "y", "z"])


@pytest.fixture
def small_dataset():
    """Dataset with two variables over a 20×30 grid."""
    return xr.Dataset(
        {
            "temp": (["x", "y"], np.random.rand(20, 30)),
            "pressure": (["x", "y"], np.random.rand(20, 30)),
        }
    )


@pytest.fixture
def large_dataset():
    """Dataset with two variables over a 50×50×50 grid."""
    return xr.Dataset(
        {
            "a": (["x", "y", "z"], np.random.rand(50, 50, 50)),
            "b": (["x", "y", "z"], np.random.rand(50, 50, 50)),
        }
    )


def _max_chunk_size(chunked: xr.DataArray | xr.Dataset) -> int:
    """Return the number of elements in the largest chunk."""
    if isinstance(chunked, xr.Dataset):
        var = next(iter(chunked.data_vars))
        chunks = chunked[var].chunks
    else:
        chunks = chunked.chunks
    return int(np.prod([max(c) for c in chunks]))


def test_no_chunking_when_total_le_target_dataarray(small_dataarray):
    result = auto_chunk_dataset(small_dataarray, target_chunk_size=200)
    assert result is small_dataarray


def test_no_chunking_when_total_eq_target_dataarray(small_dataarray):
    result = auto_chunk_dataset(small_dataarray, target_chunk_size=100)
    assert result is small_dataarray


def test_no_chunking_when_total_le_target_dataset(small_dataset):
    result = auto_chunk_dataset(small_dataset, target_chunk_size=10_000)
    assert result is small_dataset


def test_chunk_size_within_target_dataarray(large_dataarray):
    target = 10_000
    result = auto_chunk_dataset(large_dataarray, target_chunk_size=target)
    assert _max_chunk_size(result) <= target


def test_chunk_size_within_target_dataset(large_dataset):
    target = 5_000
    result = auto_chunk_dataset(large_dataset, target_chunk_size=target)
    assert _max_chunk_size(result) <= target


def test_chunk_size_reasonable_lower_bound(large_dataarray):
    """Greedy halving should not over-reduce chunks beyond target/8."""
    target = 10_000
    result = auto_chunk_dataset(large_dataarray, target_chunk_size=target)
    assert _max_chunk_size(result) >= target // 8


# Edge cases
def test_very_small_target_does_not_crash(large_dataarray):
    result = auto_chunk_dataset(large_dataarray, target_chunk_size=1)
    assert _max_chunk_size(result) == 1


def test_very_small_target_dataset(large_dataset):
    result = auto_chunk_dataset(large_dataset, target_chunk_size=1)
    assert _max_chunk_size(result) == 1


def test_1d_dataarray():
    da = xr.DataArray(np.arange(1000), dims=["x"])
    result = auto_chunk_dataset(da, target_chunk_size=100)
    assert _max_chunk_size(result) <= 100
 No newline at end of file