Source code for ufs_da_diagnostics.plots.obs_diag_plotter

import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from netCDF4 import Dataset
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from .utils_loaders import (
    load_obsvalue,
    load_omb,
    load_oma_explicit,
    load_qc_universal,
)

from .atms_stats import plot_stats_atms
from .atms_stats_extended import plot_stats_atms_extended
from .atms_scan_position import plot_scan_position_atms
from .atms_latbins import plot_latbins_atms


# ============================================================
# Unified Histogram
# ============================================================

[docs] def unified_histogram(omb, oma, qc, title_label, outpath, qc_label="QC", nbins=None): """Generic histogram for scalar/vector/ATMS channels.""" os.makedirs(os.path.dirname(outpath), exist_ok=True) omb = np.ravel(omb) oma = np.ravel(oma) qc = np.ravel(qc) valid_omb = (qc == 0) & np.isfinite(omb) valid_oma = (qc == 0) & np.isfinite(oma) if np.sum(valid_omb) == 0: print(f"[SKIP] {title_label}: no valid OMB") return if nbins is None: std = np.nanstd(omb[valid_omb]) nbins = 100 if (np.isfinite(std) and std < 1) else 80 fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) ax.hist(omb[valid_omb], bins=nbins, color="lightgrey", alpha=0.7, density=True) if np.sum(valid_omb) > 1: sns.kdeplot(omb[valid_omb], color="dimgray", linewidth=2, ax=ax) if np.sum(valid_oma) > 1: sns.kdeplot(oma[valid_oma], color="red", linewidth=2, ax=ax) ax.set_xlabel("Value") ax.set_ylabel("Density") fig.text(0.12, 0.93, f"{title_label} ({qc_label}==0)", ha="left", fontsize=12) fig.text(0.18, 0.87, f"N assimilated = {np.sum(valid_omb)}", ha="left", fontsize=9) ax.legend( handles=[ plt.Line2D([0], [0], color="dimgray", lw=2, label="OMB"), plt.Line2D([0], [0], color="red", lw=2, label="OMA") ], loc="upper right", fontsize=8, frameon=False ) # GNSSRO histograms need tighter x-range if "GNSSRO" in title_label.upper(): xmin, xmax = ax.get_xlim() center = 0.5 * (xmin + xmax) half = 0.5 * (xmax - xmin) ax.set_xlim(center - 0.2 * half, center + 0.2 * half) fig.savefig(outpath, dpi=150) plt.close(fig) print(f"[SAVED] {outpath}")
# ============================================================ # Main Orchestrator # ============================================================
[docs] class ObsDiagPlotter: def __init__(self, config): self.config = config # Allow prefix_root in YAML prefix_root = self.config.get("prefix_root", None) if prefix_root is not None: for obs in self.config.get("observations", []): if "file" in obs: obs["diag"] = os.path.join(prefix_root, obs["file"])
[docs] def run(self): obs_list = self.config.get("observations", []) global_outdir = ( self.config.get("output_dir") or self.config.get("outdir") or "./plot-outputs-obs" ) for obs_cfg in obs_list: label = obs_cfg["label"] otype = obs_cfg["type"] var = obs_cfg["variable"] diag = obs_cfg.get("diag", obs_cfg.get("file")) outdir = obs_cfg.get("output_dir") or obs_cfg.get("outdir") or global_outdir diags_cfg = obs_cfg.get("diagnostics", {}) print(f"[INFO] Processing {label} ({otype}) from {diag}") with Dataset(diag, "r") as f: # ---------------- ATMS ---------------- if otype == "atms": if diags_cfg.get("hist", False): self._plot_atms_histograms(f, var, label, outdir) if diags_cfg.get("stats", False): plot_stats_atms(f, var, label, outdir) if diags_cfg.get("extended", False): plot_stats_atms_extended(f, var, label, outdir) if diags_cfg.get("scanpos", False): plot_scan_position_atms(f, var, label, outdir) if diags_cfg.get("latbins", False): plot_latbins_atms(f, var, label, outdir) if diags_cfg.get("scatter", False): self._plot_scatter(f, var, label, outdir) if diags_cfg.get("scatter_map", False): selected_chans = diags_cfg.get("scatter_map_channels", None) self._plot_scatter_map(f, var, label, outdir, selected_chans) # ---------------- Scalar ---------------- elif otype == "scalar": if diags_cfg.get("hist", False): self._plot_scalar_hist(f, var, label, outdir) if diags_cfg.get("scatter", False): self._plot_scatter(f, var, label, outdir) if diags_cfg.get("scatter_map", False): self._plot_scatter_map(f, var, label, outdir) # ---------------- Vector ---------------- elif otype == "vector": if diags_cfg.get("hist", False): self._plot_vector_hist(f, label, outdir) if diags_cfg.get("scatter", False): self._plot_vector_scatter(f, label, outdir) if diags_cfg.get("scatter_map", False): self._plot_scatter_map_vector(f, label, outdir) print("[INFO] Diagnostics complete.")
# ============================================================ # ATMS Histograms # ============================================================ def _plot_atms_histograms(self, f, varname, label, outdir): """ATMS channel-by-channel histograms.""" obs = load_obsvalue(f, varname) omb = load_omb(f, varname) oma = load_oma_explicit(f, varname) qc = load_qc_universal(f, varname) if obs is None or omb is None or oma is None: return if qc.ndim == 1: qc = np.repeat(qc[:, None], obs.shape[1], axis=1) for ch in range(obs.shape[1]): unified_histogram( omb[:, ch], oma[:, ch], qc[:, ch], f"{label} Ch {ch+1:02d} Histogram", os.path.join(outdir, f"{label.lower()}_ch{ch+1:02d}_hist.png"), qc_label="QC2", ) # ============================================================ # Scalar Histograms # ============================================================ def _plot_scalar_hist(self, f, varname, label, outdir): obs = load_obsvalue(f, varname) omb = load_omb(f, varname) oma = load_oma_explicit(f, varname) qc = load_qc_universal(f, varname) if obs is None and omb is None: return if omb is None: omb = obs if oma is None: oma = obs unified_histogram( omb, oma, qc, f"{label} Histogram", os.path.join(outdir, f"{label.lower()}_hist.png"), qc_label="QC", ) # ============================================================ # Vector Histograms # ============================================================ def _plot_vector_hist(self, f, label, outdir): u_name = "windEastward" v_name = "windNorthward" omb_u = load_omb(f, u_name) oma_u = load_oma_explicit(f, u_name) omb_v = load_omb(f, v_name) oma_v = load_oma_explicit(f, v_name) qc_u = load_qc_universal(f, u_name) qc_v = load_qc_universal(f, v_name) unified_histogram( omb_u, oma_u, qc_u, f"{label} windEastward", os.path.join(outdir, f"{label.lower()}_u_hist.png"), qc_label="QC1", ) unified_histogram( omb_v, oma_v, qc_v, f"{label} windNorthward", os.path.join(outdir, f"{label.lower()}_v_hist.png"), qc_label="QC1", ) # ============================================================ # Scatter (scalar/ATMS) # ============================================================ def _plot_scatter(self, f, varname, label, outdir): obs = load_obsvalue(f, varname) omb = load_omb(f, varname) qc = load_qc_universal(f, varname) if obs is None or omb is None: return valid = (qc == 0) & np.isfinite(obs) & np.isfinite(omb) if np.sum(valid) == 0: return obs_valid = obs[valid] omb_valid = omb[valid] scatter_dir = os.path.join(outdir, "scatter_plots") os.makedirs(scatter_dir, exist_ok=True) plt.figure(figsize=(6, 6)) plt.scatter(obs_valid, omb_valid, s=2, alpha=0.4) plt.xlabel("ObsValue") plt.ylabel("OMB") plt.title(f"{label} (assimilated, count={len(obs_valid)})") plt.grid(True) plt.savefig(os.path.join(scatter_dir, f"{label.lower()}_omb_scatter.png"), dpi=150, bbox_inches="tight") plt.close() # ============================================================ # Scatter Map (scalar / ATMS per-channel) # ============================================================ def _plot_scatter_map(self, f, varname, label, outdir, selected_chans=None): """Global scatter map for scalar and ATMS observations.""" obs = load_obsvalue(f, varname) omb = load_omb(f, varname) qc = load_qc_universal(f, varname) if obs is None or omb is None: return # Mask fill values omb = np.where(omb > 1e10, np.nan, omb) # ============================================================ # ATMS CASE — per-channel scatter maps # ============================================================ if omb.ndim == 2: nchans = omb.shape[1] # Debug printout print(f"[DEBUG] Selected ATMS channels from YAML: {selected_chans}") # Apply YAML channel selection if selected_chans is not None: chan_list = [ch - 1 for ch in selected_chans] # YAML is 1‑based else: chan_list = list(range(nchans)) # Loop over selected channels ONLY for ch in chan_list: omb_ch = omb[:, ch] qc_ch = qc[:, ch] # Valid mask valid = (qc_ch == 0) & np.isfinite(omb_ch) if np.sum(valid) == 0: continue omb_1d = omb_ch[valid] # Load lat/lon lat = lon = None if "MetaData" in f.groups: g = f.groups["MetaData"] lat = g["latitude"][:] if "latitude" in g.variables else None lon = g["longitude"][:] if "longitude" in g.variables else None if lat is None or lon is None: return if lat.ndim == 2: lat = lat[:, 0] if lon.ndim == 2: lon = lon[:, 0] lat = lat[valid] lon = lon[valid] N = len(lat) # Dot size + color range s = min(6.0, max(1.5, 30000 / N)) alpha = 0.7 vmin, vmax = np.percentile(omb_1d, [2, 98]) # Output directory map_dir = os.path.join(outdir, "scatter_maps") os.makedirs(map_dir, exist_ok=True) # Plot fig = plt.figure(figsize=(12, 6)) ax = plt.axes(projection=ccrs.PlateCarree()) ax.set_global() ax.coastlines() ax.add_feature(cfeature.BORDERS, linewidth=0.3) gl = ax.gridlines(draw_labels=True, linewidth=0.3) gl.top_labels = False gl.right_labels = False sc = ax.scatter( lon, lat, c=omb_1d, s=s, alpha=alpha, cmap="coolwarm", linewidths=0, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree() ) # Bottom colorbar cbar = plt.colorbar( sc, ax=ax, orientation="horizontal", pad=0.06, shrink=0.4 ) cbar.set_label("OMB") plt.title(f"{label} Ch {ch+1:02d} (assimilated, count={N})") outfile = os.path.join(map_dir, f"{label.lower()}_ch{ch+1:02d}_omb_map.png") plt.savefig(outfile, dpi=150, bbox_inches="tight") plt.close() print(f"[SAVED] {outfile}") return # ATMS case fully handled # ============================================================ # SCALAR CASE — unchanged # ============================================================ omb_1d = omb qc_1d = (qc == 0) lat = lon = None if "MetaData" in f.groups: g = f.groups["MetaData"] lat = g["latitude"][:] if "latitude" in g.variables else None lon = g["longitude"][:] if "longitude" in g.variables else None if lat is None or lon is None: return if lat.ndim == 2: lat = lat[:, 0] if lon.ndim == 2: lon = lon[:, 0] valid = qc_1d & np.isfinite(omb_1d) & np.isfinite(lat) & np.isfinite(lon) if np.sum(valid) == 0: return lat = lat[valid] lon = lon[valid] omb = omb_1d[valid] N = len(lat) # Dot size + color range if label.upper() == "GNSSRO": s = min(6.0, max(1.5, 80000 / N)) alpha = 0.30 vmin, vmax = np.percentile(omb, [10, 90]) else: s = min(6.0, max(1.5, 30000 / N)) alpha = 0.7 vmin, vmax = np.percentile(omb, [2, 98]) # Output directory map_dir = os.path.join(outdir, "scatter_maps") os.makedirs(map_dir, exist_ok=True) # Plot fig = plt.figure(figsize=(12, 6)) ax = plt.axes(projection=ccrs.PlateCarree()) ax.set_global() ax.coastlines() ax.add_feature(cfeature.BORDERS, linewidth=0.3) gl = ax.gridlines(draw_labels=True, linewidth=0.3) gl.top_labels = False gl.right_labels = False sc = ax.scatter( lon, lat, c=omb, s=s, alpha=alpha, cmap="coolwarm", linewidths=0, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree() ) # Bottom colorbar cbar = plt.colorbar( sc, ax=ax, orientation="horizontal", pad=0.06, shrink=0.4 ) cbar.set_label("OMB") # Scientific notation for GNSSRO if label.upper() == "GNSSRO": from matplotlib.ticker import ScalarFormatter fmt = ScalarFormatter(useMathText=True) fmt.set_powerlimits((-3, 3)) cbar.ax.xaxis.set_major_formatter(fmt) plt.title(f"{label} (assimilated, count={N})") outfile = os.path.join(map_dir, f"{label.lower()}_omb_map.png") plt.savefig(outfile, dpi=150, bbox_inches="tight") plt.close() print(f"[SAVED] {outfile}") # ============================================================ # Scatter Map (vector winds) # ============================================================ def _plot_scatter_map_vector(self, f, label, outdir): u_name = "windEastward" v_name = "windNorthward" omb_u = load_omb(f, u_name) omb_v = load_omb(f, v_name) qc_u = load_qc_universal(f, u_name) qc_v = load_qc_universal(f, v_name) if omb_u is None or omb_v is None: return # Mask fill values omb_u = np.where(omb_u > 1e10, np.nan, omb_u) omb_v = np.where(omb_v > 1e10, np.nan, omb_v) # Wind-speed OMB omb_speed = np.sqrt(omb_u**2 + omb_v**2) # QC mask qc = np.minimum(qc_u, qc_v) qc_mask = (qc == 0) # Load lat/lon lat = lon = None if "MetaData" in f.groups: g = f.groups["MetaData"] lat = g["latitude"][:] if "latitude" in g.variables else None lon = g["longitude"][:] if "longitude" in g.variables else None if lat is None or lon is None: return if lat.ndim == 2: lat = lat[:, 0] if lon.ndim == 2: lon = lon[:, 0] # Valid mask valid = qc_mask & np.isfinite(omb_speed) & np.isfinite(lat) & np.isfinite(lon) if np.sum(valid) == 0: return lat = lat[valid] lon = lon[valid] omb = omb_speed[valid] N = len(lat) # Dot size + color range s = min(5.0, max(1.2, 25000 / N)) alpha = 0.55 vmin, vmax = np.percentile(omb, [5, 95]) # Output directory map_dir = os.path.join(outdir, "scatter_maps") os.makedirs(map_dir, exist_ok=True) # --- Plot --- fig = plt.figure(figsize=(12, 6)) ax = plt.axes(projection=ccrs.PlateCarree()) ax.set_global() ax.coastlines() ax.add_feature(cfeature.BORDERS, linewidth=0.3) gl = ax.gridlines(draw_labels=True, linewidth=0.3) gl.top_labels = False gl.right_labels = False sc = ax.scatter( lon, lat, c=omb, s=s, alpha=alpha, cmap="coolwarm", linewidths=0, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree() ) # Bottom, smaller colorbar cbar = plt.colorbar( sc, ax=ax, orientation="horizontal", pad=0.06, shrink=0.4 ) cbar.set_label("OMB wind speed") plt.title(f"{label} (assimilated, count={N})") outfile = os.path.join(map_dir, f"{label.lower()}_omb_map.png") plt.savefig(outfile, dpi=150, bbox_inches="tight") plt.close() print(f"[SAVED] {outfile}")