Part IV. Training on multiple years (bigger than memory)#

Author: Eli Holmes (NOAA)

Colab Badge JupyterHub Badge Download Badge

Now we will put it all together and train on multiple years.

TO DO Get rid of days will lots of NaN in CHL. > 832

Load the libraries#

# Uncomment this line and run if you are in Colab; leave in the !. That is part of the cmd
# !pip install zarr gcsfs xbatcher --quiet
# --- Core data handling libraries ---
import xarray as xr       # for working with labeled multi-dimensional arrays
import numpy as np        # for numerical operations on arrays
import dask.array as da   # for lazy, parallel array operations (used in xarray backends)

# --- Plotting ---
import matplotlib.pyplot as plt  # for creating plots

import xbatcher

# --- TensorFlow setup ---
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # suppress TensorFlow log spam (0=all, 3=only errors)

import tensorflow as tf  # main deep learning framework

# --- Keras (part of TensorFlow): building and training neural networks ---
from keras.models import Sequential          # lets us stack layers in a simple linear model
from keras.layers import Conv2D              # 2D convolution layer — finds spatial patterns in image-like data
from keras.layers import BatchNormalization  # stabilizes and speeds up training by normalizing activations
from keras.layers import Dropout             # randomly "drops" neurons during training to reduce overfitting
from keras.callbacks import EarlyStopping    # stops training early if validation loss doesn't improve
2025-06-21 02:26:35.210701: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-21 02:26:35.228670: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-21 02:26:35.234116: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

See what machine we are on#

# list all the physical devices
physical_devices = tf.config.list_physical_devices()
print("All Physical Devices:", physical_devices)

# list all the available GPUs
gpus = tf.config.list_physical_devices('GPU')
print("Available GPUs:", gpus)

# Print infomation for available GPU if there exists any
if gpus:
    for gpu in gpus:
        details = tf.config.experimental.get_device_details(gpu)
        print("GPU Details:", details)
else:
    print("No GPU available")
All Physical Devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Available GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU Details: {'compute_capability': (7, 5), 'device_name': 'Tesla T4'}

Load data#

We definitely need to use open_zarr() so that our data is identified as dask arrays.

# load full dataset 1997 to 2022 from Google bucket
# chunked into 100 day chunks (64, 64 lat/lon)
dataset = xr.open_zarr(
    "gcs://nmfs_odp_nwfsc/CB/mind_the_chl_gap/cnn_tutorial", 
    storage_options={"token": "anon"}, 
    consolidated=True
)

Set up our train, validation and test datasets for xbatcher#

xbatcher requires that we return xarray Datasets, like dataset above, with our input variables (sst, so, etc) and y ready for sending to TensorFlow. We compute the normalization metrics (X_mean X_std) within it and return it for use later (e.g. for validation and test data).

import numpy as np
import dask.array as da
import xarray as xr

def time_series_split_for_xbatcher(
    data, num_var, cat_var=None, split_ratio=(0.7, 0.2, 0.1), seed=42,
    X_mean=None, X_std=None, sample_frac=None
):
    """
    Splits time indices randomly into train/val/test for xbatcher.
    Replaces NaNs, and normalizes numerical variables if mean/std are provided.
    Downsampling happens before normalization so stats match the final training data.

    Parameters:
        data: xarray.Dataset with 'time' dimension
        num_var: list of numerical variable names
        cat_var: list of categorical variable names (no normalization)
        split_ratio: tuple of train, val, test split ratios
        seed: random seed
        X_mean, X_std: optional normalization stats
        sample_frac: optional float (0, 1] to randomly downsample splits

    Returns:
        train_ds, val_ds, test_ds, full_ds: datasets ready for xbatcher
        X_mean, X_std: used for normalization
    """
    if cat_var is None:
        cat_var = []
    output_var = ["y"]

    time_dim = "time"
    if time_dim not in data.dims:
        raise ValueError("Dataset must contain a 'time' dimension.")

    time_len = data.sizes[time_dim]
    rng = np.random.default_rng(seed)
    all_indices = rng.choice(time_len, size=time_len, replace=False)

    train_end = int(split_ratio[0] * time_len)
    val_end = int((split_ratio[0] + split_ratio[1]) * time_len)
    train_idx = np.sort(all_indices[:train_end])
    val_idx = np.sort(all_indices[train_end:val_end])
    test_idx = np.sort(all_indices[val_end:])

    def downsample_idx(idx):
        if sample_frac is not None and len(idx) > 0:
            n_sample = max(1, int(sample_frac * len(idx)))
            return np.sort(rng.choice(idx, size=n_sample, replace=False))
        return idx

    train_idx = downsample_idx(train_idx)
    val_idx = downsample_idx(val_idx)
    test_idx = downsample_idx(test_idx)

    train_data = data.isel(time=train_idx)

    # Compute normalization stats if not provided
    if num_var:
        if X_mean is None or X_std is None:
            stacked = da.stack([train_data[v].data for v in num_var], axis=-1)
            X_mean = da.nanmean(stacked, axis=(0, 1, 2)).compute()
            X_std = da.nanstd(stacked, axis=(0, 1, 2)).compute()
        X_std_safe = da.where(X_std == 0, 1.0, X_std)

    def normalize_and_fill(ds):
        ds_copy = ds.copy()
        for i, var in enumerate(num_var):
            v = (ds[var] - X_mean[i]) / X_std_safe[i]
            ds_copy[var] = xr.DataArray(
                da.nan_to_num(v.data), dims=ds[var].dims, coords=ds[var].coords
            )
        for var in cat_var + output_var:
            ds_copy[var] = xr.DataArray(
                da.nan_to_num(ds[var].data), dims=ds[var].dims, coords=ds[var].coords
            )
        return ds_copy

    train_ds = normalize_and_fill(data.isel(time=train_idx))
    val_ds = normalize_and_fill(data.isel(time=val_idx))
    test_ds = normalize_and_fill(data.isel(time=test_idx))
    full_ds = normalize_and_fill(data)

    return train_ds, val_ds, test_ds, full_ds, X_mean, X_std

Set up train, validation and test years#

I will use different years for each.

# Define variables
input_vars = ["sst", "so", "sin_time", "cos_time", "ocean_mask"]
output_vars = ["y"]
num_var = ["sst", "so"]
cat_var = ["sin_time", "cos_time", "ocean_mask"]
train_yrs = [2015, 2020]
train_years = dataset.sel(time=dataset.time.dt.year.isin(train_yrs))
# Split the dataset using your function
train_ds, _, _, _, X_mean, X_std  = time_series_split_for_xbatcher(
    data=train_years,
    split_ratio=(1.0, 0.0, 0.0),
    num_var=num_var,
    cat_var=cat_var,
    sample_frac = 0.5
)
X_mean
array([301.79614 ,  35.002277], dtype=float32)
val_yrs = [2014, 2019]
val_years = dataset.sel(time=dataset.time.dt.year.isin(train_yrs))
_, val_ds, _, _, X_mean, X_std  = time_series_split_for_xbatcher(
    data=train_years,
    split_ratio=(0.0, 1.0, 0.0),
    num_var=num_var,
    cat_var=cat_var,
    sample_frac = 0.25,
    X_mean=X_mean, X_std=X_std
)

Set up xbatcher#

from xbatcher import BatchGenerator

# Use the whole field
input_dims = {"time": 30, "lat": 149, "lon": 181}
input_overlap = {"time": 0, "lat": 0, "lon": 0}

# Create batch generators
train_gen = BatchGenerator(
    train_ds[input_vars + output_vars], 
    input_dims=input_dims, input_overlap=input_overlap)
val_gen = BatchGenerator(
    val_ds[input_vars + output_vars], 
    input_dims=input_dims, input_overlap=input_overlap)

Set up generator functions to create numpy arrays#

# Generator: yield individual time slices for batches
input_shape = (input_dims["lat"], input_dims["lon"], len(input_vars))
output_shape = (input_dims["lat"], input_dims["lon"], 1)

def train_gen_tf_batches():
    for batch in train_gen:
        time_len = batch["y"].sizes["time"]
        for t in range(time_len):
            x = np.stack([
                batch[var].isel(time=t).data if "time" in batch[var].dims else batch[var].data
                for var in input_vars
            ], axis=-1).astype(np.float32)  # (xx, xx, n_features)
            y = batch["y"].isel(time=t).data[..., np.newaxis].astype(np.float32)  # (xx, xx, 1)
            yield x, y

def val_gen_tf_batches():
    for batch in val_gen:
        time_len = batch["y"].sizes["time"]
        for t in range(time_len):
            x = np.stack([
                batch[var].isel(time=t).data if "time" in batch[var].dims else batch[var].data
                for var in input_vars
            ], axis=-1).astype(np.float32)
            y = batch["y"].isel(time=t).data[..., np.newaxis].astype(np.float32)
            yield x, y


train_dataset_from_gen = tf.data.Dataset.from_generator(
    train_gen_tf_batches,
    output_signature=(
        tf.TensorSpec(shape=input_shape, dtype=tf.float32),
        tf.TensorSpec(shape=output_shape, dtype=tf.float32),
    )
)

val_dataset_from_gen = tf.data.Dataset.from_generator(
    val_gen_tf_batches,
    output_signature=(
        tf.TensorSpec(shape=input_shape, dtype=tf.float32),
        tf.TensorSpec(shape=output_shape, dtype=tf.float32),
    )
)

Final prep of dataset for TensorFlow#

We are not trying to preserve temporal information since we are predicting chlorophyll from same day SST and salinity. So we shuffle() to make sure that everything is i.i.d. for TensorFlow.

  • shuffle() adds randomness within the training batches

  • Prevents batch-to-batch correlation

  • Improves model convergence and generalization

The max shuffle would be the length of the training data per batch up to about 1000-2000. But if the dataset if very large, that would be a lot of overhead and not necessary. In the pper end 10 x number of batches. In our case, the generator yields one sample per time step, so the total training samples ≈ len(train_gen) × input_dims["time"]

print("Number of batches in training set:", len(train_gen))
print("Max shuffle size:", len(train_gen) * input_dims["time"])
Number of batches in training set: 12
Max shuffle size: 360

Using .repeat(). I had to add this when using a generator and specify the training steps per batch. Otherwise, TensorFlow was struggling during the first pass to figure out how much data to use and wasn’t resetting the generator (to give a new set of data) properly.

# you might need to tweak this
BATCH_SIZE = 8
SHUFFLE_N = 100
train_dataset = train_dataset_from_gen.shuffle(SHUFFLE_N).batch(BATCH_SIZE).repeat().prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset_from_gen.batch(BATCH_SIZE).repeat().prefetch(tf.data.AUTOTUNE)
%%time
# check that shape is (BATCH_SIZE, 149, 181, 5)
for x, y in train_dataset.take(1):
    print("Train x shape:", x.shape)
    print("Train y shape:", y.shape)
Train x shape: (8, 149, 181, 5)
Train y shape: (8, 149, 181, 1)
CPU times: user 3.44 s, sys: 672 ms, total: 4.11 s
Wall time: 6.31 s

Let’s build the model#

We build a simple 3-layer CNN model. Each layer preserves the (lat, lon) shape and learns filters to extract spatial patterns.

from keras.models import Sequential
from keras.layers import Input, Conv2D, BatchNormalization, Dropout

def create_model_CNN(input_shape):
    """
    Create a simple 3-layer CNN model for gridded ocean data.

    Parameters
    ----------
    input_shape : tuple
        The shape of each sample, e.g., (149, 181, 2)

    Returns
    -------
    model : keras.Model
        CNN model to predict CHL from SST and salinity
    """
    model = Sequential()

    # Input layer defines the input dimensions for the CNN
    model.add(Input(shape=input_shape))

    # Layer 1 — learns fine-scale 3x3 spatial features
    # Let the model learn 64 different patterns (filters) in the data at this layer.
    # activation relu is non-linearity
    model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    # Layer 2 — expands context to 5x5; combines fine features into larger structures
    # Reduce the number of patterns (filters) so we gradually reduce model complexity
    model.add(Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    # Layer 3 — has access to ~7x7 neighborhood; outputs CHL prediction per pixel
    # Combines all the previous layer’s features into a CHL estimate at each pixel
    # 1 response (chl) — hence, 1 prediction pixel = filter
    # linear since predicting a real continuous variable (log CHL)
    model.add(Conv2D(filters=1, kernel_size=(3, 3), padding='same', activation='linear'))

    return model

Let’s train the model#

Because we are working in batches that are chunks of our training set, we only load in a small bit into memory, process that, release the memory and go to the next chunk. This is slow but make sure we can work with larger than memory data.

We load in this 8 × (149 × 181 × n_features + 149 × 181 × 1) × 4 bytes (float32) ≈ a few tens of MB per batch, depending on n_features.

We could do bigger (we have more memory) but this shows the proof of concept.

model = create_model_CNN(input_shape)
model.compile(optimizer='adam', loss='mae', metrics=['mae'])

# Set up early stopping to prevent overfitting
early_stop = EarlyStopping(
    patience=10,              # Stop if validation loss doesn't improve for 10 epochs
    restore_best_weights=True  # Revert to the model weights from the best epoch
)

TRAIN_STEPS_PER_EPOCH = len(train_gen) * input_dims["time"] // BATCH_SIZE
VAL_STEPS_PER_EPOCH = len(val_gen) * input_dims["time"] // BATCH_SIZE

history = model.fit(
    train_dataset,
    epochs=50,                    # Maximum number of training epochs
    steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
    validation_data=val_dataset, # Use validation data during training
    validation_steps=VAL_STEPS_PER_EPOCH,
    callbacks=[early_stop],      # Stop early if no improvement
)
Epoch 1/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 21s 348ms/step - loss: 1.1818 - mae: 1.1818 - val_loss: 0.8262 - val_mae: 0.8262
Epoch 2/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 16s 352ms/step - loss: 0.6575 - mae: 0.6575 - val_loss: 0.6876 - val_mae: 0.6876
Epoch 3/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 16s 361ms/step - loss: 0.3969 - mae: 0.3969 - val_loss: 0.6155 - val_mae: 0.6155
Epoch 4/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 338ms/step - loss: 0.3402 - mae: 0.3402 - val_loss: 0.5467 - val_mae: 0.5467
Epoch 5/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 11s 254ms/step - loss: 0.3211 - mae: 0.3211 - val_loss: 0.4450 - val_mae: 0.4450
Epoch 6/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 289ms/step - loss: 0.3033 - mae: 0.3033 - val_loss: 0.3884 - val_mae: 0.3884
Epoch 7/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 232ms/step - loss: 0.2892 - mae: 0.2892 - val_loss: 0.3564 - val_mae: 0.3564
Epoch 8/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 232ms/step - loss: 0.2771 - mae: 0.2771 - val_loss: 0.3057 - val_mae: 0.3057
Epoch 9/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 233ms/step - loss: 0.2695 - mae: 0.2695 - val_loss: 0.2768 - val_mae: 0.2768
Epoch 10/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 14s 207ms/step - loss: 0.2636 - mae: 0.2636 - val_loss: 0.2634 - val_mae: 0.2634
Epoch 11/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 234ms/step - loss: 0.2631 - mae: 0.2631 - val_loss: 0.2325 - val_mae: 0.2325
Epoch 12/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 227ms/step - loss: 0.2597 - mae: 0.2597 - val_loss: 0.2200 - val_mae: 0.2200
Epoch 13/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 230ms/step - loss: 0.2524 - mae: 0.2524 - val_loss: 0.2109 - val_mae: 0.2109
Epoch 14/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 234ms/step - loss: 0.2475 - mae: 0.2475 - val_loss: 0.2151 - val_mae: 0.2151
Epoch 15/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 14s 218ms/step - loss: 0.2454 - mae: 0.2454 - val_loss: 0.2022 - val_mae: 0.2022
Epoch 16/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 234ms/step - loss: 0.2395 - mae: 0.2395 - val_loss: 0.2020 - val_mae: 0.2020
Epoch 17/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 238ms/step - loss: 0.2360 - mae: 0.2360 - val_loss: 0.2106 - val_mae: 0.2106
Epoch 18/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 14s 208ms/step - loss: 0.2439 - mae: 0.2439 - val_loss: 0.2219 - val_mae: 0.2219
Epoch 19/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 224ms/step - loss: 0.2451 - mae: 0.2451 - val_loss: 0.2152 - val_mae: 0.2152
Epoch 20/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 224ms/step - loss: 0.2397 - mae: 0.2397 - val_loss: 0.2036 - val_mae: 0.2036
Epoch 21/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 235ms/step - loss: 0.2398 - mae: 0.2398 - val_loss: 0.2078 - val_mae: 0.2078
Epoch 22/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 14s 207ms/step - loss: 0.2484 - mae: 0.2484 - val_loss: 0.2060 - val_mae: 0.2060
Epoch 23/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 231ms/step - loss: 0.2378 - mae: 0.2378 - val_loss: 0.2074 - val_mae: 0.2074
Epoch 24/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 227ms/step - loss: 0.2311 - mae: 0.2311 - val_loss: 0.2094 - val_mae: 0.2094
Epoch 25/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 228ms/step - loss: 0.2322 - mae: 0.2322 - val_loss: 0.2039 - val_mae: 0.2039
Epoch 26/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 14s 210ms/step - loss: 0.2391 - mae: 0.2391 - val_loss: 0.2016 - val_mae: 0.2016
Epoch 27/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 237ms/step - loss: 0.2369 - mae: 0.2369 - val_loss: 0.2013 - val_mae: 0.2013
Epoch 28/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 234ms/step - loss: 0.2272 - mae: 0.2272 - val_loss: 0.1952 - val_mae: 0.1952
Epoch 29/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 225ms/step - loss: 0.2326 - mae: 0.2326 - val_loss: 0.2030 - val_mae: 0.2030
Epoch 30/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 14s 208ms/step - loss: 0.2278 - mae: 0.2278 - val_loss: 0.1981 - val_mae: 0.1981
Epoch 31/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 226ms/step - loss: 0.2233 - mae: 0.2233 - val_loss: 0.2055 - val_mae: 0.2055
Epoch 32/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 236ms/step - loss: 0.2325 - mae: 0.2325 - val_loss: 0.2021 - val_mae: 0.2021
Epoch 33/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 14s 219ms/step - loss: 0.2266 - mae: 0.2266 - val_loss: 0.1996 - val_mae: 0.1996
Epoch 34/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 233ms/step - loss: 0.2317 - mae: 0.2317 - val_loss: 0.1962 - val_mae: 0.1962
Epoch 35/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 228ms/step - loss: 0.2351 - mae: 0.2351 - val_loss: 0.2037 - val_mae: 0.2037
Epoch 36/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 228ms/step - loss: 0.2371 - mae: 0.2371 - val_loss: 0.1977 - val_mae: 0.1977
Epoch 37/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 15s 228ms/step - loss: 0.2247 - mae: 0.2247 - val_loss: 0.1974 - val_mae: 0.1974
Epoch 38/50
45/45 ━━━━━━━━━━━━━━━━━━━━ 14s 217ms/step - loss: 0.2233 - mae: 0.2233 - val_loss: 0.1988 - val_mae: 0.1988

Plot training & validation loss values#

plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.grid(True)
plt.show()
../_images/044c54e9591883232859468dbfbea1b446f37420e75bcdc71f6b92dd954beb7f.png

Plot all months#

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_true_vs_predicted(data, year, model, X_mean, X_std, num_var, cat_var):
    """
    Plot true vs predicted output for first available day of each month in data_test.

    Parameters:
        data (xarray.Dataset): Contains variables 'y', 'ocean_mask', and coords 'lat', 'lon', 'time'
        year (string in format XXXX): The year to use.
        model (tf.keras.Model): Trained model with a .predict() method
        X_mean (np.ndarray): mean from the model training data (num_vars only)
        X_std (np.ndarray): std from the model training data
        num_var (np.ndarry): The variables to be standardized with `X_mean` and `X_std`
        cat_var (np.ndarry): The variables to be included in X, y but not standardized.
    """
    # Split the dataset to get full_ds
    _, _, _, year_ds, _, _ = time_series_split_for_xbatcher(
        data=data.sel(time=year),
        num_var=num_var,
        cat_var=cat_var,
        X_mean=X_mean,
        X_std=X_std,
    )

    # Get available time points and group by month
    available_dates = pd.to_datetime(year_ds.time.values)
    monthly_dates = (
        pd.Series(available_dates)
        .groupby([available_dates.year, available_dates.month])
        .min()
        .sort_values()
    )
    n_months = len(monthly_dates)

    # lat/lon info
    lat = year_ds.lat.values
    lon = year_ds.lon.values
    extent = [lon.min(), lon.max(), lat.min(), lat.max()]
    flip_lat = lat[0] > lat[-1]
    land_mask = ~year_ds["ocean_mask"].values.astype(bool)

    # Create figure and axes
    fig, axs = plt.subplots(n_months, 2, figsize=(7, 2 * n_months), constrained_layout=True)

    for i, date in enumerate(monthly_dates):
        # Select dataset for this date
        ds_at_time = year_ds.sel(time=np.datetime64(date))

        # Prepare model input: stack input variables into (lat, lon, n_features)
        input_data = np.stack([
            ds_at_time[var].values for var in input_vars
        ], axis=-1)

        # Predict: shape (lat, lon)
        predicted_output = model.predict(input_data[np.newaxis, ...])[0, ..., 0]

        # True output
        true_output = data["y"].sel(time=np.datetime64(date)).values

        # Mask land
        predicted_output[land_mask] = np.nan
        true_output[land_mask] = np.nan

        # Flip latitude if needed
        if flip_lat:
            true_output = np.flipud(true_output)
            predicted_output = np.flipud(predicted_output)

        # Shared color scale
        vmin = np.nanpercentile([true_output, predicted_output], 5)
        vmax = np.nanpercentile([true_output, predicted_output], 95)

        # Compute R²
        from sklearn.metrics import r2_score
        true_flat = true_output.flatten()
        pred_flat = predicted_output.flatten()
        valid_mask = ~np.isnan(true_flat) & ~np.isnan(pred_flat)
        r2 = r2_score(true_flat[valid_mask], pred_flat[valid_mask])

        # Plot true
        axs[i, 0].imshow(true_output, origin='lower', extent=extent,
                         vmin=vmin, vmax=vmax, cmap='viridis',
                         aspect='equal')
        axs[i, 0].set_title(f"{date.strftime('%b')} — True", fontsize=10)
        axs[i, 0].axis('off')

        # Plot predicted with R²
        axs[i, 1].imshow(predicted_output, origin='lower', extent=extent,
                         vmin=vmin, vmax=vmax, cmap='viridis', 
                         aspect='equal')
        axs[i, 1].set_title(f"{date.strftime('%b')} — Pred\n$R^2$ = {r2:.2f}", fontsize=10)
        axs[i, 1].axis('off')

    plt.suptitle(f'CHL: True vs Predicted (log scale) — {year}', fontsize=16)
    plt.show()
plot_true_vs_predicted(dataset, "2010", model, X_mean, X_std, num_var, cat_var)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 298ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
../_images/ed209238f08ea2f8d32b80785ce2432e8614d38bf93b43369fdb29c820a01a92.png

Comparing fits with metrics#

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score, mean_absolute_error
import calendar

def plot_metric_by_month(data, years, model, X_mean, X_std, num_var, cat_var, 
                         training_year=None, metric='r2'):
    """
    Plot a selected evaluation metric (R², RMSE, MAE, or Bias) by month for each year.

    Parameters:
        data (xarray.Dataset): Contains 'y', predictors, and coordinates
        years (list of str): Years to evaluate, e.g., ['2018', '2019', '2020']
        model (tf.keras.Model): Trained model with .predict() method
        X_mean, X_std (np.ndarray): Normalization stats for num_vars
        num_var, cat_var (list of str): Variable names
        training_year (str, optional): If specified, highlights that year specially
        metric (str): One of ['r2', 'rmse', 'mae', 'bias']
    """
    assert metric in ['r2', 'rmse', 'mae', 'bias'], "Invalid metric. Choose from 'r2', 'rmse', 'mae', 'bias'."
    
    metric_by_year_month = {}

    for year in years:
        data_year = data.sel(time=year)
        dates = pd.to_datetime(data_year.time.values)

        monthly_dates = (
            pd.Series(dates)
            .groupby([dates.year, dates.month])
            .min()
            .sort_values()
        )

        _, _, _, year_ds, _, _ = time_series_split_for_xbatcher(
            data=data.sel(time=year),
            num_var=num_var,
            cat_var=cat_var,
            X_mean=X_mean,
            X_std=X_std,
            )
        
        metric_scores = []
        for date in monthly_dates:
            idx = np.where(dates == date)[0][0]
            true_output = data_year['y'].sel(time=date).values
            ds_at_time = year_ds.sel(time=np.datetime64(date))
            pred_input = np.stack([
                ds_at_time[var].values for var in input_vars
            ], axis=-1)

            pred_output = model.predict(pred_input[np.newaxis, ...], verbose=0)[0][:, :, 0]
            pred_output[data_year["ocean_mask"].values == 0.0] = np.nan

            mask = ~np.isnan(true_output) & ~np.isnan(pred_output)
            y_true = true_output[mask].flatten()
            y_pred = pred_output[mask].flatten()

            if metric == 'r2':
                score = r2_score(y_true, y_pred)
            elif metric == 'rmse':
                score = np.sqrt(np.mean((y_true - y_pred)**2))
            elif metric == 'mae':
                score = mean_absolute_error(y_true, y_pred)
            elif metric == 'bias':
                score = np.mean(y_pred - y_true)

            metric_scores.append(score)

        metric_by_year_month[year] = (monthly_dates.dt.month.values, metric_scores)

    # Plotting
    plt.figure(figsize=(10, 5))
    for year, (months, scores) in metric_by_year_month.items():
        label = f"{year} (train)" if year == training_year else year
        style = "--" if year == training_year else "-"
        plt.plot(months, scores, style, marker='o', label=label)

    plt.xlabel("Month")
    plt.ylabel({
        'r2': "$R^2$",
        'rmse': "RMSE",
        'mae': "MAE",
        'bias': "Bias"
    }[metric])
    plt.title(f"Monthly {metric.upper()} by Year")
    plt.xticks(np.arange(1, 13), calendar.month_abbr[1:13])
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
%%time
plot_metric_by_month(dataset, ['2000', '2005', '2010', '2015', '2020'], model, X_mean, X_std, num_var, cat_var, training_year="2020")
../_images/96dc77ccaa8b5443db59e25aedac2b5fb7be4d30c5a07a6044c5e4c07ac899f6.png
%%time
plot_metric_by_month(dataset, ['2000', '2005', '2010', '2015', '2020'], model, 
                     X_mean, X_std, num_var, cat_var, training_year="2020", metric="bias")
../_images/1bd0626271681513e552d0ac3f685789309bfea914f342d977d7d37d121215c2.png
CPU times: user 17.2 s, sys: 2.22 s, total: 19.4 s
Wall time: 38.8 s

Summary#

That concludes the series on 2D CNNs for predicting chlorophyll in a region.

train_ds.time.values
array(['2015-01-01T00:00:00.000000000', '2015-01-03T00:00:00.000000000',
       '2015-01-05T00:00:00.000000000', '2015-01-07T00:00:00.000000000',
       '2015-01-08T00:00:00.000000000', '2015-01-09T00:00:00.000000000',
       '2015-01-10T00:00:00.000000000', '2015-01-11T00:00:00.000000000',
       '2015-01-12T00:00:00.000000000', '2015-01-14T00:00:00.000000000',
       '2015-01-15T00:00:00.000000000', '2015-01-16T00:00:00.000000000',
       '2015-01-17T00:00:00.000000000', '2015-01-18T00:00:00.000000000',
       '2015-01-19T00:00:00.000000000', '2015-01-21T00:00:00.000000000',
       '2015-01-23T00:00:00.000000000', '2015-01-28T00:00:00.000000000',
       '2015-01-30T00:00:00.000000000', '2015-02-01T00:00:00.000000000',
       '2015-02-02T00:00:00.000000000', '2015-02-07T00:00:00.000000000',
       '2015-02-08T00:00:00.000000000', '2015-02-09T00:00:00.000000000',
       '2015-02-10T00:00:00.000000000', '2015-02-11T00:00:00.000000000',
       '2015-02-13T00:00:00.000000000', '2015-02-15T00:00:00.000000000',
       '2015-02-19T00:00:00.000000000', '2015-02-24T00:00:00.000000000',
       '2015-02-25T00:00:00.000000000', '2015-02-26T00:00:00.000000000',
       '2015-02-28T00:00:00.000000000', '2015-03-01T00:00:00.000000000',
       '2015-03-05T00:00:00.000000000', '2015-03-06T00:00:00.000000000',
       '2015-03-07T00:00:00.000000000', '2015-03-08T00:00:00.000000000',
       '2015-03-10T00:00:00.000000000', '2015-03-13T00:00:00.000000000',
       '2015-03-14T00:00:00.000000000', '2015-03-16T00:00:00.000000000',
       '2015-03-17T00:00:00.000000000', '2015-03-20T00:00:00.000000000',
       '2015-03-21T00:00:00.000000000', '2015-03-24T00:00:00.000000000',
       '2015-03-25T00:00:00.000000000', '2015-03-29T00:00:00.000000000',
       '2015-03-30T00:00:00.000000000', '2015-03-31T00:00:00.000000000',
       '2015-04-01T00:00:00.000000000', '2015-04-04T00:00:00.000000000',
       '2015-04-06T00:00:00.000000000', '2015-04-07T00:00:00.000000000',
       '2015-04-09T00:00:00.000000000', '2015-04-10T00:00:00.000000000',
       '2015-04-14T00:00:00.000000000', '2015-04-15T00:00:00.000000000',
       '2015-04-16T00:00:00.000000000', '2015-04-17T00:00:00.000000000',
       '2015-04-19T00:00:00.000000000', '2015-04-20T00:00:00.000000000',
       '2015-04-22T00:00:00.000000000', '2015-04-23T00:00:00.000000000',
       '2015-04-26T00:00:00.000000000', '2015-04-27T00:00:00.000000000',
       '2015-04-29T00:00:00.000000000', '2015-05-02T00:00:00.000000000',
       '2015-05-09T00:00:00.000000000', '2015-05-10T00:00:00.000000000',
       '2015-05-11T00:00:00.000000000', '2015-05-13T00:00:00.000000000',
       '2015-05-14T00:00:00.000000000', '2015-05-16T00:00:00.000000000',
       '2015-05-17T00:00:00.000000000', '2015-05-18T00:00:00.000000000',
       '2015-05-19T00:00:00.000000000', '2015-05-20T00:00:00.000000000',
       '2015-05-22T00:00:00.000000000', '2015-05-24T00:00:00.000000000',
       '2015-05-25T00:00:00.000000000', '2015-05-30T00:00:00.000000000',
       '2015-06-01T00:00:00.000000000', '2015-06-02T00:00:00.000000000',
       '2015-06-03T00:00:00.000000000', '2015-06-04T00:00:00.000000000',
       '2015-06-06T00:00:00.000000000', '2015-06-10T00:00:00.000000000',
       '2015-06-11T00:00:00.000000000', '2015-06-12T00:00:00.000000000',
       '2015-06-14T00:00:00.000000000', '2015-06-18T00:00:00.000000000',
       '2015-06-19T00:00:00.000000000', '2015-06-20T00:00:00.000000000',
       '2015-06-22T00:00:00.000000000', '2015-06-23T00:00:00.000000000',
       '2015-06-25T00:00:00.000000000', '2015-07-02T00:00:00.000000000',
       '2015-07-04T00:00:00.000000000', '2015-07-07T00:00:00.000000000',
       '2015-07-09T00:00:00.000000000', '2015-07-11T00:00:00.000000000',
       '2015-07-14T00:00:00.000000000', '2015-07-16T00:00:00.000000000',
       '2015-07-18T00:00:00.000000000', '2015-07-21T00:00:00.000000000',
       '2015-07-22T00:00:00.000000000', '2015-07-24T00:00:00.000000000',
       '2015-07-27T00:00:00.000000000', '2015-07-30T00:00:00.000000000',
       '2015-07-31T00:00:00.000000000', '2015-08-01T00:00:00.000000000',
       '2015-08-03T00:00:00.000000000', '2015-08-04T00:00:00.000000000',
       '2015-08-06T00:00:00.000000000', '2015-08-08T00:00:00.000000000',
       '2015-08-11T00:00:00.000000000', '2015-08-13T00:00:00.000000000',
       '2015-08-15T00:00:00.000000000', '2015-08-16T00:00:00.000000000',
       '2015-08-17T00:00:00.000000000', '2015-08-19T00:00:00.000000000',
       '2015-08-22T00:00:00.000000000', '2015-08-25T00:00:00.000000000',
       '2015-08-26T00:00:00.000000000', '2015-08-29T00:00:00.000000000',
       '2015-08-31T00:00:00.000000000', '2015-09-02T00:00:00.000000000',
       '2015-09-03T00:00:00.000000000', '2015-09-04T00:00:00.000000000',
       '2015-09-10T00:00:00.000000000', '2015-09-11T00:00:00.000000000',
       '2015-09-13T00:00:00.000000000', '2015-09-14T00:00:00.000000000',
       '2015-09-16T00:00:00.000000000', '2015-09-18T00:00:00.000000000',
       '2015-09-20T00:00:00.000000000', '2015-09-21T00:00:00.000000000',
       '2015-09-22T00:00:00.000000000', '2015-09-23T00:00:00.000000000',
       '2015-09-24T00:00:00.000000000', '2015-09-25T00:00:00.000000000',
       '2015-10-01T00:00:00.000000000', '2015-10-04T00:00:00.000000000',
       '2015-10-06T00:00:00.000000000', '2015-10-10T00:00:00.000000000',
       '2015-10-12T00:00:00.000000000', '2015-10-14T00:00:00.000000000',
       '2015-10-17T00:00:00.000000000', '2015-10-18T00:00:00.000000000',
       '2015-10-19T00:00:00.000000000', '2015-10-21T00:00:00.000000000',
       '2015-10-23T00:00:00.000000000', '2015-10-27T00:00:00.000000000',
       '2015-10-29T00:00:00.000000000', '2015-11-01T00:00:00.000000000',
       '2015-11-02T00:00:00.000000000', '2015-11-03T00:00:00.000000000',
       '2015-11-04T00:00:00.000000000', '2015-11-08T00:00:00.000000000',
       '2015-11-09T00:00:00.000000000', '2015-11-11T00:00:00.000000000',
       '2015-11-12T00:00:00.000000000', '2015-11-13T00:00:00.000000000',
       '2015-11-19T00:00:00.000000000', '2015-11-20T00:00:00.000000000',
       '2015-11-24T00:00:00.000000000', '2015-11-25T00:00:00.000000000',
       '2015-11-28T00:00:00.000000000', '2015-11-30T00:00:00.000000000',
       '2015-12-02T00:00:00.000000000', '2015-12-03T00:00:00.000000000',
       '2015-12-04T00:00:00.000000000', '2015-12-05T00:00:00.000000000',
       '2015-12-06T00:00:00.000000000', '2015-12-08T00:00:00.000000000',
       '2015-12-10T00:00:00.000000000', '2015-12-13T00:00:00.000000000',
       '2015-12-14T00:00:00.000000000', '2015-12-18T00:00:00.000000000',
       '2015-12-23T00:00:00.000000000', '2015-12-25T00:00:00.000000000',
       '2015-12-27T00:00:00.000000000', '2015-12-28T00:00:00.000000000',
       '2020-01-01T00:00:00.000000000', '2020-01-03T00:00:00.000000000',
       '2020-01-04T00:00:00.000000000', '2020-01-08T00:00:00.000000000',
       '2020-01-11T00:00:00.000000000', '2020-01-13T00:00:00.000000000',
       '2020-01-15T00:00:00.000000000', '2020-01-17T00:00:00.000000000',
       '2020-01-21T00:00:00.000000000', '2020-01-22T00:00:00.000000000',
       '2020-01-24T00:00:00.000000000', '2020-01-26T00:00:00.000000000',
       '2020-01-27T00:00:00.000000000', '2020-01-28T00:00:00.000000000',
       '2020-01-29T00:00:00.000000000', '2020-01-30T00:00:00.000000000',
       '2020-02-03T00:00:00.000000000', '2020-02-05T00:00:00.000000000',
       '2020-02-06T00:00:00.000000000', '2020-02-07T00:00:00.000000000',
       '2020-02-08T00:00:00.000000000', '2020-02-09T00:00:00.000000000',
       '2020-02-10T00:00:00.000000000', '2020-02-12T00:00:00.000000000',
       '2020-02-14T00:00:00.000000000', '2020-02-15T00:00:00.000000000',
       '2020-02-19T00:00:00.000000000', '2020-02-20T00:00:00.000000000',
       '2020-02-21T00:00:00.000000000', '2020-02-23T00:00:00.000000000',
       '2020-02-25T00:00:00.000000000', '2020-03-02T00:00:00.000000000',
       '2020-03-06T00:00:00.000000000', '2020-03-07T00:00:00.000000000',
       '2020-03-08T00:00:00.000000000', '2020-03-15T00:00:00.000000000',
       '2020-03-16T00:00:00.000000000', '2020-03-22T00:00:00.000000000',
       '2020-03-24T00:00:00.000000000', '2020-03-28T00:00:00.000000000',
       '2020-03-30T00:00:00.000000000', '2020-03-31T00:00:00.000000000',
       '2020-04-02T00:00:00.000000000', '2020-04-04T00:00:00.000000000',
       '2020-04-05T00:00:00.000000000', '2020-04-06T00:00:00.000000000',
       '2020-04-07T00:00:00.000000000', '2020-04-08T00:00:00.000000000',
       '2020-04-09T00:00:00.000000000', '2020-04-10T00:00:00.000000000',
       '2020-04-13T00:00:00.000000000', '2020-04-14T00:00:00.000000000',
       '2020-04-16T00:00:00.000000000', '2020-04-19T00:00:00.000000000',
       '2020-04-20T00:00:00.000000000', '2020-04-22T00:00:00.000000000',
       '2020-04-24T00:00:00.000000000', '2020-04-26T00:00:00.000000000',
       '2020-04-27T00:00:00.000000000', '2020-04-29T00:00:00.000000000',
       '2020-04-30T00:00:00.000000000', '2020-05-03T00:00:00.000000000',
       '2020-05-07T00:00:00.000000000', '2020-05-09T00:00:00.000000000',
       '2020-05-11T00:00:00.000000000', '2020-05-12T00:00:00.000000000',
       '2020-05-15T00:00:00.000000000', '2020-05-17T00:00:00.000000000',
       '2020-05-18T00:00:00.000000000', '2020-05-19T00:00:00.000000000',
       '2020-05-20T00:00:00.000000000', '2020-05-25T00:00:00.000000000',
       '2020-05-27T00:00:00.000000000', '2020-06-03T00:00:00.000000000',
       '2020-06-05T00:00:00.000000000', '2020-06-06T00:00:00.000000000',
       '2020-06-07T00:00:00.000000000', '2020-06-08T00:00:00.000000000',
       '2020-06-09T00:00:00.000000000', '2020-06-10T00:00:00.000000000',
       '2020-06-11T00:00:00.000000000', '2020-06-12T00:00:00.000000000',
       '2020-06-13T00:00:00.000000000', '2020-06-16T00:00:00.000000000',
       '2020-06-17T00:00:00.000000000', '2020-06-19T00:00:00.000000000',
       '2020-06-20T00:00:00.000000000', '2020-06-22T00:00:00.000000000',
       '2020-06-24T00:00:00.000000000', '2020-06-26T00:00:00.000000000',
       '2020-06-27T00:00:00.000000000', '2020-06-28T00:00:00.000000000',
       '2020-06-30T00:00:00.000000000', '2020-07-02T00:00:00.000000000',
       '2020-07-05T00:00:00.000000000', '2020-07-06T00:00:00.000000000',
       '2020-07-07T00:00:00.000000000', '2020-07-10T00:00:00.000000000',
       '2020-07-12T00:00:00.000000000', '2020-07-14T00:00:00.000000000',
       '2020-07-16T00:00:00.000000000', '2020-07-18T00:00:00.000000000',
       '2020-07-19T00:00:00.000000000', '2020-07-20T00:00:00.000000000',
       '2020-07-22T00:00:00.000000000', '2020-07-23T00:00:00.000000000',
       '2020-07-26T00:00:00.000000000', '2020-07-30T00:00:00.000000000',
       '2020-07-31T00:00:00.000000000', '2020-08-02T00:00:00.000000000',
       '2020-08-05T00:00:00.000000000', '2020-08-06T00:00:00.000000000',
       '2020-08-07T00:00:00.000000000', '2020-08-11T00:00:00.000000000',
       '2020-08-12T00:00:00.000000000', '2020-08-13T00:00:00.000000000',
       '2020-08-14T00:00:00.000000000', '2020-08-21T00:00:00.000000000',
       '2020-08-22T00:00:00.000000000', '2020-08-24T00:00:00.000000000',
       '2020-08-25T00:00:00.000000000', '2020-08-26T00:00:00.000000000',
       '2020-08-28T00:00:00.000000000', '2020-08-29T00:00:00.000000000',
       '2020-08-30T00:00:00.000000000', '2020-08-31T00:00:00.000000000',
       '2020-09-01T00:00:00.000000000', '2020-09-03T00:00:00.000000000',
       '2020-09-06T00:00:00.000000000', '2020-09-08T00:00:00.000000000',
       '2020-09-09T00:00:00.000000000', '2020-09-10T00:00:00.000000000',
       '2020-09-12T00:00:00.000000000', '2020-09-13T00:00:00.000000000',
       '2020-09-20T00:00:00.000000000', '2020-09-26T00:00:00.000000000',
       '2020-09-29T00:00:00.000000000', '2020-09-30T00:00:00.000000000',
       '2020-10-01T00:00:00.000000000', '2020-10-02T00:00:00.000000000',
       '2020-10-03T00:00:00.000000000', '2020-10-04T00:00:00.000000000',
       '2020-10-05T00:00:00.000000000', '2020-10-06T00:00:00.000000000',
       '2020-10-07T00:00:00.000000000', '2020-10-14T00:00:00.000000000',
       '2020-10-16T00:00:00.000000000', '2020-10-19T00:00:00.000000000',
       '2020-10-20T00:00:00.000000000', '2020-10-21T00:00:00.000000000',
       '2020-10-22T00:00:00.000000000', '2020-10-23T00:00:00.000000000',
       '2020-10-27T00:00:00.000000000', '2020-10-28T00:00:00.000000000',
       '2020-11-01T00:00:00.000000000', '2020-11-02T00:00:00.000000000',
       '2020-11-04T00:00:00.000000000', '2020-11-06T00:00:00.000000000',
       '2020-11-09T00:00:00.000000000', '2020-11-12T00:00:00.000000000',
       '2020-11-13T00:00:00.000000000', '2020-11-15T00:00:00.000000000',
       '2020-11-16T00:00:00.000000000', '2020-11-17T00:00:00.000000000',
       '2020-11-24T00:00:00.000000000', '2020-11-26T00:00:00.000000000',
       '2020-11-28T00:00:00.000000000', '2020-11-30T00:00:00.000000000',
       '2020-12-03T00:00:00.000000000', '2020-12-04T00:00:00.000000000',
       '2020-12-06T00:00:00.000000000', '2020-12-08T00:00:00.000000000',
       '2020-12-11T00:00:00.000000000', '2020-12-12T00:00:00.000000000',
       '2020-12-13T00:00:00.000000000', '2020-12-15T00:00:00.000000000',
       '2020-12-18T00:00:00.000000000', '2020-12-26T00:00:00.000000000',
       '2020-12-28T00:00:00.000000000', '2020-12-30T00:00:00.000000000',
       '2020-12-31T00:00:00.000000000'], dtype='datetime64[ns]')