Gap-filling with ConvLSTM

Gap-filling with ConvLSTM#

Author: Yifei Hang (UW Varanasi intern 2024), Jiarui Yu (UW Varanasi intern 2023)

This notebook shows how to fit a basic ConvLSTM for filling the gaps in the Chlorophyll-a data. Although you can run this tutorial on CPU, it will be much faster on GPU. We used the image quay.io/pangeo/ml-notebook:2024.08.18 for running the notebook.

import xarray as xr
import numpy as np

import dask.array as da

import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import ConvLSTM2D, BatchNormalization, Conv2D, Dropout
# list all the physical devices
physical_devices = tf.config.list_physical_devices()
print("All Physical Devices:", physical_devices)

# list all the available GPUs
gpus = tf.config.list_physical_devices('GPU')
print("Available GPUs:", gpus)

# Print infomation for available GPU if there exists any
if gpus:
    for gpu in gpus:
        details = tf.config.experimental.get_device_details(gpu)
        print("GPU Details:", details)
else:
    print("No GPU available")
All Physical Devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Available GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU Details: {'device_name': 'NVIDIA GeForce RTX 4070 Laptop GPU', 'compute_capability': (8, 9)}
zarr_ds = xr.open_zarr(store="./参考资料/2023_INTERN_MATERIAL/INDIAN_OCEAN_025GRID_DAILY.zarr", consolidated=True)  # get data

zarr_sliced = zarr_ds.sel(lat=slice(35, -5), lon=slice(45,90))  # choose long and lat

all_nan_CHL = np.isnan(zarr_sliced.CHL).all(dim=["lon", "lat"]).compute()  # find sample indices where CHL is NaN

zarr_CHL = zarr_sliced.sel(time=(all_nan_CHL == False))  # select samples with CHL not NaN

zarr_CHL = zarr_CHL.sortby('time')
zarr_CHL = zarr_CHL.sel(time=slice('2020-01-01', '2020-12-31'))
zarr_CHL
<xarray.Dataset> Size: 711MB
Dimensions:          (time: 366, lat: 149, lon: 181)
Coordinates:
  * lat              (lat) float32 596B 32.0 31.75 31.5 ... -4.5 -4.75 -5.0
  * lon              (lon) float32 724B 45.0 45.25 45.5 ... 89.5 89.75 90.0
  * time             (time) datetime64[ns] 3kB 2020-01-01 ... 2020-12-31
Data variables: (12/19)
    CHL              (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    CHL_uncertainty  (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    adt              (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    air_temp         (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    curr_dir         (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    curr_speed       (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    ...               ...
    ug_curr          (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    v_curr           (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    v_wind           (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    vg_curr          (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    wind_dir         (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
    wind_speed       (time, lat, lon) float32 39MB dask.array<chunksize=(25, 149, 181), meta=np.ndarray>
Attributes: (12/17)
    creator_email:              minhphan@uw.edu
    creator_name:               Minh Phan
    creator_type:               person
    date_created:               2023-07-19
    geospatial_lat_max:         32.0
    geospatial_lat_min:         -12.0
    ...                         ...
    geospatial_lon_units:       degrees_east
    source:                     Earth & Space Research (ESR), Copernicus Clim...
    summary:                    Daily mean of 0.25 x 0.25 degrees gridded dat...
    time_coverage_end:          2022-12-31T23:59:59
    time_coverage_start:        1979-01-01T00:00:00
    title:                      Climate Data for Coastal Upwelling Machine Le...
p = zarr_CHL.sel(time='2020-09-02').CHL.plot(y='lat', x='lon')
../_images/f560746a9162ef6b140bf33d77ab31e367fb3f6446c3c8e7ade0df0a18f0b052.png
np.log(zarr_CHL.sel(time='2020-12-02').CHL).plot(y='lat', x='lon')
<matplotlib.collections.QuadMesh at 0x23ca9d2fa90>
../_images/4242534cf755e93367e069f77c2d81295ecee90785330193f34a6987ac1061e6.png
def log_label(data, label):
    data_logged = data.copy()
    data_logged[label] = np.log(data[label]).copy()
    return data_logged

# Add more preprocessing later
def preprocess_data(data, features, label):
    # log CHL
    data_logged = log_label(data, label)

    # save all features and labels in one list
    sel_data_list = []
    for var in (features + [label]):
        sel_var_data = data_logged[var]
        sel_var_data = da.where(da.isnan(sel_var_data), 0.0, sel_var_data)
        sel_data_list.append(sel_var_data)

    # reassign datatype of list to dask array
    sel_data_da =  da.array(sel_data_list)
    # of dimensions (var, time, lat, lon)
    return sel_data_da


# 
def data_prep(data, window_size=5):
    X_orig = data
    y_orig = data[-1]
    
    X_orig = np.moveaxis(X_orig, 0, -1)
    length = X_orig.shape[0]
    X, y = [], []

    for i in range(length - window_size):
        X.append(X_orig[i: i + window_size])
        y.append(y_orig[i + window_size])

    X, y = da.array(X), da.array(y)

    return X, y


def time_series_split(X, y, split_ratio):
    total_length = X.shape[0]
    
    train_end = int(total_length * split_ratio[0])
    val_end = int(total_length * (split_ratio[0] + split_ratio[1]))

    X_train, y_train = X[:train_end], y[:train_end]
    X_val, y_val = X[train_end: val_end], y[train_end: val_end]
    X_test, y_test = X[val_end:], y[val_end:]

    return (X_train, y_train,
            X_val, y_val,
            X_test, y_test)
    
features = ['sst', 'so']  # Curr Features: Sea Surface Temp (K), Sea Salinity Concentration (m**-3 or PSL). [Excluding Topography/Bathymetry (m)]
label = 'CHL'  # Curr Label: : chlorophyll-a concentration (mg/m**3) [Not taking uncertainty into consideration for now]
model_data = preprocess_data(zarr_CHL, features, label)

window_size = 5
X, y = data_prep(model_data, window_size=window_size)

split_ratio = [.7, .2, .1]
X_train, y_train, X_val, y_val, X_test, y_test = time_series_split(X, y, split_ratio)
y_train.shape
(252, 149, 181)
def create_model_ConvLSTM(input_shape=(5, 149, 181, 3)):
    model = Sequential()
    
    model.add(ConvLSTM2D(filters=32, kernel_size=(3, 3),
                         input_shape=input_shape,
                         padding='same', return_sequences=False))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    model.add(Conv2D(filters=64, 
                     kernel_size=(3, 3),  
                     padding='same',
                     activation='relu'
                     ))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    model.add(Conv2D(filters=32, 
                     kernel_size=(3, 3), 
                     padding='same',
                     activation='relu'
                     ))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    model.add(Conv2D(filters=1, 
                     kernel_size=(3, 3), 
                     padding='same', 
                     activation='linear'
                     ))
    
    return model

model = create_model_ConvLSTM()
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv_lstm2d (ConvLSTM2D)    (None, 149, 181, 32)      40448     
                                                                 
 batch_normalization (BatchN  (None, 149, 181, 32)     128       
 ormalization)                                                   
                                                                 
 dropout (Dropout)           (None, 149, 181, 32)      0         
                                                                 
 conv2d (Conv2D)             (None, 149, 181, 64)      18496     
                                                                 
 batch_normalization_1 (Batc  (None, 149, 181, 64)     256       
 hNormalization)                                                 
                                                                 
 dropout_1 (Dropout)         (None, 149, 181, 64)      0         
                                                                 
 conv2d_1 (Conv2D)           (None, 149, 181, 32)      18464     
                                                                 
 batch_normalization_2 (Batc  (None, 149, 181, 32)     128       
 hNormalization)                                                 
                                                                 
 dropout_2 (Dropout)         (None, 149, 181, 32)      0         
                                                                 
 conv2d_2 (Conv2D)           (None, 149, 181, 1)       289       
                                                                 
=================================================================
Total params: 78,209
Trainable params: 77,953
Non-trainable params: 256
_________________________________________________________________
model.compile(optimizer='adam', loss='mae', metrics=['mae'])

early_stop = EarlyStopping(patience=10, restore_best_weights=True)

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(8)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.batch(8)

history = model.fit(train_dataset, epochs=50, validation_data=val_dataset, callbacks=[early_stop])
Epoch 1/50
32/32 [==============================] - 13s 260ms/step - loss: 1.0611 - mae: 1.0611 - val_loss: 0.7293 - val_mae: 0.7293
Epoch 2/50
32/32 [==============================] - 8s 236ms/step - loss: 0.5555 - mae: 0.5555 - val_loss: 0.6904 - val_mae: 0.6904
Epoch 3/50
32/32 [==============================] - 8s 237ms/step - loss: 0.3589 - mae: 0.3589 - val_loss: 0.5267 - val_mae: 0.5267
Epoch 4/50
32/32 [==============================] - 8s 235ms/step - loss: 0.3039 - mae: 0.3039 - val_loss: 0.4278 - val_mae: 0.4278
Epoch 5/50
32/32 [==============================] - 8s 244ms/step - loss: 0.2711 - mae: 0.2711 - val_loss: 0.3880 - val_mae: 0.3880
Epoch 6/50
32/32 [==============================] - 8s 236ms/step - loss: 0.2492 - mae: 0.2492 - val_loss: 0.3104 - val_mae: 0.3104
Epoch 7/50
32/32 [==============================] - 8s 239ms/step - loss: 0.2358 - mae: 0.2358 - val_loss: 0.2944 - val_mae: 0.2944
Epoch 8/50
32/32 [==============================] - 8s 235ms/step - loss: 0.2261 - mae: 0.2261 - val_loss: 0.2752 - val_mae: 0.2752
Epoch 9/50
32/32 [==============================] - 8s 234ms/step - loss: 0.2110 - mae: 0.2110 - val_loss: 0.2692 - val_mae: 0.2692
Epoch 10/50
32/32 [==============================] - 8s 235ms/step - loss: 0.2070 - mae: 0.2070 - val_loss: 0.2399 - val_mae: 0.2399
Epoch 11/50
32/32 [==============================] - 8s 235ms/step - loss: 0.2002 - mae: 0.2002 - val_loss: 0.2560 - val_mae: 0.2560
Epoch 12/50
32/32 [==============================] - 8s 240ms/step - loss: 0.2038 - mae: 0.2038 - val_loss: 0.2250 - val_mae: 0.2250
Epoch 13/50
32/32 [==============================] - 8s 237ms/step - loss: 0.2039 - mae: 0.2039 - val_loss: 0.4885 - val_mae: 0.4885
Epoch 14/50
32/32 [==============================] - 8s 242ms/step - loss: 0.2034 - mae: 0.2034 - val_loss: 0.1741 - val_mae: 0.1741
Epoch 15/50
32/32 [==============================] - 8s 240ms/step - loss: 0.1953 - mae: 0.1953 - val_loss: 0.2007 - val_mae: 0.2007
Epoch 16/50
32/32 [==============================] - 8s 246ms/step - loss: 0.1901 - mae: 0.1901 - val_loss: 0.1622 - val_mae: 0.1622
Epoch 17/50
32/32 [==============================] - 8s 240ms/step - loss: 0.1837 - mae: 0.1837 - val_loss: 0.1780 - val_mae: 0.1780
Epoch 18/50
32/32 [==============================] - 8s 241ms/step - loss: 0.1822 - mae: 0.1822 - val_loss: 0.3413 - val_mae: 0.3413
Epoch 19/50
32/32 [==============================] - 8s 242ms/step - loss: 0.1771 - mae: 0.1771 - val_loss: 0.1499 - val_mae: 0.1499
Epoch 20/50
32/32 [==============================] - 8s 243ms/step - loss: 0.1877 - mae: 0.1877 - val_loss: 0.4606 - val_mae: 0.4606
Epoch 21/50
32/32 [==============================] - 8s 242ms/step - loss: 0.1908 - mae: 0.1908 - val_loss: 0.1608 - val_mae: 0.1608
Epoch 22/50
32/32 [==============================] - 8s 237ms/step - loss: 0.1743 - mae: 0.1743 - val_loss: 0.2386 - val_mae: 0.2386
Epoch 23/50
32/32 [==============================] - 8s 238ms/step - loss: 0.1712 - mae: 0.1712 - val_loss: 0.1379 - val_mae: 0.1379
Epoch 24/50
32/32 [==============================] - 8s 237ms/step - loss: 0.1735 - mae: 0.1735 - val_loss: 0.2289 - val_mae: 0.2289
Epoch 25/50
32/32 [==============================] - 8s 235ms/step - loss: 0.1718 - mae: 0.1718 - val_loss: 0.1371 - val_mae: 0.1371
Epoch 26/50
32/32 [==============================] - 8s 236ms/step - loss: 0.1714 - mae: 0.1714 - val_loss: 0.1574 - val_mae: 0.1574
Epoch 27/50
32/32 [==============================] - 8s 236ms/step - loss: 0.1760 - mae: 0.1760 - val_loss: 0.3918 - val_mae: 0.3918
Epoch 28/50
32/32 [==============================] - 8s 236ms/step - loss: 0.1724 - mae: 0.1724 - val_loss: 0.2902 - val_mae: 0.2902
Epoch 29/50
32/32 [==============================] - 8s 236ms/step - loss: 0.1796 - mae: 0.1796 - val_loss: 0.1230 - val_mae: 0.1230
Epoch 30/50
32/32 [==============================] - 8s 237ms/step - loss: 0.1686 - mae: 0.1686 - val_loss: 0.2143 - val_mae: 0.2143
Epoch 31/50
32/32 [==============================] - 8s 238ms/step - loss: 0.1635 - mae: 0.1635 - val_loss: 0.1600 - val_mae: 0.1600
Epoch 32/50
32/32 [==============================] - 8s 234ms/step - loss: 0.1619 - mae: 0.1619 - val_loss: 0.1123 - val_mae: 0.1123
Epoch 33/50
32/32 [==============================] - 8s 237ms/step - loss: 0.1618 - mae: 0.1618 - val_loss: 0.1696 - val_mae: 0.1696
Epoch 34/50
32/32 [==============================] - 8s 238ms/step - loss: 0.1656 - mae: 0.1656 - val_loss: 0.2791 - val_mae: 0.2791
Epoch 35/50
32/32 [==============================] - 8s 237ms/step - loss: 0.1713 - mae: 0.1713 - val_loss: 0.1272 - val_mae: 0.1272
Epoch 36/50
32/32 [==============================] - 8s 241ms/step - loss: 0.1715 - mae: 0.1715 - val_loss: 0.1695 - val_mae: 0.1695
Epoch 37/50
32/32 [==============================] - 8s 237ms/step - loss: 0.1664 - mae: 0.1664 - val_loss: 0.2333 - val_mae: 0.2333
Epoch 38/50
32/32 [==============================] - 8s 237ms/step - loss: 0.1717 - mae: 0.1717 - val_loss: 0.2303 - val_mae: 0.2303
Epoch 39/50
32/32 [==============================] - 8s 236ms/step - loss: 0.1606 - mae: 0.1606 - val_loss: 0.1642 - val_mae: 0.1642
Epoch 40/50
32/32 [==============================] - 8s 236ms/step - loss: 0.1570 - mae: 0.1570 - val_loss: 0.1952 - val_mae: 0.1952
Epoch 41/50
32/32 [==============================] - 8s 236ms/step - loss: 0.1646 - mae: 0.1646 - val_loss: 0.1192 - val_mae: 0.1192
Epoch 42/50
32/32 [==============================] - 8s 238ms/step - loss: 0.1621 - mae: 0.1621 - val_loss: 0.1814 - val_mae: 0.1814
# Plot training & validation loss values
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.grid(True)
plt.show()

# Plot training & validation MAE values
plt.figure(figsize=(10, 6))
plt.plot(history.history['mae'], label='Train MAE')
plt.plot(history.history['val_mae'], label='Validation MAE')
plt.title('Model Mean Absolute Error (MAE)')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.legend(loc='upper right')
plt.grid(True)
plt.show()
../_images/ab2131dca0b7f65bf4de0d71f37d5fc24a9b5e19fd45c492c6d814ba70dcf2e7.png ../_images/7476b555b283bcef0d23e2e28cf10ffa6df23e620118f7d8023baaab09882aa7.png
# Prepare test dataset
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_dataset = test_dataset.batch(4)

# Evaluate the model on the test dataset
test_loss, test_mae = model.evaluate(test_dataset)
print(f"Test Loss: {test_loss}")
print(f"Test MAE: {test_mae}")
10/10 [==============================] - 1s 56ms/step - loss: 0.1045 - mae: 0.1045
Test Loss: 0.10454108566045761
Test MAE: 0.1045411005616188
X
Array Chunk
Bytes 557.09 MiB 105.35 kiB
Shape (361, 5, 149, 181, 3) (1, 1, 149, 181, 1)
Dask graph 5415 chunks in 751 graph layers
Data type float32 numpy.ndarray
5 361 3 181 149
date_to_predict = '2020-12-02'
true_output = np.log(zarr_CHL.sel(time=date_to_predict).CHL)

date_index = (np.datetime64(date_to_predict) - np.datetime64('2020-01-01')).item().days
input = X[date_index - window_size]
input = np.array(input)
predicted_output = model.predict(input[np.newaxis, ...])[0]

predicted_output = predicted_output[:,:,0]

land_mask = np.load(r"E:\24SU Varanasi Intern\参考资料\2023_Intern_Material\land_mask_nc.npy")
predicted_output[land_mask] = np.nan

# true_output_2 = (model_data[-1])[date_index]

vmax = np.nanmax((true_output, predicted_output))
vmin = np.nanmin((true_output, predicted_output))

plt.imshow(true_output, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.title(f'True CHL on {date_to_predict}')
plt.show()

plt.imshow(predicted_output, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.title(f'Predicted CHL on {date_to_predict}')
plt.show()
1/1 [==============================] - 0s 10ms/step
../_images/ad0dc6d6c9dc03d9565c2c29aa41c15a47f8506dde6a3e6da27575d425022898.png ../_images/79b69bc20b481f28f273ad97760e1fe47371cf5f9db898aa15447d0a79fcc45d.png
def compute_mae(y_true, y_pred):
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    return np.mean(np.abs(y_true[mask] - y_pred[mask]))


predicted_mae = compute_mae(np.array(true_output), predicted_output)
print(f"MAE between Predicted Output and True Output: {predicted_mae}")

prev_day_dt64 = np.datetime64(date_to_predict) - np.timedelta64(1, 'D')
prev_day = np.datetime_as_string(prev_day_dt64, unit='D')
prev_day_CHL = np.log(zarr_CHL.sel(time=prev_day).CHL)
last_frame_mae = compute_mae(np.array(prev_day_CHL), np.array(true_output))
print(f"MAE between Last Input Frame and True Output: {last_frame_mae}")
MAE between Predicted Output and True Output: 0.16125544905662537
MAE between Last Input Frame and True Output: 0.13173514604568481
input.shape
input[-1,:,:,-1].shape
(149, 181)
# t = numpy.datetime64('2012-06-30T20:00:00.000000000-0400')
# numpy.datetime_as_string(t, unit='D')
last_day = np.datetime64(date_to_predict) - np.timedelta64(1, 'D')
last_day_str = np.datetime_as_string(last_day, unit='D')
last_day_str
'2020-09-01'
# (np.datetime64('2020-01-03') - np.datetime64('2020-01-01')) / np.timedelta64(1, 'D')
(np.datetime64('2020-01-03') - np.datetime64('2020-01-01')).item().days
2