Source code for ufs_da_diagnostics.spectra.spectra_core

"""
Spectral Diagnostics Core Engine
================================

This module implements the core spectral diagnostics used by both
``ufsda-spectra-ana-inc`` (CTRL vs EXP increment comparison) and
``ufsda-spectra-bkg-inc`` (background vs increment comparison).

The ``SpectraCore`` class provides:

- FV3 cubed-sphere tile loading
- Grid loading and corner → center conversion
- Cubed-sphere → global lat/lon regridding
- 1D isotropic spectra
- Vertical variance profiles
- 2D spectral ratios

This module contains **no plotting** and **no CLI logic** — it is a
pure computational engine used by the two driver scripts.
"""

import xarray as xr
import numpy as np
from scipy.interpolate import griddata


[docs] class SpectraCore: """ Core FV3-JEDI spectral diagnostics engine. This class provides the computational backend for spectral analysis workflows, including: - Loading FV3 cubed-sphere tiles - Loading FV3 grid geometry and converting corner → center points - Regridding cubed-sphere tiles to a global lat/lon mesh - Computing isotropic 1D wavenumber spectra - Computing vertical variance profiles - Computing 2D spectral ratios (EXP / CTRL) Parameters ---------- varname : str, optional Name of the variable to extract from FV3 increment files. Default is ``"T_inc"``. grid_prefix : str, optional Prefix for FV3 grid files, e.g. ``"./C96_grid.tile"``. suffix : str, optional File suffix for both grid and data files (default ``".nc"``). Notes ----- This class does **not** perform any plotting. Plotting is handled by the driver modules: - ``spectra_analysis_tiles.py`` (ANA–INC) - ``spectra_analysis_bkg_inc.py`` (BKG–INC) """ def __init__(self, varname="T_inc", grid_prefix="./C96_grid.tile", suffix=".nc"): self.varname = varname self.grid_prefix = grid_prefix self.suffix = suffix self.nlevels = None self.k = None # ----------------------------------------------------------------------
[docs] def load_tile(self, prefix, tile, level): """Load a single FV3 cubed-sphere tile for a given vertical level. Parameters ---------- prefix : str File prefix for the tile, e.g. ``"/path/to/atminc.tile"``. tile : int Tile number (1–6). level : int Vertical level index. Returns ------- numpy.ndarray 2D array of the requested variable at the given level. Notes ----- Assumes FV3-JEDI tile files follow the naming pattern:: <prefix><tile><suffix> """ ds = xr.open_dataset(f"{prefix}{tile}{self.suffix}", engine="netcdf4") return ds[self.varname].isel(Time=0, zaxis_1=level).values
# ----------------------------------------------------------------------
[docs] def load_grid(self, tile): """Load FV3 grid geometry and compute cell-center coordinates. Parameters ---------- tile : int Tile number (1–6). Returns ------- lon_c : numpy.ndarray 2D array of cell-center longitudes. lat_c : numpy.ndarray 2D array of cell-center latitudes. Notes ----- FV3 grid files store corner coordinates. This method converts them to cell centers using a 4-point average. """ grid = xr.open_dataset(f"{self.grid_prefix}{tile}{self.suffix}", engine="netcdf4") lon = grid["x"].values lat = grid["y"].values lon_c = 0.25 * (lon[0:-1:2, 0:-1:2] + lon[1::2, 0:-1:2] + lon[0:-1:2, 1::2] + lon[1::2, 1::2]) lat_c = 0.25 * (lat[0:-1:2, 0:-1:2] + lat[1::2, 0:-1:2] + lat[0:-1:2, 1::2] + lat[1::2, 1::2]) lon_c = (lon_c + 180) % 360 - 180 return lon_c, lat_c
# ----------------------------------------------------------------------
[docs] def build_global(self, prefix, level): """Regrid all six FV3 tiles to a global lat/lon mesh for one level. Parameters ---------- prefix : str Tile file prefix, e.g. ``"/path/to/atminc.tile"``. level : int Vertical level index. Returns ------- Lon : numpy.ndarray 2D longitude meshgrid (regular lat/lon grid). Lat : numpy.ndarray 2D latitude meshgrid. field_interp : numpy.ndarray 2D field interpolated onto the global mesh. Notes ----- - Tile fields are flattened and concatenated. - ``scipy.interpolate.griddata`` performs nearest-neighbor interpolation onto a regular 1°×1° grid. """ fields, lons, lats = [], [], [] for t in range(1, 7): fields.append(self.load_tile(prefix, t, level)) lon_c, lat_c = self.load_grid(t) lons.append(l_c := lon_c) lats.append(lat_c) field_global = np.concatenate([f.flatten() for f in fields]) lon_global = np.concatenate([g.flatten() for g in lons]) lat_global = np.concatenate([g.flatten() for g in lats]) lon_new = np.linspace(-180, 180, 360) lat_new = np.linspace(-90, 90, 181) Lon, Lat = np.meshgrid(lon_new, lat_new) field_interp = griddata( points=(lon_global, lat_global), values=field_global, xi=(Lon, Lat), method="nearest" ) return Lon, Lat, field_interp
# ----------------------------------------------------------------------
[docs] def build_global_all_levels(self, prefix, nlevels): """Regrid all vertical levels for a given experiment. Parameters ---------- prefix : str Tile file prefix. nlevels : int Number of vertical levels. Returns ------- Lon : numpy.ndarray 2D longitude meshgrid. Lat : numpy.ndarray 2D latitude meshgrid. fields : numpy.ndarray 3D array of shape ``(nlevels, ny, nx)`` containing all levels. """ Lon, Lat = None, None fields = [] for lev in range(nlevels): Lon, Lat, field = self.build_global(prefix, lev) fields.append(field) return Lon, Lat, np.array(fields)
# ----------------------------------------------------------------------
[docs] def isotropic_spectrum(self, field): """Compute the 1D isotropic wavenumber spectrum of a 2D field. Parameters ---------- field : numpy.ndarray 2D field on a regular lat/lon grid. Returns ------- numpy.ndarray 1D isotropic spectrum ``E(k)``. Notes ----- - Uses 2D FFT with ``fftshift``. - Computes radial bins in wavenumber space. - Returns energy density averaged over annuli. """ ny, nx = field.shape F = np.fft.fftshift(np.fft.fft2(field)) P = np.abs(F)**2 ky = np.fft.fftshift(np.fft.fftfreq(ny)) kx = np.fft.fftshift(np.fft.fftfreq(nx)) KX, KY = np.meshgrid(kx, ky) K = np.sqrt(KX**2 + KY**2) kmax = int(np.sqrt((nx/2)**2 + (ny/2)**2)) spectrum = np.zeros(kmax) counts = np.zeros(kmax) for i in range(ny): for j in range(nx): k = int(K[i, j] * kmax) if k < kmax: spectrum[k] += P[i, j] counts[k] += 1 return spectrum / np.maximum(counts, 1)
# ----------------------------------------------------------------------
[docs] def load_fields(self, ctrl_prefix, exp_prefix, nlevels=127): """Load CTRL and EXP fields for all vertical levels and compute spectra. Parameters ---------- ctrl_prefix : str File prefix for CTRL increment tiles. exp_prefix : str File prefix for EXP increment tiles. nlevels : int, optional Number of vertical levels to load (default 127). Returns ------- None Results are stored as attributes: - ``CTRL_3D`` : 3D array of CTRL fields - ``EXP_3D`` : 3D array of EXP fields - ``spec_ctrl_all`` : spectra for all CTRL levels - ``spec_exp_all`` : spectra for all EXP levels - ``spec_ratio_all`` : EXP/CTRL spectral ratio - ``k`` : wavenumber array """ self.Lon, self.Lat, self.CTRL_3D = self.build_global_all_levels(ctrl_prefix, nlevels) _, _, self.EXP_3D = self.build_global_all_levels(exp_prefix, nlevels) self.nlevels = nlevels self.spec_ctrl_all = [] self.spec_exp_all = [] for lev in range(nlevels): sc = self.isotropic_spectrum(self.CTRL_3D[lev]) se = self.isotropic_spectrum(self.EXP_3D[lev]) self.spec_ctrl_all.append(sc) self.spec_exp_all.append(se) self.spec_ctrl_all = np.array(self.spec_ctrl_all) self.spec_exp_all = np.array(self.spec_exp_all) self.spec_ratio_all = self.spec_exp_all / np.maximum(self.spec_ctrl_all, 1e-12) self.k = np.arange(self.spec_ctrl_all.shape[1])
# ----------------------------------------------------------------------
[docs] def variance_profile(self): """Compute vertical variance ratio (EXP / CTRL). Returns ------- numpy.ndarray 1D array of variance ratios for each vertical level. Notes ----- Variance is computed as the sum of the isotropic spectrum across all wavenumbers. """ var_ctrl = self.spec_ctrl_all.sum(axis=1) var_exp = self.spec_exp_all.sum(axis=1) return var_exp / np.maximum(var_ctrl, 1e-12)
# ----------------------------------------------------------------------
[docs] def spectral_ratio_2d(self): """Return the full 2D spectral ratio array (EXP / CTRL). Returns ------- numpy.ndarray 2D array of shape ``(nlevels, nk)`` containing spectral ratios. """ return self.spec_ratio_all