Part III. xbatcher#

Author: Eli Holmes (NOAA)

Colab Badge JupyterHub Badge Download Badge

What is xbatcher?#

xbatcher is a Python library that efficiently generates mini-batches from large xarray datasets. It allows you to stream spatial and temporal patches directly from Dask-backed data into machine learning models without needing to load everything into memory.

When we are interacting with data in a cloud bucket, we want to study the chunking of the data so that we minimize the I/O load. Hopefully, the data are chunked (into files) in such a way that we are not accessing a lot of files. Perferably we want to be doing as little I/O (file reading) as we can while still not running out of memory.

Key Elements of Using xbatcher#

  • Batch Generator to get little xarrays

  • Function to turn the little xarrays into numpy arrays

  • Link to TensorFlow

1. A Batch Generator

You create a generator that yields individual samples (as little xarray Datasets) from the full dataset. It needs to know the size of batches and overlap.

Example:
input_dims = {"time": 20, "lat": 64, "lon": 64}
Each sample will have shape (20, 64, 64) before stacking variables. Unfortunately, I chunked the Zarr file in 100 day chunks, so using 20 day chunks in xbatches means I have to download a lot more data (80 extra days) than I want, but my validation data had 37 days so I am limited by that.

input_overlap controls the amount of overlap between samples along each dimension. Often set to 0 (like I am), but including overlap can help reduce edge effects in CNNs.

from xbatcher import BatchGenerator

train_gen = BatchGenerator(
    dataset[input_vars + output_vars],
    input_dims=input_dims,
    input_overlap=input_overlap
)

2. Create a function to give numpy arrays from your batches (little xarrays)

By default, xarray and dask use lazy, chunked arrays (e.g., dask.array or xarray.DataArray). TensorFlow cannot ingest those directly — it needs fully materialized data per training sample.

TensorFlow’s from_generator expects the generator to yield concrete values, typically as np.ndarray (NumPy arrays) (though other formats are possible that can convert to tf.Tensor).

Here is our little function that takes a batch (little xarray) and returns a stacked numpy array.

def train_gen_tf_batches():
    for batch in train_gen:
        time_len = batch["y"].sizes["time"]
        for t in range(time_len):
            x = np.stack([
                batch[var].isel(time=t).data if "time" in batch[var].dims else batch[var].data
                for var in input_vars
            ], axis=-1).astype(np.float32)  # (xx, xx, n_features)
            y = batch["y"].isel(time=t).data[..., np.newaxis].astype(np.float32)  # (xx, xx, 1)
            yield x, y

5. Integration with TensorFlow

Wraps the generator into a TensorFlow pipeline, so it can:

  • Automatically batch and prefetch data,

  • Stream from disk or cloud lazily (via Dask)

  • Integrate with model.fit()

train_dataset = train_dataset.shuffle(100).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

Load the libraries#

# Colab Users!!!
# Uncomment the pip line below and run if you are in Colab
# Leave in the !. That is part of the cmd
# !pip install zarr gcsfs xbatcher --quiet
# --- Core data handling libraries ---
import xarray as xr       # for working with labeled multi-dimensional arrays
import numpy as np        # for numerical operations on arrays
import dask.array as da   # for lazy, parallel array operations (used in xarray backends)

# --- Plotting ---
import matplotlib.pyplot as plt  # for creating plots

import xbatcher

# --- TensorFlow setup ---
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # suppress TensorFlow log spam (0=all, 3=only errors)

import tensorflow as tf  # main deep learning framework

# --- Keras (part of TensorFlow): building and training neural networks ---
from keras.models import Sequential          # lets us stack layers in a simple linear model
from keras.layers import Conv2D              # 2D convolution layer — finds spatial patterns in image-like data
from keras.layers import BatchNormalization  # stabilizes and speeds up training by normalizing activations
from keras.layers import Dropout             # randomly "drops" neurons during training to reduce overfitting
from keras.callbacks import EarlyStopping    # stops training early if validation loss doesn't improve

See what machine we are on#

# list all the physical devices
physical_devices = tf.config.list_physical_devices()
print("All Physical Devices:", physical_devices)

# list all the available GPUs
gpus = tf.config.list_physical_devices('GPU')
print("Available GPUs:", gpus)

# Print infomation for available GPU if there exists any
if gpus:
    for gpu in gpus:
        details = tf.config.experimental.get_device_details(gpu)
        print("GPU Details:", details)
else:
    print("No GPU available")
All Physical Devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Available GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU Details: {'compute_capability': (7, 5), 'device_name': 'Tesla T4'}

Load data#

I created the data for Part I in Data_Prep_Part_2. Here I will load.

  • sst, so, sin_time, cos_time

  • ocean_mask

  • y (CHL)

open_dataset versus open_zarr#

The data in the Google Object Storage is a chunked Zarr file. It has files that are chunks of 100 days of data. When we go to get data we will get these files. To do this efficiently, our xarray dataset needs to know what these chunks are. How we load our data will determine if our xarray dataset “knows” about the chunks and this will make a big difference to speed and memory usage.

open_dataset will load our data but it won’t be dask arrays. So when we do computations, the dask machinary to do chunked processing will not be triggered. Effectively all the needed data (in this case 2020) will be loaded into memory. You will see that memory usage is high in this example. This is ok if all the training data can fit into memory. This will be the fastest approach.

# load full dataset 1997 to 2022 from Google bucket
dataset_od = xr.open_dataset(
    "gcs://nmfs_odp_nwfsc/CB/mind_the_chl_gap/cnn_tutorial",
    engine="zarr",
    backend_kwargs={"storage_options": {"token": "anon"}},
    consolidated=True
)
# nothing; no chunking
dataset_od.sel(time="2020")["sst"].chunks

open_zarr will load our data with the chunking information. Our variables will be dask arrays. When we do computations, the dask machinary to do chunked processing will be triggered. This will create a lot of CPU overhead as dask figures out how best to get the chunks. This means it will be slower; in the case of our example, about 5x slower. But it will be memory efficient. The needed data for a batch will be loaded and then released.

# load full dataset 1997 to 2022 from Google bucket
# chunked into 100 day chunks (64, 64 lat/lon)
dataset_oz = xr.open_zarr(
    "gcs://nmfs_odp_nwfsc/CB/mind_the_chl_gap/cnn_tutorial", 
    storage_options={"token": "anon"}, 
    consolidated=True
)
# chunked; some are partial
dataset_oz.sel(time="2020")["sst"].chunks
((89, 100, 100, 77), (64, 64, 21), (64, 64, 53))
# will use the chunked data because we are going to use that for bigger data
dataset = dataset_oz
dataset
<xarray.Dataset> Size: 5GB
Dimensions:     (time: 9207, lat: 149, lon: 181)
Coordinates:
  * lat         (lat) float32 596B 32.0 31.75 31.5 31.25 ... -4.5 -4.75 -5.0
  * lon         (lon) float32 724B 45.0 45.25 45.5 45.75 ... 89.5 89.75 90.0
  * time        (time) datetime64[ns] 74kB 1997-10-01 1997-10-02 ... 2022-12-31
Data variables:
    cos_time    (time, lat, lon) float32 993MB dask.array<chunksize=(100, 64, 64), meta=np.ndarray>
    ocean_mask  (lat, lon) float32 108kB dask.array<chunksize=(64, 64), meta=np.ndarray>
    sin_time    (time, lat, lon) float32 993MB dask.array<chunksize=(100, 64, 64), meta=np.ndarray>
    so          (time, lat, lon) float32 993MB dask.array<chunksize=(100, 64, 64), meta=np.ndarray>
    sst         (time, lat, lon) float32 993MB dask.array<chunksize=(100, 64, 64), meta=np.ndarray>
    topo        (lat, lon) float32 108kB dask.array<chunksize=(64, 64), meta=np.ndarray>
    y           (time, lat, lon) float32 993MB dask.array<chunksize=(100, 64, 64), meta=np.ndarray>

Set up our train, validation and test datasets for xbatcher#

xbatcher requires that we return xarray Datasets, like dataset above, with our input variables (sst, so, etc) and y ready for sending to TensorFlow.

import numpy as np
import dask.array as da
import xarray as xr

def time_series_split_for_xbatcher(
    data, num_var, cat_var=None, split_ratio=(0.7, 0.2, 0.1), seed=42,
    X_mean=None, X_std=None
):
    """
    Splits time indices randomly into train/val/test for xbatcher.
    Replaces NaNs, and normalizes numerical variables if mean/std are provided.

    Parameters:
        data: xarray.Dataset with 'time' dimension
        num_var: list of numerical variable names
        cat_var: list of categorical variable names (no normalization)
        split_ratio: tuple of train, val, test split ratios
        seed: random seed
        X_mean, X_std: optional normalization stats

    Returns:
        train_ds, val_ds, test_ds: xarray datasets ready for xbatcher
        X_mean, X_std: used for normalization (or None, None if not applied)
    """
    if cat_var is None:
        cat_var = []
    input_var = num_var + cat_var
    output_var = ["y"]

    time_dim = "time"
    if time_dim not in data.dims:
        raise ValueError("Dataset must contain a 'time' dimension.")

    time_len = data.sizes[time_dim]
    rng = np.random.default_rng(seed)
    all_indices = rng.choice(time_len, size=time_len, replace=False)

    train_end = int(split_ratio[0] * time_len)
    val_end = int((split_ratio[0] + split_ratio[1]) * time_len)
    train_idx = np.sort(all_indices[:train_end])
    val_idx = np.sort(all_indices[train_end:val_end])
    test_idx = np.sort(all_indices[val_end:])

    train_data = data.isel(time=train_idx)

    # Compute normalization stats if not provided
    if num_var:
        if X_mean is None or X_std is None:
            stacked = da.stack([train_data[v].data for v in num_var], axis=-1)
            X_mean = da.nanmean(stacked, axis=(0, 1, 2)).compute()
            X_std = da.nanstd(stacked, axis=(0, 1, 2)).compute()
        X_std_safe = da.where(X_std == 0, 1.0, X_std)

    def normalize_and_fill(ds):
        ds_copy = ds.copy()
        for i, var in enumerate(num_var):
            v = (ds[var] - X_mean[i]) / X_std_safe[i]
            ds_copy[var] = xr.DataArray(
                da.nan_to_num(v.data), dims=ds[var].dims, coords=ds[var].coords
            )
        for var in cat_var + output_var:
            ds_copy[var] = xr.DataArray(
                da.nan_to_num(ds[var].data), dims=ds[var].dims, coords=ds[var].coords
            )
        return ds_copy

    train_ds = normalize_and_fill(data.isel(time=train_idx))
    val_ds = normalize_and_fill(data.isel(time=val_idx))
    test_ds = normalize_and_fill(data.isel(time=test_idx))
    full_ds = normalize_and_fill(data)

    return train_ds, val_ds, test_ds, full_ds, X_mean, X_std

Set up xbatcher#

from xbatcher import BatchGenerator

# Define variables
input_vars = ["sst", "so", "sin_time", "cos_time", "ocean_mask"]
output_vars = ["y"]
num_var = ["sst", "so"]
cat_var = ["sin_time", "cos_time", "ocean_mask"]

# Use the whole field
input_dims = {"time": 30, "lat": 149, "lon": 181}
input_overlap = {"time": 0, "lat": 0, "lon": 0}

# Split the dataset using your function
train_ds, val_ds, test_ds, full_ds, X_mean_used, X_std_used = time_series_split_for_xbatcher(
    data=dataset.sel(time="2020"),
    num_var=num_var,
    cat_var=cat_var,
)

# Create batch generators
train_gen = BatchGenerator(train_ds[input_vars + output_vars], input_dims=input_dims, input_overlap=input_overlap)
val_gen = BatchGenerator(val_ds[input_vars + output_vars], input_dims=input_dims, input_overlap=input_overlap)
test_gen = BatchGenerator(test_ds[input_vars + output_vars], input_dims=input_dims, input_overlap=input_overlap)
X_mean
array([301.79614 ,  35.002277], dtype=float32)

Set up generator functions to create numpy arrays#

# Generator: yield individual time slices for batches
input_shape = (input_dims["lat"], input_dims["lon"], len(input_vars))
output_shape = (input_dims["lat"], input_dims["lon"], 1)

def train_gen_tf_batches():
    for batch in train_gen:
        time_len = batch["y"].sizes["time"]
        for t in range(time_len):
            x = np.stack([
                batch[var].isel(time=t).data if "time" in batch[var].dims else batch[var].data
                for var in input_vars
            ], axis=-1).astype(np.float32)  # (xx, xx, n_features)
            y = batch["y"].isel(time=t).data[..., np.newaxis].astype(np.float32)  # (xx, xx, 1)
            yield x, y

def val_gen_tf_batches():
    for batch in val_gen:
        time_len = batch["y"].sizes["time"]
        for t in range(time_len):
            x = np.stack([
                batch[var].isel(time=t).data if "time" in batch[var].dims else batch[var].data
                for var in input_vars
            ], axis=-1).astype(np.float32)
            y = batch["y"].isel(time=t).data[..., np.newaxis].astype(np.float32)
            yield x, y


train_dataset_from_gen = tf.data.Dataset.from_generator(
    train_gen_tf_batches,
    output_signature=(
        tf.TensorSpec(shape=input_shape, dtype=tf.float32),
        tf.TensorSpec(shape=output_shape, dtype=tf.float32),
    )
)

val_dataset_from_gen = tf.data.Dataset.from_generator(
    val_gen_tf_batches,
    output_signature=(
        tf.TensorSpec(shape=input_shape, dtype=tf.float32),
        tf.TensorSpec(shape=output_shape, dtype=tf.float32),
    )
)

Final prep of dataset for TensorFlow#

We are not trying to preserve temporal information since we are predicting chlorophyll from same day SST and salinity. So we shuffle() to make sure that everything is i.i.d. for TensorFlow.

  • shuffle() adds randomness within the training batches

  • Prevents batch-to-batch correlation

  • Improves model convergence and generalization

The max shuffle would be the length of the training data per batch up to about 1000-2000. But if the dataset if very large, that would be a lot of overhead and not necessary. In the pper end 10 x number of batches. In our case, the generator yields one sample per time step, so the total training samples ≈ len(train_gen) × input_dims["time"]

print("Number of batches in training set:", len(train_gen))
print("Max shuffle size:", len(train_gen) * input_dims["time"])
Number of batches in training set: 12
Max shuffle size: 240

Using .repeat(). I had to add this when using a generator and specify the training steps per batch. Otherwise, TensorFlow was struggling during the first pass to figure out how much data to use and wasn’t resetting the generator (to give a new set of data) properly.

# you might need to tweak this
BATCH_SIZE = 8
SHUFFLE_N = 100
train_dataset = train_dataset_from_gen.shuffle(SHUFFLE_N).batch(BATCH_SIZE).repeat().prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset_from_gen.batch(BATCH_SIZE).repeat().prefetch(tf.data.AUTOTUNE)
%%time
# check that shape is (BATCH_SIZE, 149, 181, 5)
for x, y in train_dataset.take(1):
    print("Train x shape:", x.shape)
    print("Train y shape:", y.shape)
Train x shape: (8, 149, 181, 5)
Train y shape: (8, 149, 181, 1)
CPU times: user 2.22 s, sys: 353 ms, total: 2.57 s
Wall time: 2.4 s

Let’s build the model#

We build a simple 3-layer CNN model. Each layer preserves the (lat, lon) shape and learns filters to extract spatial patterns.

from keras.models import Sequential
from keras.layers import Input, Conv2D, BatchNormalization, Dropout

def create_model_CNN(input_shape):
    """
    Create a simple 3-layer CNN model for gridded ocean data.

    Parameters
    ----------
    input_shape : tuple
        The shape of each sample, e.g., (149, 181, 2)

    Returns
    -------
    model : keras.Model
        CNN model to predict CHL from SST and salinity
    """
    model = Sequential()

    # Input layer defines the input dimensions for the CNN
    model.add(Input(shape=input_shape))

    # Layer 1 — learns fine-scale 3x3 spatial features
    # Let the model learn 64 different patterns (filters) in the data at this layer.
    # activation relu is non-linearity
    model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    # Layer 2 — expands context to 5x5; combines fine features into larger structures
    # Reduce the number of patterns (filters) so we gradually reduce model complexity
    model.add(Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    # Layer 3 — has access to ~7x7 neighborhood; outputs CHL prediction per pixel
    # Combines all the previous layer’s features into a CHL estimate at each pixel
    # 1 response (chl) — hence, 1 prediction pixel = filter
    # linear since predicting a real continuous variable (log CHL)
    model.add(Conv2D(filters=1, kernel_size=(3, 3), padding='same', activation='linear'))

    return model

Let’s train the model#

What is an Epoch?#

An epoch is one full pass through the entire training dataset. Each step is a day since model is trained one day at a time. Number of steps (days) in an epoch is (training days) / number of batches.

What Happens in One Epoch?#

1. Training Phase (train_dataset)#

  • TensorFlow iterates over batches from train_dataset.

  • For each batch:

    • Performs a forward pass (makes predictions),

    • Computes the loss (e.g., Mean Absolute Error),

    • Runs a backward pass (computes gradients),

    • Updates weights using the optimizer (e.g., Adam).

2. Validation Phase (val_dataset)#

  • After all training batches are processed in the epoch:

    • The model is evaluated on val_dataset.

    • This step does not update weights.

  • Metrics like val_loss and val_mae are calculated to monitor performance on unseen data.

Why Use val_dataset?#

  • The validation dataset does not participate in training.

  • It provides a consistent benchmark at the end of each epoch.

  • It’s used to:

    • Track overfitting (if validation loss starts increasing),

    • Support callbacks like EarlyStopping,

    • Save best model checkpoints during training.

model = create_model_CNN(input_shape)
model.compile(optimizer='adam', loss='mae', metrics=['mae'])

# Set up early stopping to prevent overfitting
early_stop = EarlyStopping(
    patience=10,              # Stop if validation loss doesn't improve for 10 epochs
    restore_best_weights=True  # Revert to the model weights from the best epoch
)

TRAIN_STEPS_PER_EPOCH = len(train_gen) * input_dims["time"] // BATCH_SIZE
VAL_STEPS_PER_EPOCH = len(val_gen) * input_dims["time"] // BATCH_SIZE

history = model.fit(
    train_dataset,
    epochs=50,                    # Maximum number of training epochs
    steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
    validation_data=val_dataset, # Use validation data during training
    validation_steps=VAL_STEPS_PER_EPOCH,
    callbacks=[early_stop],      # Stop early if no improvement
)
Epoch 1/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 11s 197ms/step - loss: 1.2570 - mae: 1.2570 - val_loss: 0.8432 - val_mae: 0.8432
Epoch 2/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 4s 133ms/step - loss: 0.8724 - mae: 0.8724 - val_loss: 0.6911 - val_mae: 0.6911
Epoch 3/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 113ms/step - loss: 0.5442 - mae: 0.5442 - val_loss: 0.7178 - val_mae: 0.7178
Epoch 4/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 79ms/step - loss: 0.4325 - mae: 0.4325 - val_loss: 0.6082 - val_mae: 0.6082
Epoch 5/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 72ms/step - loss: 0.3795 - mae: 0.3795 - val_loss: 0.5682 - val_mae: 0.5682
Epoch 6/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 87ms/step - loss: 0.3577 - mae: 0.3577 - val_loss: 0.5128 - val_mae: 0.5128
Epoch 7/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 73ms/step - loss: 0.3387 - mae: 0.3387 - val_loss: 0.4746 - val_mae: 0.4746
Epoch 8/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 72ms/step - loss: 0.3216 - mae: 0.3216 - val_loss: 0.4448 - val_mae: 0.4448
Epoch 9/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 71ms/step - loss: 0.3045 - mae: 0.3045 - val_loss: 0.4459 - val_mae: 0.4459
Epoch 10/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 73ms/step - loss: 0.3065 - mae: 0.3065 - val_loss: 0.3755 - val_mae: 0.3755
Epoch 11/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 90ms/step - loss: 0.2940 - mae: 0.2940 - val_loss: 0.3583 - val_mae: 0.3583
Epoch 12/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 71ms/step - loss: 0.2977 - mae: 0.2977 - val_loss: 0.3209 - val_mae: 0.3209
Epoch 13/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 71ms/step - loss: 0.2881 - mae: 0.2881 - val_loss: 0.3045 - val_mae: 0.3045
Epoch 14/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 78ms/step - loss: 0.2781 - mae: 0.2781 - val_loss: 0.2825 - val_mae: 0.2825
Epoch 15/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 83ms/step - loss: 0.2689 - mae: 0.2689 - val_loss: 0.2641 - val_mae: 0.2641
Epoch 16/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 86ms/step - loss: 0.2675 - mae: 0.2675 - val_loss: 0.2661 - val_mae: 0.2661
Epoch 17/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 80ms/step - loss: 0.2620 - mae: 0.2620 - val_loss: 0.2531 - val_mae: 0.2531
Epoch 18/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 75ms/step - loss: 0.2720 - mae: 0.2720 - val_loss: 0.2343 - val_mae: 0.2343
Epoch 19/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 76ms/step - loss: 0.2743 - mae: 0.2743 - val_loss: 0.2359 - val_mae: 0.2359
Epoch 20/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 73ms/step - loss: 0.2731 - mae: 0.2731 - val_loss: 0.2202 - val_mae: 0.2202
Epoch 21/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 90ms/step - loss: 0.2669 - mae: 0.2669 - val_loss: 0.2119 - val_mae: 0.2119
Epoch 22/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 4s 75ms/step - loss: 0.2654 - mae: 0.2654 - val_loss: 0.2250 - val_mae: 0.2250
Epoch 23/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 72ms/step - loss: 0.2664 - mae: 0.2664 - val_loss: 0.2170 - val_mae: 0.2170
Epoch 24/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 72ms/step - loss: 0.2560 - mae: 0.2560 - val_loss: 0.2241 - val_mae: 0.2241
Epoch 25/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 4s 74ms/step - loss: 0.2589 - mae: 0.2589 - val_loss: 0.2124 - val_mae: 0.2124
Epoch 26/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 84ms/step - loss: 0.2577 - mae: 0.2577 - val_loss: 0.1975 - val_mae: 0.1975
Epoch 27/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 72ms/step - loss: 0.2628 - mae: 0.2628 - val_loss: 0.1939 - val_mae: 0.1939
Epoch 28/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 72ms/step - loss: 0.2536 - mae: 0.2536 - val_loss: 0.2044 - val_mae: 0.2044
Epoch 29/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 72ms/step - loss: 0.2496 - mae: 0.2496 - val_loss: 0.2164 - val_mae: 0.2164
Epoch 30/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 4s 69ms/step - loss: 0.2608 - mae: 0.2608 - val_loss: 0.2106 - val_mae: 0.2106
Epoch 31/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 73ms/step - loss: 0.2609 - mae: 0.2609 - val_loss: 0.2064 - val_mae: 0.2064
Epoch 32/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 71ms/step - loss: 0.2470 - mae: 0.2470 - val_loss: 0.2170 - val_mae: 0.2170
Epoch 33/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 4s 74ms/step - loss: 0.2513 - mae: 0.2513 - val_loss: 0.2007 - val_mae: 0.2007
Epoch 34/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 87ms/step - loss: 0.2447 - mae: 0.2447 - val_loss: 0.2045 - val_mae: 0.2045
Epoch 35/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 73ms/step - loss: 0.2615 - mae: 0.2615 - val_loss: 0.2008 - val_mae: 0.2008
Epoch 36/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 80ms/step - loss: 0.2526 - mae: 0.2526 - val_loss: 0.2057 - val_mae: 0.2057
Epoch 37/50
30/30 ━━━━━━━━━━━━━━━━━━━━ 5s 72ms/step - loss: 0.2503 - mae: 0.2503 - val_loss: 0.2065 - val_mae: 0.2065

Plot training & validation loss values#

plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.grid(True)
plt.show()
../_images/8fb6e811d83b584672699df3d8845a5102c14356bff1f0ace4d9ba1885ea938e.png

Make some maps of our predictions#

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# Example: date to predict
date_to_predict = np.datetime64("2020-09-02")

# Extract dataset slice for this time
ds_at_time = full_ds.sel(time=date_to_predict)

# Prepare input (shape = [lat, lon, n_features])
input_data = np.stack([
    ds_at_time[var].values for var in input_vars
], axis=-1)  # shape = (lat, lon, n_features)

# Add batch dimension
input_data = input_data[np.newaxis, ...]  # shape = (1, lat, lon, n_features)

# Predict
predicted_output = model.predict(input_data)[0, ..., 0]  # shape = (lat, lon)

# True value from y
true_output = ds_at_time["y"].values  # shape = (lat, lon)

# Mask land (land_mask = ~ocean)
land_mask = ~dataset["ocean_mask"].values.astype(bool)
predicted_output[land_mask] = np.nan
true_output[land_mask] = np.nan

# Plot
vmin = np.nanmin([true_output, predicted_output])
vmax = np.nanmax([true_output, predicted_output])

plt.imshow(true_output, vmin=vmin, vmax=vmax, cmap='viridis')
plt.colorbar()
plt.title(f"True CHL on {date_to_predict}")
plt.show()

plt.imshow(predicted_output, vmin=vmin, vmax=vmax, cmap='viridis')
plt.colorbar()
plt.title(f"Predicted CHL on {date_to_predict}")
plt.show()
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 149ms/step
../_images/a653b24a3f9faadc1874a9547842fa062bab33d2b4e00debe57f3148a3903690.png ../_images/ccf0a997e86a3a4814f8bbdf0494064fc00c76410983b80b98a6c541657c5f7c.png

Let’s look at all the months#

I will create a function to do this.

import numpy as np
import matplotlib.pyplot as plt

def plot_true_vs_predicted(data, year, model, X_mean, X_std, num_var, cat_var):
    """
    Plot true vs predicted output for first available day of each month in data_test.

    Parameters:
        data (xarray.Dataset): Contains variables 'y', 'ocean_mask', and coords 'lat', 'lon', 'time'
        year (string in format XXXX): The year to use.
        model (tf.keras.Model): Trained model with a .predict() method
        X_mean (np.ndarray): mean from the model training data (num_vars only)
        X_std (np.ndarray): std from the model training data
        num_var (np.ndarry): The variables to be standardized with `X_mean` and `X_std`
        cat_var (np.ndarry): The variables to be included in X, y but not standardized.
    """
    # Split the dataset to get full_ds
    _, _, _, year_ds, _, _ = time_series_split_for_xbatcher(
        data=data.sel(time=year),
        num_var=num_var,
        cat_var=cat_var,
        X_mean=X_mean,
        X_std=X_std,
    )

    # Get available time points and group by month
    available_dates = pd.to_datetime(year_ds.time.values)
    monthly_dates = (
        pd.Series(available_dates)
        .groupby([available_dates.year, available_dates.month])
        .min()
        .sort_values()
    )
    n_months = len(monthly_dates)

    # lat/lon info
    lat = year_ds.lat.values
    lon = year_ds.lon.values
    extent = [lon.min(), lon.max(), lat.min(), lat.max()]
    flip_lat = lat[0] > lat[-1]
    land_mask = ~year_ds["ocean_mask"].values.astype(bool)

    # Create figure and axes
    fig, axs = plt.subplots(n_months, 2, figsize=(7, 2 * n_months), constrained_layout=True)

    for i, date in enumerate(monthly_dates):
        # Select dataset for this date
        ds_at_time = year_ds.sel(time=np.datetime64(date))

        # Prepare model input: stack input variables into (lat, lon, n_features)
        input_data = np.stack([
            ds_at_time[var].values for var in input_vars
        ], axis=-1)

        # Predict: shape (lat, lon)
        predicted_output = model.predict(input_data[np.newaxis, ...])[0, ..., 0]

        # True output
        true_output = data["y"].sel(time=np.datetime64(date)).values

        # Mask land
        predicted_output[land_mask] = np.nan
        true_output[land_mask] = np.nan

        # Flip latitude if needed
        if flip_lat:
            true_output = np.flipud(true_output)
            predicted_output = np.flipud(predicted_output)

        # Shared color scale
        vmin = np.nanpercentile([true_output, predicted_output], 5)
        vmax = np.nanpercentile([true_output, predicted_output], 95)

        # Compute R²
        from sklearn.metrics import r2_score
        true_flat = true_output.flatten()
        pred_flat = predicted_output.flatten()
        valid_mask = ~np.isnan(true_flat) & ~np.isnan(pred_flat)
        r2 = r2_score(true_flat[valid_mask], pred_flat[valid_mask])

        # Plot true
        axs[i, 0].imshow(true_output, origin='lower', extent=extent,
                         vmin=vmin, vmax=vmax, cmap='viridis',
                         aspect='equal')
        axs[i, 0].set_title(f"{date.strftime('%b')} — True", fontsize=10)
        axs[i, 0].axis('off')

        # Plot predicted with R²
        axs[i, 1].imshow(predicted_output, origin='lower', extent=extent,
                         vmin=vmin, vmax=vmax, cmap='viridis', 
                         aspect='equal')
        axs[i, 1].set_title(f"{date.strftime('%b')} — Pred\n$R^2$ = {r2:.2f}", fontsize=10)
        axs[i, 1].axis('off')

    plt.suptitle(f'CHL: True vs Predicted (log scale) — {year}', fontsize=16)
    plt.show()
plot_true_vs_predicted(dataset, "2020", model, X_mean, X_std, num_var, cat_var)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
../_images/133adaf4eb57adb4c24f70961354c63767eff120a916f4920031a208a8153c7d.png

Let’s look at a different year#

plot_true_vs_predicted(dataset, "2015", model, X_mean, X_std, num_var, cat_var)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
../_images/78e6f6648f154d509d3bb746c5cbdff63b788b0aaff96eec88feb8693e70a1bd.png

Summary#

So far what we have done in Part III is no different than Part II. But we have the structure now to start training on large data sets working in batches that our memory can handle. In Part IV, we will put it together and train on multiple years. Hopefully we can improve our out of sample (different year) forecasts compared to Part II.