Source code for emlib.matplotting

"""
Routines to help draw shapes / labels within a matplotlib (pyplot) plot.
Implements the concept of a plotting profile, which makes it easier to
define sizes, colors, etc. for a series of elements.
"""
from __future__ import annotations
from functools import cache
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.patches import Rectangle

import numpy as np

import typing as _t
if _t.TYPE_CHECKING:
    from matplotlib.colors import Colormap
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure
    _colort: _t.TypeAlias = float | tuple[float, float, float] | tuple[float, float, float, float]


defaultprofile = {
    'label_font': 'sans-serif',
    'label_size': 10,
    'label_alpha': 0.75,
    'line_alpha': 0.8,
    'line_style': 'solid',
    'edgecolor': 1,
    'facecolor': 0.8,
    'alpha': 0.75,
    'linewidth': 1,
    'annotation_color': (0, 0, 0),
    'annotation_alpha': 0.3,
    'autoscale': True,
    'background': (0, 0, 0),
    'colormap': 'jet'
}


[docs] def makeProfile(**kws): """ Create a profile based on a default profile A profile is used to determine multiple defaults Example ------- >>> makeProfile( ... label_font="Roboto", ... background=(10, 10, 10), ... linewidth=2) """ out = defaultprofile.copy() if kws: assert all(key in defaultprofile for key in kws), f"Unknown keys: {[k for k in kws if k not in defaultprofile]}" out |= kws return out
[docs] def drawLabel(ax: Axes, x: float, y: float, text: str, size=None, alpha=None, profile=defaultprofile) -> None: """ Draw a text label at the given coordinates Args: ax: the plot axes x: x coordinate y: y coordinate text: the text size: size of the label. If given, overrides the profile's "label_size" alpha: if given, overrides the profile's "label_alpha" profile: the profile used (None = default) """ family = profile['label_font'] size = size if size is not None else profile['label_size'] alpha = alpha if alpha is not None else profile['label_alpha'] ax.text(x, y, text, ha="center", family=family, size=size, alpha=alpha)
[docs] def drawLine(ax: Axes, x0: float, y0: float, x1: float, y1: float, color: float = None, linestyle='solid', alpha: float = None, linewidth: float = None, label='', profile=defaultprofile, cmap='', autoscale: bool | None = None) -> None: """ Draw a line from ``(x0, y0)`` to ``(x1, y1)`` Args: ax: a Axes to draw on x0: x coord of the start point y0: y coord of the start point x1: x coord of the end point y1: y coord of the end point color: the color of the line as a value 0-1 within the colormap space linestyle: 'solid', 'dashed' alpha: a float 0-1 label: if given, a label is plotted next to the line autoscale: autoscale axis if True profile: the profile (created via makeProfile) to use. Examples -------- >>> import matplotlib.pyplot as plt >>> from emlib import matplotting >>> fig, ax = plt.subplots() >>> matplotting.drawLine(ax, 0, 0, 1, 1) >>> plt.show() """ linewidth = linewidth if linewidth is not None else profile['line_width'] alpha = alpha if alpha is not None else profile['line_alpha'] color = color if color is not None else profile['edgecolor'] X, Y = np.array([[x0, x1], [y0, y1]]) assert linestyle in ('solid', 'dashed') colortup = _get_colormap(cmap or profile['colormap'])(color) line = mlines.Line2D(X, Y, lw=linewidth, alpha=alpha, color=colortup, linestyle=linestyle) ax.add_line(line) if label is not None: drawLabel(ax, x=(x0+x1)*0.5, y=y0, text=label, profile=profile) if autoscale or (autoscale is None and profile['autoscale']): autoscaleAxis(ax)
def _aslist(obj) -> list: if isinstance(obj, list): return obj return list(obj) @cache def _get_colormap(name: str) -> Colormap: return plt.get_cmap(name) @cache def _getcolor(color: _colort, colormap: str ) -> tuple[float, float, float, float]: if isinstance(color, tuple): return color if len(color) == 4 else color + (1,) return _get_colormap(colormap)(color)
[docs] def drawConnectedLines(ax: Axes, pairs: list[tuple[float, float]], connectEdges=False, color: _colort = None, alpha: float = None, linewidth: float = None, label='', linestyle='', profile=defaultprofile, cmap='' ) -> None: """ Draw an open / closed poligon Args: ax: the plot axes pairs: a list of (x, y) pairs connectEdges: close the form, connecting start end end points color: the color to use. A float selects a color from the current color map alpha: alpha value of the lines linewidth: the line width label: an optional label to attach to the start of the lines linestyle: the line style, one of "solid", "dashed" profile: the profile used, or None for default """ cmap = cmap or profile['colormap'] if linewidth is None: linewidth = profile['line_width'] if alpha is None: alpha = profile['line_alpha'] if color is None: color = profile['edgecolor'] if not linestyle: linestyle = profile['line_style'] if connectEdges: pairs = pairs.copy() pairs.append(pairs[0]) xs, ys = zip(*pairs) line = mlines.Line2D(xs, ys, lw=linewidth, alpha=alpha, color=_getcolor(color, cmap), linestyle=linestyle) ax.add_line(line) if label: avgx = sum(xs)/len(xs) avgy = sum(ys)/len(ys) drawLabel(ax, x=avgx, y=avgy, text=label, profile=profile) if profile['autoscale']: autoscaleAxis(ax)
[docs] def drawRect(ax: Axes, x0: float, y0: float, x1: float, y1: float, color: _colort, alpha: float = None, edgecolor: _colort = None, label='', profile=defaultprofile) -> None: """ Draw a rectangle from point (x0, y0) to (x1, y1) Args: ax: the plot axe x0: x coord of the start point y0: y coord of the start point x1: x coord of the end point y1: y coord of the end point color: the face color edgecolor: the color of the edges alpha: alpha value for the rectangle (both facecolor and edgecolor) label: if given, a label is plotted at the center of the rectangle profile: the profile used, or None for default """ cmap = profile['colormap'] facecolor = color if color is not None else profile['facecolor'] edgecolor = edgecolor if edgecolor is not None else profile['edgecolor'] facecolor = _getcolor(facecolor, cmap) edgecolor = _getcolor(edgecolor, cmap) alpha = alpha if alpha is not None else profile['alpha'] rect = Rectangle((x0, y0), x1-x0, y1-y0, facecolor=facecolor, edgecolor=edgecolor, alpha=alpha) ax.add_patch(rect) if label is not None: drawLabel(ax, x=(x0+x1)*0.5, y=(y0+y1)*0.5, text=label, profile=profile) if profile['autoscale']: autoscaleAxis(ax)
[docs] def drawRects(ax: Axes, data, facecolor: _colort = None, alpha: float = None, edgecolor: _colort = None, linewidth: float | None = None, profile=defaultprofile, autolim=True ) -> None: """ Draw multiple rectangles Args: ax: the plot axes data: either a 2D array of shape (num. rectangles, 4), or a list of tuples (x0, y0, x1, y1), where each row is a rectangle facecolor: the face color edgecolor: the color of the edges alpha: alpha value for the rectangle (both facecolor and edgecolor) label: if given, a label is plotted at the center of the rectangle profile: the profile used, or None for default autolim: autoscale view linewidth: line width """ facecolor = facecolor if facecolor is not None else profile['facecolor'] edgecolor = edgecolor if edgecolor is not None else profile['edgecolor'] linewidth = linewidth if linewidth is not None else profile['linewidth'] cmap = profile['colormap'] facecolor = _getcolor(facecolor, cmap) edgecolor = _getcolor(edgecolor, cmap) rects = [] for coords in data: x0, y0, x1, y1 = coords rect = Rectangle((x0, y0), x1-x0, y1-y0) rects.append(rect) from matplotlib.collections import PatchCollection coll = PatchCollection(rects, linewidth=linewidth, alpha=alpha, edgecolor=edgecolor, facecolor=facecolor) ax.add_collection(coll, autolim=True) if autolim: ax.autoscale_view()
[docs] def autoscaleAxis(ax: Axes) -> None: ax.relim() ax.autoscale_view(True,True,True)
[docs] def makeAxis(pixels: tuple[int, int] | None = None, dpi=96) -> Axes: """ Create a plotting axes Args: pixels: the size of the plot, in pixels dpi: dots per inch Returns: the Axes """ # plt.subplots(figsize=(20, 10)) if pixels is None: fig, ax = plt.subplots() return ax if not isinstance(pixels, tuple): raise TypeError(f"pixels should be of the form (x, y), got {pixels}") import emlib.misc xinches = emlib.misc.pixels_to_inches(pixels[0], dpi=dpi) yinches = emlib.misc.pixels_to_inches(pixels[1], dpi=dpi) fig,ax = plt.subplots(figsize=(xinches, yinches), dpi=dpi) return ax
[docs] def drawBracket(ax: Axes, x0: float, y0: float, x1: float, y1: float, label='', color=None, linewidth: float = None, alpha: float = None, profile=defaultprofile) -> None: """ Draw a bracket from (x0, y0) to (x1, y1) Args: ax: the plot axe x0: x coord of the start point y0: y coord of the start point x1: x coord of the end point y1: y coord of the end point color: the face color alpha: alpha value for the rectangle (both facecolor and edgecolor) label: if given, a label is plotted at the center of the rectangle linewidth: line width profile: the profile used, or None for default """ if linewidth is None: linewidth = profile['linewidth'] if color is None: color = profile['edgecolor'] if alpha is None: alpha = profile['annotation_alpha'] data = [(x0, y0), (x0, y1), (x1, y1), (x1, y0)] drawConnectedLines(ax, data, color=color, linewidth=linewidth, label=label, alpha=alpha)
[docs] def plotDurs(durs: list[float], y0=0.0, x0=0.0, height=1.0, labels: list[str] = None, color: _colort | None = None, ax: Axes = None, groupLabel='', profile=defaultprofile, stacked=False ) -> Axes: """ Plot durations as contiguous rectangles Args: durs: the durations expressed in seconds y0: y of origin x0: x of origin height: the height of the drawn rectangles labels: if given, a label for each rectangle color: the color used for the rectangles ax: the axes to draw on. If not given, a new axes is created (and returned) groupLabel: a label for the group profile: the profile used, or None to use a default stacked: if True, the rectangles are drawn stacked vertically (the duration is still drawn horizontally). The result is then similar to a bars plot Returns: the plot axes. If *ax* was given, then it is returned; otherwise the new axes is returned. """ if ax is None: ax = makeAxis() if color is None: color = profile['facecolor'] if not stacked: x = x0 data = [] for i, dur in enumerate(durs): data.append((x, y0, x+dur, y0+height)) x += dur drawRects(ax, data, facecolor=color) if groupLabel: sep = height * 0.05 y1 = y0 + height x1 = x0 + sum(durs) drawBracket(ax, x0, y1+sep, x1, y1+sep*2, color=profile['annotation_color']) alpha = (profile['annotation_alpha'] +1) * 0.5 drawLabel(ax, (x0+x1) * 0.5, y1 + sep, text=groupLabel, alpha=alpha) else: data = [] y = y0 for dur in durs: data.append((x0, y, x0+dur, y+height)) y += height drawRects(ax, data, facecolor=color) return ax
[docs] def fig2data(fig: Figure) -> np.ndarray: """ Convert a Matplotlib figure to a 4D numpy array with RGBA channels Args: fig: a matplotlib figure Returns: a numpy 3D array of RGBA values """ fig.canvas.draw() # draw the renderer # Get the RGBA buffer from the figure w, h = fig.canvas.get_width_height() buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) buf.shape = (w, h, 4) # canvas.tostring_argb give pixmap in ARGB mode. # Roll the ALPHA channel to have it in RGBA mode buf = np.roll(buf, 3, axis=2) return buf