Source code for aeolus.plot.mpl

"""Matplotlib-related utilities."""

from collections.abc import Sequence
from pathlib import Path
from typing import Optional, Union

from iris.cube import Cube
from matplotlib.axes._axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

from ..coord import get_cube_rel_days
from ..model import lfric, um
from ..model.base import Model
from ..runtime import RUNTIME

__all__ = (
    "add_custom_legend",
    "capitalise",
    "figsave",
    "hcross",
    "linspace_pm1",
    "make_list_2d",
    "map_scatter",
    "timeseries_1d",
    "timeseries_2d",
)


[docs] def add_custom_legend( ax_or_fig: Union[Axes, Figure], styles_and_labels: dict, **leg_kw ) -> None: """ Add a custom legend to a matplotlib axis or figure. Parameters ---------- ax_or_fig: matplotlib.axes._subplots.AxesSubplot / matplotlib.figure.Figure Matplotlib object where to put the legend. styles_and_labels: dict Dictionary with labels as keys and a dictionary of plot keywords as values. leg_kw: dict, optional Keyword arguments passed to `legend()` function. Example ------- >>> import matplotlib.pyplot as plt >>> ax = plt.axes() >>> my_dict = dict(foo=dict(color='C0', marker="X"), bar=dict(color='C1', marker="o")) >>> add_custom_legend(ax, my_dict, loc=2, title="blah") """ lines = [Line2D([0], [0], **style) for style in styles_and_labels.values()] leg = ax_or_fig.legend(lines, styles_and_labels.keys(), **leg_kw) try: if ax_or_fig.legend_ is not None: ax_or_fig.add_artist(leg) except AttributeError: pass
[docs] def capitalise( s: str, sep_old: Optional[str] = "_", sep_new: Optional[str] = " " ) -> str: """Split the string and capitalise each word.""" return sep_new.join([i.capitalize() for i in s.split(sep_old)])
[docs] def hcross( cube: Cube, ax: Optional[Axes] = None, model: Optional[Model] = um, **kw_plt, ) -> Optional[Axes]: """Plot a horizontal cross-section aka lat-lon map of a 2D cube.""" newax = False if ax is None: ax = plt.axes() newax = True fig = ax.figure lons = cube.coord(um.x).points lats = cube.coord(um.y).points mappable = ax.pcolormesh(lons, lats, cube.data, **kw_plt) fig.colorbar(mappable, ax=ax) if newax: return ax
[docs] def linspace_pm1(n: int) -> np.typing.ArrayLike: """Return 2n evenly spaced numbers from -1 to 1, always skipping 0.""" seq = np.linspace(0, 1, n + 1) return np.concatenate([-seq[1:][::-1], seq[1:]])
[docs] def make_list_2d( list_x: Sequence[str], list_y: Sequence[str], transpose: Optional[bool] = False, ) -> list: """Create a nested list out of 2 given lists.""" if transpose: return [[f"{key_y}-{key_x}" for key_x in list_x] for key_y in list_y] else: return [[f"{key_x}-{key_y}" for key_x in list_x] for key_y in list_y]
[docs] def map_scatter( cube: Cube, ax: Optional[Axes] = None, model: Optional[Model] = lfric, **kw_plt, ) -> Optional[Axes]: """Plot a lat-lon scatter plot of a 2D cube.""" newax = False if ax is None: ax = plt.axes() newax = True fig = ax.figure # This doesn't work because lons and lats in the mesh are of size N+2 # while the data array is of size N # lons, lats = cube.mesh.node_coords # lons, lats = lons.points, lats.points lons = cube.coord(model.x).points lats = cube.coord(model.y).points mappable = ax.scatter(lons, lats, c=cube.data, **kw_plt) fig.colorbar(mappable, ax=ax) if newax: return ax
[docs] def timeseries_1d( cube: Cube, ax: Optional[Axes] = None, model: Optional[Model] = um, **kw_plt, ) -> Optional[Axes]: """Plot time series of a 1D cube.""" newax = False if ax is None: ax = plt.axes() newax = True days = get_cube_rel_days(cube, model=um) ax.plot(days, cube.data, **kw_plt) if newax: return ax
[docs] def timeseries_2d( cube: Cube, ax: Optional[Axes] = None, model: Optional[Model] = um, **kw_plt, ) -> Optional[Axes]: """Plot time series of a 2D cube.""" newax = False if ax is None: ax = plt.axes() newax = True fig = ax.figure days = get_cube_rel_days(cube, model=um) z = cube.coord(um.z).points mappable = ax.pcolormesh(days, z, cube.data.T, **kw_plt) fig.colorbar(mappable, ax=ax) if newax: return ax
[docs] def figsave(fig: Figure, filename: Path, **kw_savefig) -> None: """Save figure and print relative path to it.""" if RUNTIME.figsave_stamp: fig.suptitle( filename.name, x=0.5, y=0.05, ha="center", fontsize="xx-small", color="tab:grey", alpha=0.5, ) save_dir = filename.absolute().parent save_dir.mkdir(parents=True, exist_ok=True) fig.savefig(filename, **kw_savefig) fmt = plt.rcParams["savefig.format"] fname_orig = filename.with_suffix(f".{fmt}") if RUNTIME.figsave_reduce_size: # See MetPy Mondays #273 orig_copy = fname_orig.with_stem(f"{fname_orig.stem}_original") fname_orig.replace(orig_copy) with Image.open(orig_copy) as im_orig: im = im_orig.convert("P", palette=Image.Palette.ADAPTIVE) im.save(fname_orig) # Delete the original orig_copy.unlink() pth = Path.cwd() rel_path = None pref = "" for par in pth.parents: pref += ".." + pth.anchor try: rel_path = f"{pref}{fname_orig.relative_to(par)}" break except ValueError: pass if rel_path is not None: print(f"Saved to {rel_path}") print(f"Size: {fname_orig.stat().st_size / 1024:.1f} KB")