"""
Extended ATMS OMB/OMA diagnostics.
This module computes and visualizes channel-by-channel statistics for ATMS
brightness temperature departures, including:
- Mean OMB / OMA
- RMS OMB / OMA
- Bias-corrected RMS (BC-RMS)
- Normalized RMS (RMS_n), using EffectiveError2 as σ_o
- RMS_n^2 debug output for chi-square consistency checks
All computations use QC2==0 and EffectiveError2, matching the values used
in the JEDI cost function. RMS_n^2 is mathematically equivalent to Jo/p
for the same QC mask and σ_o.
"""
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from .utils_loaders import load_omb, load_oma_explicit, load_qc_universal
# ---------------------------------------------------------------------
# Correct ATMS channel groups
# ---------------------------------------------------------------------
[docs]
def channel_groups():
return [
("Window", 1, 2, "lightgrey"),
("O₂ Temp", 3, 15, "lightblue"),
("Window", 16, 17, "lightgrey"),
("H₂O", 18, 22, "lightgreen"),
]
# ---------------------------------------------------------------------
# Extended ATMS Stats
# ---------------------------------------------------------------------
[docs]
def plot_stats_atms_extended(f, varname, label, outdir):
"""
Generate extended ATMS OMB/OMA diagnostics.
"""
os.makedirs(outdir, exist_ok=True)
omb = load_omb(f, varname)
oma = load_oma_explicit(f, varname)
qc = load_qc_universal(f, varname)
# Load correct observation error (σ_o)
if "EffectiveError2" in f.groups:
Rstd = f["EffectiveError2/brightnessTemperature"][:]
elif "EffectiveError1" in f.groups:
Rstd = f["EffectiveError1/brightnessTemperature"][:]
else:
Rstd = f["EffectiveError0/brightnessTemperature"][:]
Rstd = np.asarray(Rstd)
if omb is None or oma is None:
print(f"[SKIP] {label} ATMS extended stats: missing OMB/OMA")
return
if qc.ndim == 1:
qc = np.repeat(qc[:, None], omb.shape[1], axis=1)
nchan = omb.shape[1]
chans = np.arange(1, nchan + 1)
# Allocate arrays
mean_omb = np.full(nchan, np.nan)
mean_oma = np.full(nchan, np.nan)
std_omb = np.full(nchan, np.nan)
std_oma = np.full(nchan, np.nan)
rms_omb = np.full(nchan, np.nan)
rms_oma = np.full(nchan, np.nan)
rms_diff = np.full(nchan, np.nan)
nrms_omb = np.full(nchan, np.nan)
nrms_oma = np.full(nchan, np.nan)
bc_rms_omb = np.full(nchan, np.nan)
bc_rms_oma = np.full(nchan, np.nan)
# -----------------------------------------------------------------
# Compute stats per channel
# -----------------------------------------------------------------
for ch in range(nchan):
mask = (
(qc[:, ch] == 0)
& np.isfinite(omb[:, ch])
& np.isfinite(oma[:, ch])
& np.isfinite(Rstd[:, ch])
& (Rstd[:, ch] > 0)
)
if not np.any(mask):
continue
o = omb[mask, ch].astype("float64")
a = oma[mask, ch].astype("float64")
sigma = Rstd[mask, ch].astype("float64")
# Means
mean_omb[ch] = np.mean(o)
mean_oma[ch] = np.mean(a)
# BC-RMS
std_omb[ch] = np.nanstd(o, ddof=1)
std_oma[ch] = np.nanstd(a, ddof=1)
bc_rms_omb[ch] = std_omb[ch]
bc_rms_oma[ch] = std_oma[ch]
# RMS
rms_omb[ch] = np.sqrt(np.mean(o**2))
rms_oma[ch] = np.sqrt(np.mean(a**2))
rms_diff[ch] = rms_oma[ch] - rms_omb[ch]
# Normalized RMS
nrms_omb[ch] = np.sqrt(np.mean((o / sigma) ** 2))
nrms_oma[ch] = np.sqrt(np.mean((a / sigma) ** 2))
# Debug RMS_n^2
rmsn2 = nrms_omb[ch] ** 2
print(f"[DEBUG] Ch {ch+1:02d} RMS_n^2 = {rmsn2:.4f}")
# -----------------------------------------------------------------
# Total RMS_n^2 (OMB and OMA)
# -----------------------------------------------------------------
total_omb_num = 0.0
total_oma_num = 0.0
total_nobs = 0
for ch in range(nchan):
mask = (
(qc[:, ch] == 0)
& np.isfinite(omb[:, ch])
& np.isfinite(oma[:, ch])
& np.isfinite(Rstd[:, ch])
& (Rstd[:, ch] > 0)
)
if not np.any(mask):
continue
o = omb[mask, ch].astype("float64")
a = oma[mask, ch].astype("float64")
sigma = Rstd[mask, ch].astype("float64")
total_omb_num += np.sum((o / sigma) ** 2)
total_oma_num += np.sum((a / sigma) ** 2)
total_nobs += np.sum(mask)
if total_nobs > 0:
total_rmsn2_omb = total_omb_num / total_nobs
total_rmsn2_oma = total_oma_num / total_nobs
print("--------------------------------------------------------")
print(f"[INFO] ATMS total obs used (QC2==0) = {total_nobs:,d}")
print(f"[INFO] ATMS RMS_n^2 (OMB) = {total_rmsn2_omb:.6f}")
print(f"[INFO] ATMS RMS_n^2 (OMA) = {total_rmsn2_oma:.6f}")
print("--------------------------------------------------------")
# -----------------------------------------------------------------
# Plotting
# -----------------------------------------------------------------
fig, axes = plt.subplots(2, 2, figsize=(10, 7), constrained_layout=True)
ax_meanstd = axes[0, 0]
ax_rms = axes[0, 1]
ax_rmsdiff = axes[1, 0]
ax_norm = axes[1, 1]
# Shading helper
def shade(ax):
for name, c1, c2, color in channel_groups():
ax.axvspan(c1 - 0.5, c2 + 0.5, color=color, alpha=0.25, zorder=0)
# Panel 1: Mean & Std
shade(ax_meanstd)
ax_meanstd.plot(chans, mean_omb, "o-", color="blue", label="Mean OMB")
ax_meanstd.plot(chans, mean_oma, "o-", color="red", label="Mean OMA")
ax_meanstd.set_ylabel("Mean")
ax_std = ax_meanstd.twinx()
ax_std.plot(chans, std_omb, "s--", color="orange", label="Std OMB")
ax_std.plot(chans, std_oma, "s--", color="purple", label="Std OMA")
ax_std.set_ylabel("Std")
ax_meanstd.set_title("Mean & Std (OMB / OMA)")
ax_meanstd.set_xlabel("Channel")
lines1, labels1 = ax_meanstd.get_legend_handles_labels()
lines2, labels2 = ax_std.get_legend_handles_labels()
leg1 = ax_meanstd.legend(lines1 + lines2, labels1 + labels2,
loc="lower left", fontsize=8, frameon=True)
ax_meanstd.add_artist(leg1)
ax_meanstd.legend(
handles=[
Patch(facecolor="lightgrey", label="Window (1–2, 16–17)"),
Patch(facecolor="lightblue", label="O₂ Temp (3–15)"),
Patch(facecolor="lightgreen", label="H₂O (18–22)")
],
loc="lower right",
fontsize=6,
frameon=True
)
# Panel 2: RMS
shade(ax_rms)
ax_rms.plot(chans, rms_omb, "^-", color="black", label="RMS OMB")
ax_rms.plot(chans, rms_oma, "^-", color="magenta", label="RMS OMA")
ax_rms.set_title("RMS (OMB / OMA)")
ax_rms.set_xlabel("Channel")
ax_rms.set_ylabel("RMS")
ax_rms.legend(loc="upper left", fontsize=8)
# Panel 3: RMS Difference
shade(ax_rmsdiff)
ax_rmsdiff.plot(chans, rms_diff, "o-", color="green",
label="RMS(OMA) – RMS(OMB)")
ax_rmsdiff.set_title("RMS Difference (OMA – OMB)")
ax_rmsdiff.set_xlabel("Channel")
ax_rmsdiff.set_ylabel("Difference")
ax_rmsdiff.legend(loc="lower left", fontsize=8)
# Panel 4: Normalized RMS + BC-RMS
shade(ax_norm)
ax_norm.plot(chans, nrms_omb, "^-", color="black", label="NRMS OMB")
ax_norm.plot(chans, nrms_oma, "^-", color="magenta", label="NRMS OMA")
ax_norm.plot(chans, bc_rms_omb, "s--", color="orange", label="BC-RMS OMB")
ax_norm.plot(chans, bc_rms_oma, "s--", color="purple", label="BC-RMS OMA")
ax_norm.set_title("Normalized RMS & Bias-Corrected RMS")
ax_norm.set_xlabel("Channel")
ax_norm.set_ylabel("RMS")
all_vals = np.concatenate([nrms_omb, nrms_oma, bc_rms_omb, bc_rms_oma])
all_vals = all_vals[np.isfinite(all_vals)]
ymin, ymax = np.min(all_vals), np.max(all_vals)
yr = ymax - ymin
ax_norm.set_ylim(ymin - 0.1 * yr, ymax + 0.1 * yr)
ax_norm.legend(loc="upper right", fontsize=8)
# Title
fig.suptitle(
f"{label} Extended OMB/OMA Stats (QC2==0)",
fontsize=13,
y=1.02
)
# Save
outpath = os.path.join(outdir, f"{label.lower()}_stats_extended.png")
fig.savefig(outpath, dpi=150)
plt.close(fig)
print(f"[SAVED] {outpath}")