{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2f-e9eNkgmhN"
},
"source": [
"# (Exercise) Physically-Informed Climate Modeling"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EXGjjvaxc1m3"
},
"source": [
"By the end of this second exercise, you will:\n",
"\n",
"1. Understand how using physical knowledge to rescale a machine learning model's inputs can make it more robust and generalizable,\n",
"2. Know how to use *custom data generators* to nonlinearly rescale inputs before feeding them to a neural network, and\n",
"3. Practice parameterization on a realistic research case."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9SBwMz0bdybq"
},
"source": [
"While this notebook's completion time may widely vary depending on your programming experience, we estimate it will take a minimum of 30 minutes and much longer if you decide to explore the source code. This notebook provides a minimal reproducible example of the work described in the following preprint:\n",
"\n",
"[Beucler, Tom, Michael Pritchard, Janni Yuval, Ankitesh Gupta, Liran Peng, Stephan Rasp, Fiaz Ahmed et al. \"Climate-invariant machine learning.\"](https://arxiv.org/abs/2112.08440),\n",
"\n",
"and contains a reduced version of our data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J-RFpMxDepw7"
},
"source": [
"We will be relying on Keras, whose documentation you can find [here](https://keras.io/), and TensorFlow, whose documentation you can find [here](https://www.tensorflow.org/). The notebooks assume that you will run them on Google Colab (Google Colab tutorial [at this link](https://colab.research.google.com/drive/16pBJQePbqkz3QFV54L4NIkOn1kwpuRrj)).\n",
"\n",
"While everything can be run locally and there are only a handful of lines that use Google specific libraries, we encourage beginners to use Google Colab not to run into [Python virtual environment](https://docs.python.org/3/tutorial/venv.html) issues.\n",
"\n",
"Before we get started, if you are struggling with some of the exercises, do not hesitate to:\n",
"\n",
"\n",
"* Use a direct Internet search, or [stackoverflow](https://stackoverflow.com/)\n",
"* Debug your program, e.g. by following [this tutorial](https://swcarpentry.github.io/python-novice-inflammation/11-debugging/index.html)\n",
"* Use assertions, e.g. by following [this tutorial](https://swcarpentry.github.io/python-novice-inflammation/10-defensive/index.html)\n",
"* Ask for help on the course Forum"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_YTv-czHHopm"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ogO2oAiH-e7"
},
"source": [
"Storms rapidly transport heat and water in the atmosphere, regulating the Earth's climate. Can you predict how storms affect atmospheric temperatures using deep learning, even in a changing climate?\n",
"\n",
"*Source: Photo by John Fowler licensed under the Unsplash License.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5NDXqaY1xtCm"
},
"source": [
"## Part I: Configuration and Requirements"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "nkyI57nV8ikK"
},
"outputs": [],
"source": [
"#@title Run this cell for preliminary requirements. Double click for the source code\n",
"!pip install --no-binary 'shapely==1.6.4' 'shapely==1.6.4' --force\n",
"!pip install xarray==2023.02.0 # Install latest version of xarray in Spring 2023\n",
"!pip install keras==2.12.0 tensorflow==2.12.0 # Install latest version of keras-tensorflow in Spring 2023\n",
"!pip install h5py==3.8.0 # Install latest version of h5py in Spring 2023\n",
"!pip install scipy==1.10.1 # Install latest version of scipy in Spring 2023\n",
"!pip install matplotlib==3.7.1 # Install latest version of matplotlib in Spring 2023\n",
"!pip install cartopy==0.21.1 # Install latest version of cartopy in Spring 2023"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "IDEzyh7_5iTK"
},
"outputs": [],
"source": [
"#@title Run this cell for Python library imports. Double click for the source code\n",
"import cartopy\n",
"import cartopy.feature as cfeature\n",
"import cartopy.crs as ccrs\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import MaxNLocator\n",
"import numpy as np\n",
"import scipy.integrate as sin\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from keras.layers import *\n",
"\n",
"import h5py\n",
"import pickle\n",
"import pooch\n",
"import xarray as xr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "5ywIaV63uG6v"
},
"outputs": [],
"source": [
"#@title Run this cell for figure aesthetics. Double click for the source code\n",
"fz = 15 # Here we define the fontsize\n",
"lw = 2 # the linewidth\n",
"siz = 75 # and the scattered dots' size\n",
"\n",
"plt.rc('text', usetex=False)\n",
"mpl.rcParams['mathtext.fontset'] = 'stix'\n",
"mpl.rcParams['font.family'] = 'STIXGeneral'\n",
"plt.rc('font', family='serif', size=fz)\n",
"mpl.rcParams['lines.linewidth'] = lw"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "GRy8C3EEQdo8"
},
"outputs": [],
"source": [
"#@title Run this cell to load the data using the pooch library. Double click for the source code\n",
"path_data = 'https://unils-my.sharepoint.com/:u:/g/personal/tom_beucler_unil_ch/'\n",
"\n",
"# Load simulation data\n",
"path_cold = path_data + 'EfHoI_pZY3xAi4bLEuDobaUBjyQmoJd1AvYnoPdH01VN-w?download=1'\n",
"path_warm = path_data + 'Eeq_n6Qv0jZBuRkICaOb0VQB6J1cN7muM6MrA3zA-v7LFg?download=1'\n",
"cold_open = pooch.retrieve(path_cold, known_hash='7b793afdd866a2e9b0db8fdb5029a88d557bf98525601275f5a335e95b26ac1a')\n",
"warm_open = pooch.retrieve(path_warm, known_hash='211db8ae89904f1fa3e2f17dc623bc6f5c6156cf24f4e3a42d92660ab1790fd4')\n",
"cold_data = xr.open_dataset(cold_open)\n",
"warm_data = xr.open_dataset(warm_open)\n",
"\n",
"# Load normalization files\n",
"path_hyam_hybm = path_data + 'Eb3WRMTGuUJNmsywPZOr0HEB-ocxfu7UkFFteWU0SyVqdQ?download=1'\n",
"path_norm_raw = path_data + 'EbqnraroMS9OsYazoCKPvxoBi42jeBegusakwfbMtqUO3w?download=1'\n",
"path_norm_RH = path_data + 'Eb5Gsx1dm4dCnIASDm6Bc8gBgko9nP3GZVKdRDgleibuTA?download=1'\n",
"path_norm_B = path_data + 'EVDMLGtWwCtLpkACuzU-YaUBs-RnsdtlvREJLNpkuG1E9w?download=1'\n",
"path_norm_LHFnsDELQ = path_data + 'Edt4Mm1hBT9FrYM0Ngd273oB6K8TvGxcBco35SL_J_ZFZQ?download=1'\n",
"hyam_hybm_open = pooch.retrieve(path_hyam_hybm, known_hash='343339f9b0fd4d92a8a31aabf774c0a17b6ac904feb6a2cd03e19ae4ff2bd329')\n",
"norm_raw_open = pooch.retrieve(path_norm_raw, known_hash='ee3c669928031af1a03ec3bc61373107575173decf66ede9b0c3b8568214ca0f')\n",
"RH_open = pooch.retrieve(path_norm_RH, known_hash='4d5275746eb1aad4a2279e16784befaa4beeab5a2aa6545e0e85437c8d73476f')\n",
"B_open = pooch.retrieve(path_norm_B, known_hash='396df61a24f6111acc1b908cdda3d10e0649d3eb551de860b3ebeb4419adc514')\n",
"LHFnsDELQ_open = pooch.retrieve(path_norm_LHFnsDELQ, known_hash='514413a6ab0f33039df5f815a925cf8916454d288f384615050131f1bdc8b06f')\n",
"hyam_hybm = pickle.load(open(hyam_hybm_open,'rb'))\n",
"hyam,hybm = hyam_hybm\n",
"norm_raw = xr.open_dataset(norm_raw_open)\n",
"norm_RH = xr.open_dataset(RH_open)\n",
"norm_B = xr.open_dataset(B_open)\n",
"norm_LHFnsDELQ = xr.open_dataset(LHFnsDELQ_open)\n",
"\n",
"# Load training files used to build the normalization data generators\n",
"path_train_RH = path_data + 'EWX9way46H9OvLLgLqPEr4QB4WkyTPDwGB7b-EjhTVIHww?download=1'\n",
"RH_train_open = pooch.retrieve(path_train_RH, known_hash='082cb63e5fbf315d8072a8d1613c8f0d810f949d32c9ad9374523b22de87a539')\n",
"path_train_BMSE = path_data + 'EU-cEsEjKT1Gn-s1aOGFMKgBK3C3yrAuxzX5_zaSIVOE-w?download=1'\n",
"BMSE_train_open = pooch.retrieve(path_train_BMSE,known_hash='cbc8e1736ffbbbc4b2c6a000cb32942abff73438157cb8e82e4195d76d0c5ccd')\n",
"path_train_LHFnsDELQ = path_data + 'ERojIn0ALWFMsPsknNqcOFMB5bL9nb1vPgPUlhO56sMe-Q?download=1'\n",
"LHFnsDELQ_train_open = pooch.retrieve(path_train_LHFnsDELQ,\n",
" known_hash='afef6ba713cafda3cbf6c1189f7f96602b4f00507feae394c65f20362cb48ba7')\n",
"\n",
"# Extract the range of possible longitude and latitude in case of need\n",
"longitude = cold_data.lon[:144].values\n",
"latitude = cold_data.lat[:96*144:144].values\n",
"# SPCAM's background pressure coordinates (in hPa)\n",
"pressure_levels = np.array([ 3.643466, 7.59482 , 14.356632, 24.61222 ,\n",
" 38.2683 , 54.59548 , 72.012451, 87.82123 ,\n",
" 103.317127, 121.547241, 142.994039, 168.22508 ,\n",
" 197.908087, 232.828619, 273.910817, 322.241902,\n",
" 379.100904, 445.992574, 524.687175, 609.778695,\n",
" 691.38943 , 763.404481, 820.858369, 859.534767,\n",
" 887.020249, 912.644547, 936.198398, 957.48548 ,\n",
" 976.325407, 992.556095])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9FV7dNneIWqC"
},
"source": [
"## Part II: Visualizing the Extrapolation Problem"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3pJXdT5ht3wx"
},
"source": [
"We've now extracted our cold simulation data as an [Xarray DataArray](https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html) called `cold_data` and our 8K-warmer simulation data as another [Xarray DataArray](https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html) called `warm_data`.\n",
"\n",
"Just in case you need it, we've also extracted all possible latitude values in `latitude` and all possible longitude values in `longitude`. The vertical pressure levels are given by `pressure_levels` (in hPa).\n",
"\n",
"We will soon visualize the data to give you more intuition about the prediction problem. We aim to predict the effect of ~5km-scale storm, clouds, and turbulence on the climate from the climate conditions (specific humidity `QBP` in kg/kg, temperature `TBP` in K, surface pressure `PS` in hPa, solar insolation `SOLIN` in $W/m^{2}$, surface sensible heat fluxes `SHFLX` in $W/m^{2}$ and surface latent heat fluxes `LHFLX` in $W/m^{2}$.\n",
"\n",
"As an example, we'll focus on predicting \"subgrid heating tendencies\" or `TPHYSTND` in $K/s$ at the model's vertical levels, which is the rate at which these storms, clouds, and turbulence warm up the atmosphere. However, you could reproduce the example below with `PHQ`, which contains the \"subgrid *moistening* tendencies\".\n",
"\n",
"Therefore, our prediction problem can be mathematically phrased as a regression problem, in which we are trying to predict: \n",
"\n",
"`y = [ TPHYSTND[:30] ]`\n",
"\n",
"from\n",
"\n",
"`x = [ QBP[:30] , TBP[:30] , PS , SOLIN , SHFLX , LHFLX ]`,\n",
"\n",
"where we use a vertical grid with 30 vertical levels, which means that the profiles of specific humidity `QBP` and temperature `TBP` both have 30 vertical levels, while the other inputs are scalars."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7TALN75OpW7k"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ub3K6JMdpxVi"
},
"source": [
"*The inputs or features (left) of a machine-learning parameterization are variables representing the large-scale climate properties, while the outputs or targets (right) are the rate at which storm-scale turbulence redistributes heat, moisture, and affects radiative fluxes.*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "eQjSOcAOYyyM"
},
"outputs": [],
"source": [
"#@title Run this cell to calculate spatial statistics of the specific humidity input.\n",
"cold_q_m = {}; warm_q_m = {};\n",
"dictionary = ['mean','max','min']\n",
"\n",
"for idic,m in enumerate(dictionary):\n",
" cold_q_m[m] = np.zeros((len(latitude),len(longitude)))\n",
" warm_q_m[m] = np.zeros((len(latitude),len(longitude)))\n",
"\n",
"Nsample = (len(longitude)*len(latitude)) # Total number of samples\n",
"\n",
"# First, convert the arrays to numpy to accelerate calculations\n",
"# and extract the specific humidity profile q\n",
"cold_q = cold_data['vars'][:,29].values\n",
"warm_q = warm_data['vars'][:,29].values\n",
"\n",
"# We will resphape both arrays to take advantage of numpy's design\n",
"# to perform operations on entire data arrays in a single opertaion\n",
"# which is faster than (nested) loops\n",
"\n",
"# Count the number of world maps (timesteps) in each dataset\n",
"Nt_cold, Nt_warm = len(cold_q) // Nsample, len(warm_q) // Nsample\n",
"\n",
"# Deduct the target shape of both reshaped arrays\n",
"coldq_shape = (Nt_cold, len(latitude), len(longitude))\n",
"warmq_shape = (Nt_warm, len(latitude), len(longitude))\n",
"\n",
"# Reshape both arrays\n",
"# (Eliminate incomplete world maps of q at the end of the cold_q,warm_q arrays)\n",
"cold_q_reshaped = cold_q[:np.prod(coldq_shape)].reshape(coldq_shape)\n",
"warm_q_reshaped = warm_q[:np.prod(warmq_shape)].reshape(warmq_shape)\n",
"\n",
"# Calculate means, mins, maxes along the first axis\n",
"cold_q_m['mean'] = np.mean(cold_q_reshaped, axis=0)\n",
"warm_q_m['mean'] = np.mean(warm_q_reshaped, axis=0)\n",
"\n",
"cold_q_m['min'] = np.min(cold_q_reshaped, axis=0)\n",
"warm_q_m['min'] = np.min(warm_q_reshaped, axis=0)\n",
"\n",
"cold_q_m['max'] = np.max(cold_q_reshaped, axis=0)\n",
"warm_q_m['max'] = np.max(warm_q_reshaped, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "43cgC-r2l6My"
},
"outputs": [],
"source": [
"#@title Run this cell to define a function plotting maps of these spatial statistics.\n",
"def Input_map(cold_variable, warm_variable, cold_name, warm_name, var_name,\n",
" vmin, vmax):\n",
" '''\n",
" Plots maps of the cold_variable (with the title \"cold_name\")\n",
" next to the warm_variable (with the title \"warm_name\")\n",
" and sets the colorbar's label to \"var_name\".\n",
" The colorbar ranges from vmin to vmax.\n",
" '''\n",
" fig, ax = plt.subplots(1,2,\n",
" subplot_kw={'projection':ccrs.Robinson(central_longitude=180)},\n",
" figsize=(12.5,6))\n",
"\n",
" cold = ax[0].pcolormesh(longitude, latitude, cold_variable,\n",
" transform=ccrs.PlateCarree(),\n",
" vmin=vmin,vmax=vmax,shading='gouraud')\n",
" ax[0].set_title(cold_name)\n",
" warm = ax[1].pcolormesh(longitude, latitude, warm_variable,\n",
" transform=ccrs.PlateCarree(),\n",
" vmin=vmin,vmax=vmax,shading='gouraud')\n",
" ax[1].set_title(warm_name)\n",
"\n",
" for iplot in range(2):\n",
" ax[iplot].coastlines(linewidth=2.0,edgecolor='0.25')\n",
" ax[iplot].add_feature(cfeature.BORDERS,linewidth=0.5,edgecolor='0.25')\n",
"\n",
" cbar_ax = fig.add_axes([0.94,0.2,0.01,0.6])\n",
" cbar = fig.colorbar(warm, label=var_name,cax=cbar_ax)\n",
" cbar_ax.yaxis.set_ticks_position('right')\n",
" cbar_ax.yaxis.set_label_position('right')\n",
"\n",
" return fig,ax"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qDUvF-YuCwIe"
},
"outputs": [],
"source": [
"Input_map(1e3*cold_q_m['mean'], 1e3*warm_q_m['mean'],\n",
" '(Cold climate) Mean Input', '(Warm climate) Mean Input',\n",
" 'Near-surface specific humidity (g/kg)', 0, 30);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3bwG5mKAln5B"
},
"outputs": [],
"source": [
"Input_map(1e3*cold_q_m['min'], 1e3*warm_q_m['min'],\n",
" '(Cold climate) Minimum Input', '(Warm climate) Minimum Input',\n",
" 'Near-surface specific humidity (g/kg)', 0, 30);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TCADmsd8o5Dy"
},
"outputs": [],
"source": [
"Input_map(1e3*cold_q_m['max'], 1e3*warm_q_m['max'],\n",
" '(Cold climate) Maximum Input', '(Warm climate) Maximum Input',\n",
" 'Near-surface specific humidity (g/kg)', 0, 30);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ySS8SHJmpTcB"
},
"source": [
"As you can see, the warm climate contains values of the inputs (here, near-surface specific humidity) that were never seen during training. This means that even if we learn an excellent machine-learning parameterization in the cold climate, using this same machine-learning parameterization in a warm climate will be a challenging *extrapolation* problem.\n",
"\n",
"Two goals of this notebook will be to:\n",
"\n",
"1. Expose the failure of a deep learning parameterization when generalizing from a cold climate to a warm climate.\n",
"2. Physically rescaling the inputs to minimize distribution changes across climates. This will improve the generalization ability of our deep learning algorithms. \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "duR2LKsuY7mU"
},
"source": [
"☁ Let's start right away with the first physical rescaling: specific humidity into relative humidity ☁"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F71zOmBQaD8H"
},
"source": [
"## **Q1) To avoid the extrapolation problems visualized above, transform specific humidity into relative humidity**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JdiMk7DbYbT7"
},
"source": [
"You may use the equations below, which are adapted from the [System for Atmospheric Modeling model](https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2021MS002968) for consistency with our storm-resolving data. Assuming you already calculated the saturation pressure of water vapor over liquid water $e_{liq}$ and the saturation pressure of water vapor over ice $e_{ice}$, the model calculates relative humidity in two steps.\n",
"\n",
"1. It assumes that the system's combined saturation pressure is:\n",
"\n",
"* Equal to $e_{liq}$ for above-freezing temperatures ($T>273.16K$)\n",
"* Equal to $e_{ice}$ for cold temperatures ($T<253.16K$)\n",
"* A linear combination of the two in the intermediate range:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lhJ1TnetcMWy"
},
"source": [
"$e_{sat}=\\omega \\times e_{liq} + (1-\\omega) \\times e_{ice}$\n",
"\n",
"where the weight $\\omega$ is defined as:\n",
"\n",
"$\\omega = \\frac{T-253.16K}{273.16K-253.16K}$."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xXzstiqAcNUn"
},
"source": [
"2. It then combines the ideal gas law and Dalton's law with the definition of relative humidity to calculate relative humidity as:\n",
"\n",
"$RH = \\frac{e}{e_{sat}} = \\frac{R_v}{R_d} \\times \\frac{p}{e_{sat}} \\times q_v$\n",
"\n",
"where:\n",
"\n",
"* $R_v \\approx 287 J kg^{-1} K^{-1}$ is the specific ideal gas constant for water vapor\n",
"* $R_d \\approx 461 J kg^{-1} K^{-1}$ is the specific ideal gas constant for a standard dry air mixture\n",
"* $p$ is air pressure\n",
"* $q_v$ is specific humidity, or equivalently the water vapor mass concentration (in kg/kg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l4WOLvqxZuoP"
},
"outputs": [],
"source": [
"# Assume you have access to specific humidity, temperature, and air pressure\n",
"# Remember Python indices are left-inclusive and right-exclusive\n",
"# (You can just run this cell)\n",
"\n",
"specific_humidity = cold_data['vars'][:,:30] # in kg/kg\n",
"temperature = cold_data['vars'][:,30:60] # in K\n",
"\n",
"P0 = 1e5 # Mean surface air pressure (Pa)\n",
"near_surface_air_pressure = cold_data['vars'][:,60]\n",
"# Formula to calculate air pressure (in Pa) using the hybrid vertical grid\n",
"# coefficients at the middle of each vertical level: hyam and hybm\n",
"air_pressure_Pa = np.outer(near_surface_air_pressure**0,P0*hyam) + \\\n",
"np.outer(near_surface_air_pressure,hybm)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s7xlJWengt_Q"
},
"source": [
"Our goal is to calculate relative humidity using the above equations. We'll assume we already have functions (below) giving us `e_{liq}` and `e_{ice}` as a function of air temperature `T`. These functions can be called using:\n",
"\n",
"* `eliq(temperature)`\n",
"* `eice(temperature)`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "1Hsy9p4Ghe-G"
},
"outputs": [],
"source": [
"#@title Source code for eliq(T) and esat(T): Do not forget to execute this cell\n",
"def eliq(T):\n",
" \"\"\"\n",
" Function taking temperature (in K) and outputting liquid saturation\n",
" pressure (in hPa) using a polynomial fit\n",
" \"\"\"\n",
" a_liq = np.array([-0.976195544e-15,-0.952447341e-13,0.640689451e-10,\n",
" 0.206739458e-7,0.302950461e-5,0.264847430e-3,\n",
" 0.142986287e-1,0.443987641,6.11239921]);\n",
" c_liq = -80\n",
" T0 = 273.16\n",
" return 100*np.polyval(a_liq,np.maximum(c_liq,T-T0))\n",
"\n",
"def eice(T):\n",
" \"\"\"\n",
" Function taking temperature (in K) and outputting ice saturation\n",
" pressure (in hPa) using a polynomial fit\n",
" \"\"\"\n",
" a_ice = np.array([0.252751365e-14,0.146898966e-11,0.385852041e-9,\n",
" 0.602588177e-7,0.615021634e-5,0.420895665e-3,\n",
" 0.188439774e-1,0.503160820,6.11147274]);\n",
" c_ice = np.array([273.15,185,-100,0.00763685,0.000151069,7.48215e-07])\n",
" T0 = 273.16\n",
" return (T>c_ice[0])*eliq(T)+\\\n",
" (T<=c_ice[0])*(T>c_ice[1])*100*np.polyval(a_ice,T-T0)+\\\n",
" (T<=c_ice[1])*100*(c_ice[3]+np.maximum(c_ice[2],T-T0)*\\\n",
" (c_ice[4]+np.maximum(c_ice[2],T-T0)*c_ice[5]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "advSakAZvKgh"
},
"source": [
"Hints:\n",
"\n",
"\n",
"\n",
"* You can often accelerate your calculations by converting [xarray DataArray](https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html) into [numpy nd-arrays](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) using their [`values` method](https://docs.xarray.dev/en/stable/generated/xarray.DataArray.values.html). You will load these values into memory when you do this, so it will increase RAM usage.\n",
"* To implement $e_{sat}$, you may e.g., use [numpy's `where` function](https://numpy.org/doc/stable/reference/generated/numpy.where.html) or booleans.\n",
"* Given that super-saturation is rare at climate timescales, you can optionally bound your relative humidity calculation by 1.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vuxTBHZ0F7Oh"
},
"source": [
"💡 **For all questions, you can write your own code or complete the proposed code by replacing the underscores with the appropriate script**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QfCs7iblF_0t"
},
"outputs": [],
"source": [
"# Here's an empty code cell to look at the data, etc.\n",
"# You can add or remove code and text cells via the \"Insert\" menu"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y4NG6SKmjHxH"
},
"outputs": [],
"source": [
"# Q1.1) Calculate the combined saturation water vapor pressure here\n",
"\n",
"omega = (temperature - ___) / (___ - ___)\n",
"\n",
"# Make sure your weight omega is always between 0 and 1\n",
"omega = ___\n",
"\n",
"esat = ___"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FRovNrI_kbO_"
},
"outputs": [],
"source": [
"# Q1.2) Calculate relative humidity here\n",
"\n",
"relative_humidity = ___"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U7w4tDsgufzi"
},
"outputs": [],
"source": [
"#@title A possible solution for Q1\n",
"\n",
"# 1) Calculating saturation water vapor pressure\n",
"T0 = 273.16 # Freezing temperature in standard conditions\n",
"T00 = 253.16 # Temperature below which we use e_ice\n",
"omega = (temperature - T00) / (T0 - T00)\n",
"omega = np.maximum( 0, np.minimum( 1, omega ))\n",
"\n",
"esat = omega * eliq(temperature) + (1-omega) * eice(temperature)\n",
"\n",
"# 2) Calculating relative humidity\n",
"Rd = 287 # Specific gas constant for dry air\n",
"Rv = 461 # Specific gas constant for water vapor\n",
"\n",
"# We use the `values` method to convert Xarray DataArray into Numpy ND-Arrays\n",
"relative_humidity = Rv/Rd * air_pressure_Pa/esat.values * specific_humidity.values"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1sHUy1BQwY4X"
},
"source": [
"## **Q2) Ensure that you mostly solved the extrapolation problem visualized in Part I**\n",
"\n",
"For this purpose, let's repeat the visualization of near-surface humidity, but this time using **near-surface** *relative* instead of *specific* humidity."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "2lRDZ4uVxLRz"
},
"outputs": [],
"source": [
"#@title First, let's automatize the relative humidity calculation...\n",
"\n",
"def RH_from_climate(data):\n",
"\n",
" # 0) Extract specific humidity, temperature, and air pressure\n",
" specific_humidity = data['vars'][:,:30] # in kg/kg\n",
" temperature = data['vars'][:,30:60] # in K\n",
"\n",
" P0 = 1e5 # Mean surface air pressure (Pa)\n",
" near_surface_air_pressure = data['vars'][:,60]\n",
" # Formula to calculate air pressure (in Pa) using the hybrid vertical grid\n",
" # coefficients at the middle of each vertical level: hyam and hybm\n",
" air_pressure_Pa = np.outer(near_surface_air_pressure**0,P0*hyam) + \\\n",
" np.outer(near_surface_air_pressure,hybm)\n",
"\n",
" # 1) Calculating saturation water vapor pressure\n",
" T0 = 273.16 # Freezing temperature in standard conditions\n",
" T00 = 253.16 # Temperature below which we use e_ice\n",
" omega = (temperature - T00) / (T0 - T00)\n",
" omega = np.maximum( 0, np.minimum( 1, omega ))\n",
"\n",
" esat = omega * eliq(temperature) + (1-omega) * eice(temperature)\n",
"\n",
" # 2) Calculating relative humidity\n",
" Rd = 287 # Specific gas constant for dry air\n",
" Rv = 461 # Specific gas constant for water vapor\n",
"\n",
" # We use the `values` method to convert Xarray DataArray into Numpy ND-Arrays\n",
" return Rv/Rd * air_pressure_Pa/esat.values * specific_humidity.values"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qiTGZDQ3x9im"
},
"source": [
"... so that we can calculate relative humidity for both cold and warm climates. We can then extract **near-surface** (993hPa) humidity, which is the input variable that was clearly out-of-distribution when we used *specific* humidity."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BTNLofOlyD3i"
},
"outputs": [],
"source": [
"RH_cold = RH_from_climate(cold_data) # Relative humidity for the cold simulation output\n",
"RH_warm = RH_from_climate(warm_data) # Relative humidity for the warm simulation output\n",
"\n",
"near_surf_RH_cold = RH_cold[:,-1] # Near-surface relative humidity for the cold case\n",
"near_surf_RH_warm = RH_warm[:,-1] # Near-surface relative humidity for the warm case"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uPLbh86Ny5w9"
},
"outputs": [],
"source": [
"# First, calculate the spatial statistics of this new near-surface\n",
"# relative humidity field. You may complete the code below or write your own.\n",
"cold_RH_m = {}; warm_RH_m = {};\n",
"dictionary = ['mean','max','min']\n",
"\n",
"for idic,m in enumerate(dictionary):\n",
" cold_RH_m[m] = np.zeros((len(latitude),len(longitude)))\n",
" warm_RH_m[m] = np.zeros((len(latitude),len(longitude)))\n",
"\n",
"cold_RH_reshaped = near_surf_RH_cold[:np.prod(coldq_shape)].reshape(coldq_shape)\n",
"warm_RH_reshaped = near_surf_RH_warm[:np.prod(warmq_shape)].reshape(warmq_shape)\n",
"\n",
"# Complete the code below calculating the spatial statistics\n",
"cold_RH_m['mean'] = np.___(___, axis=___)\n",
"warm_RH_m['mean'] = np.___(___, axis=___)\n",
"\n",
"cold_RH_m['min'] = ___\n",
"warm_RH_m['min'] = ___\n",
"\n",
"cold_RH_m['max'] = ___\n",
"warm_RH_m['max'] = ___"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yuxnuvdkykVZ"
},
"outputs": [],
"source": [
"# Visualize the mean near-surface relative humidity at each location\n",
"# You may use the Input_map function provided above\n",
"Input_map(___,___,___,___,___,___,___);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "su3sFD7cz9Fu"
},
"outputs": [],
"source": [
"# Visualize the maximum at each location"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FSH1YQfj0CfW"
},
"outputs": [],
"source": [
"# Visualize the minimum at each location"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "y2UEAs4F0iOH"
},
"outputs": [],
"source": [
"#@title A possible solution for Q2\n",
"\n",
"# Calculate relative humidity statistics in cold/warm simulation output\n",
"cold_RH_m = {}; warm_RH_m = {};\n",
"dictionary = ['mean','max','min']\n",
"\n",
"for idic,m in enumerate(dictionary):\n",
" cold_RH_m[m] = np.zeros((len(latitude),len(longitude)))\n",
" warm_RH_m[m] = np.zeros((len(latitude),len(longitude)))\n",
"\n",
"# Calculate the spatial statistics of the newly derived RH field\n",
"cold_RH_reshaped = near_surf_RH_cold[:np.prod(coldq_shape)].reshape(coldq_shape)\n",
"warm_RH_reshaped = near_surf_RH_warm[:np.prod(warmq_shape)].reshape(warmq_shape)\n",
"\n",
"cold_RH_m['mean'] = np.mean(cold_RH_reshaped, axis=0)\n",
"warm_RH_m['mean'] = np.mean(warm_RH_reshaped, axis=0)\n",
"\n",
"cold_RH_m['min'] = np.min(cold_RH_reshaped, axis=0)\n",
"warm_RH_m['min'] = np.min(warm_RH_reshaped, axis=0)\n",
"\n",
"cold_RH_m['max'] = np.max(cold_RH_reshaped, axis=0)\n",
"warm_RH_m['max'] = np.max(warm_RH_reshaped, axis=0)\n",
"\n",
"# Visualize the new relative humidity input\n",
"Input_map(1e2*cold_RH_m['mean'], 1e2*warm_RH_m['mean'],\n",
" '(Cold climate) Mean Input', '(Warm climate) Mean Input',\n",
" 'Relative Humidity (%)', 0, 100);\n",
"Input_map(1e2*cold_RH_m['max'], 1e2*warm_RH_m['max'],\n",
" '(Cold climate) Max Input', '(Warm climate) Max Input',\n",
" 'Relative Humidity (%)', 0, 100);\n",
"Input_map(1e2*cold_RH_m['min'], 1e2*warm_RH_m['min'],\n",
" '(Cold climate) Min Input', '(Warm climate) Min Input',\n",
" 'Relative Humidity (%)', 0, 100);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "twmMmP7e1K-w"
},
"source": [
"😃 It looks like our input distribution is now quite similar in the cold and warm climates! Since the support (range) of the relative humidity is $\\approx [0,1]$, we now expect to have mostly converted a difficult *extrapolation* case into an easier *interpolation* case. The rest of this notebook will explore the consequences of this **physical rescaling** for the performance and generalization ability of neural networks 🧠"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QRjxwxzj3B0Y"
},
"source": [
"It would take much more than the alloted 30 minutes to repeat the same physical rescaling in a data generator (also called data loader or data pipeline), so we directly give you the source code below. Run it if you would like to proceed, and read it if you want to dive into fun implementation details 🤓"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LoGnytgkpBEn"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YtTzeXq5pDjs"
},
"source": [
"*Our custom data generator rescales the inputs $x$ to $\\widetilde{x}$ before feeding them to the machine learning model for training.*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "0S6W988UaG6p"
},
"outputs": [],
"source": [
"#@title Source code for the moist thermodynamics library: Run to proceed and double click to read\n",
"\n",
"# Constants for the Community Atmosphere Model\n",
"DT = 1800.\n",
"L_V = 2.501e6 # Latent heat of vaporization\n",
"L_I = 3.337e5 # Latent heat of freezing\n",
"L_F = L_I\n",
"L_S = L_V + L_I # Sublimation\n",
"C_P = 1.00464e3 # Specific heat capacity of air at constant pressure\n",
"G = 9.80616\n",
"RHO_L = 1e3\n",
"\n",
"# Moist thermodynamics library in numpy\n",
"class CrhClass:\n",
" def __init__(self):\n",
" pass\n",
"\n",
" def eliq(self,T):\n",
" a_liq = np.array([-0.976195544e-15,-0.952447341e-13,0.640689451e-10,0.206739458e-7,0.302950461e-5,0.264847430e-3,0.142986287e-1,0.443987641,6.11239921]);\n",
" c_liq = -80\n",
" T0 = 273.16\n",
" return 100*np.polyval(a_liq,np.maximum(c_liq,T-T0))\n",
"\n",
" def eice(self,T):\n",
" a_ice = np.array([0.252751365e-14,0.146898966e-11,0.385852041e-9,0.602588177e-7,0.615021634e-5,0.420895665e-3,0.188439774e-1,0.503160820,6.11147274]);\n",
" c_ice = np.array([273.15,185,-100,0.00763685,0.000151069,7.48215e-07])\n",
" T0 = 273.16\n",
" return (T>c_ice[0])*self.eliq(T)+\\\n",
" (T<=c_ice[0])*(T>c_ice[1])*100*np.polyval(a_ice,T-T0)+\\\n",
" (T<=c_ice[1])*100*(c_ice[3]+np.maximum(c_ice[2],T-T0)*(c_ice[4]+np.maximum(c_ice[2],T-T0)*c_ice[5]))\n",
"\n",
" def esat(self,T):\n",
" T0 = 273.16\n",
" T00 = 253.16\n",
" omega = np.maximum(0,np.minimum(1,(T-T00)/(T0-T00)))\n",
"\n",
" return (T>T0)*self.eliq(T)+(T=T00)*(omega*self.eliq(T)+(1-omega)*self.eice(T))\n",
"\n",
" def RH(self,T,qv,P0,PS,hyam,hybm):\n",
" R = 287\n",
" Rv = 461\n",
" S = PS.shape\n",
" p = 1e5 * np.tile(hyam,(S[0],1))+np.transpose(np.tile(PS,(30,1)))*np.tile(hybm,(S[0],1))\n",
"\n",
" return Rv*p*qv/(R*self.esat(T))\n",
"\n",
" def qv(self,T,RH,P0,PS,hyam,hybm):\n",
" R = 287\n",
" Rv = 461\n",
" S = PS.shape\n",
" p = 1e5 * np.tile(hyam,(S[0],1))+np.transpose(np.tile(PS,(30,1)))*np.tile(hybm,(S[0],1))\n",
"\n",
" return R*self.esat(T)*RH/(Rv*p)\n",
"\n",
"\n",
" def qsat(self,T,P0,PS,hyam,hybm):\n",
" return self.qv(T,1,P0,PS,hyam,hybm)\n",
"\n",
"\n",
"\n",
" def dP(self,PS):\n",
" S = PS.shape\n",
" P = 1e5 * np.tile(hyai,(S[0],1))+np.transpose(np.tile(PS,(31,1)))*np.tile(hybi,(S[0],1))\n",
" return P[:, 1:]-P[:, :-1]\n",
"\n",
"\n",
"class ThermLibNumpy:\n",
" @staticmethod\n",
" def eliqNumpy(T):\n",
" a_liq = np.float32(np.array([-0.976195544e-15,-0.952447341e-13,\\\n",
" 0.640689451e-10,\\\n",
" 0.206739458e-7,0.302950461e-5,0.264847430e-3,\\\n",
" 0.142986287e-1,0.443987641,6.11239921]));\n",
" c_liq = np.float32(-80.0)\n",
" T0 = np.float32(273.16)\n",
" return np.float32(100.0)*np.polyval(a_liq,np.maximum(c_liq,T-T0))\n",
"\n",
"\n",
" @staticmethod\n",
" def eiceNumpy(T):\n",
" a_ice = np.float32(np.array([0.252751365e-14,0.146898966e-11,0.385852041e-9,\\\n",
" 0.602588177e-7,0.615021634e-5,0.420895665e-3,\\\n",
" 0.188439774e-1,0.503160820,6.11147274]));\n",
" c_ice = np.float32(np.array([273.15,185,-100,0.00763685,0.000151069,7.48215e-07]))\n",
" T0 = np.float32(273.16)\n",
" return np.where(T>c_ice[0],ThermLibNumpy.eliqNumpy(T),\\\n",
" np.where(T<=c_ice[1],np.float32(100.0)*(c_ice[3]+np.maximum(c_ice[2],T-T0)*\\\n",
" (c_ice[4]+np.maximum(c_ice[2],T-T0)*c_ice[5])),\\\n",
" np.float32(100.0)*np.polyval(a_ice,T-T0)))\n",
"\n",
" @staticmethod\n",
" def esatNumpy(T):\n",
" T0 = np.float32(273.16)\n",
" T00 = np.float32(253.16)\n",
" omtmp = (T-T00)/(T0-T00)\n",
" omega = np.maximum(np.float32(0.0),np.minimum(np.float32(1.0),omtmp))\n",
"\n",
" return np.where(T>T0,ThermLibNumpy.eliqNumpy(T),np.where(T