Source code for emlib.matplotting

"""
Routines to help draw shapes / labels within a matplotlib (pyplot) plot.
Implements the concept of a plotting profile, which makes it easier to
define sizes, colors, etc. for a series of elements.
"""
from __future__ import annotations
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection

import numpy as np
from emlib.misc import isiterable, pixels_to_inches
from typing import Union


__all__ = (
    'autoscaleAxis',
    'defaultprofile',
    'drawBracket',
    'drawConnectedLines',
    'drawLabel',
    'drawLine',
    'drawRect',
    'drawRects',
    'makeAxis',
    'makeProfile',
    'plotDurs',
)


defaultprofile = {
    'label_font': 'sans-serif',
    'label_size': 10,
    'label_alpha': 0.75,
    'line_alpha': 0.8,
    'line_style': 'solid',
    'edgecolor': 1,
    'facecolor': 0.8,
    'alpha': 0.75,
    'linewidth': 1,
    'annotation_color': (0, 0, 0),
    'annotation_alpha': 0.3,
    'autoscale': True,
    'background': (0, 0, 0)
}


[docs] def makeProfile(default: dict = None, **kws): """ Create a profile based on a default profile A profile is used to determine multiple defaults Example ------- >>> makeProfile( ... label_font="Roboto", ... background=(10, 10, 10), ... linewidth=2) """ out = (default or defaultprofile).copy() for key, value in kws.items(): if key not in default: raise KeyError(f"Key {key} not in default profile") out[key] = value return out
_colormap = plt.get_cmap('jet') def _get(profile: dict, key: str, value=None): if profile is not None and key in profile: return profile[key] return defaultprofile.get(key, value)
[docs] def drawLabel(ax: plt.Axes, x: float, y: float, text: str, size=None, alpha=None, profile=None) -> None: """ Draw a text label at the given coordinates Args: ax: the plot axes x: x coordinate y: y coordinate text: the text size: size of the label. If given, overrides the profile's "label_size" alpha: if given, overrides the profile's "label_alpha" profile: the profile used (None = default) """ family = _get(profile, 'label_font') size = _fallback(size, profile, 'label_size') alpha = _fallback(alpha, profile, 'label_alpha') ax.text(x, y, text, ha="center", family=family, size=size, alpha=alpha)
[docs] def drawLine(ax: plt.Axes, x0: float, y0: float, x1: float, y1: float, color: float=None, linestyle:str = 'solid', alpha: float=None, linewidth:float=None, label: str=None, profile=None) -> None: """ Draw a line from ``(x0, y0)`` to ``(x1, y1)`` Args: ax: a plt.Axes to draw on x0: x coord of the start point y0: y coord of the start point x1: x coord of the end point y1: y coord of the end point color: the color of the line as a value 0-1 within the colormap space linestyle: 'solid', 'dashed' alpha: a float 0-1 label: if given, a label is plotted next to the line profile: the profile (created via makeProfile) to use. Leave None to use the default profile Examples -------- >>> import matplotlib.pyplot as plt >>> from emlib import matplotting >>> fig, ax = plt.subplots() >>> matplotting.drawLine(ax, 0, 0, 1, 1) >>> plt.show() """ linewidth = _fallback(linewidth, profile, 'line_width') alpha = _fallback(alpha, profile, 'line_alpha') color = _fallback(color, profile, 'edgecolor') X, Y = np.array([[x0, x1], [y0, y1]]) assert linestyle in ('solid', 'dashed') line = mlines.Line2D(X, Y, lw=linewidth, alpha=alpha, color=_colormap(color), linestyle=linestyle) ax.add_line(line) if label is not None: drawLabel(ax, x=(x0+x1)*0.5, y=y0, text=label, profile=profile) if _get(profile, 'autoscale'): autoscaleAxis(ax)
def _aslist(obj) -> list: if isinstance(obj, list): return obj return list(obj) def _getcolor(color: Union[float, tuple]) -> tuple[float, float, float, float]: if isinstance(color, tuple): return color return _colormap(color) def _unzip(pairs): return zip(*pairs)
[docs] def drawConnectedLines(ax: plt.Axes, pairs: list[tuple[float, float]], connectEdges=False, color: Union[float, tuple] = None, alpha: float = None, linewidth: float = None, label: str = None, linestyle: str = None, profile: dict = None ) -> None: """ Draw an open / closed poligon Args: ax: the plot axes pairs: a list of (x, y) pairs connectEdges: close the form, connecting start end end points color: the color to use. A float selects a color from the current color map alpha: alpha value of the lines linewidth: the line width label: an optional label to attach to the start of the lines linestyle: the line style, one of "solid", "dashed" profile: the profile used, or None for default """ linewidth = _fallback(linewidth, profile, 'line_width') alpha = _fallback(alpha, profile, 'line_alpha') color = _fallback(color, profile, 'edgecolor') linestyle = _fallback(linestyle, profile, 'line_style') if connectEdges: pairs = pairs + pairs[0] color = _getcolor(color) xs, ys = _unzip(pairs) line = mlines.Line2D(xs, ys, lw=linewidth, alpha=alpha, color=_getcolor(color), linestyle=linestyle) ax.add_line(line) if label is not None: avgx = sum(xs)/len(xs) avgy = sum(ys)/len(ys) drawLabel(ax, x=avgx, y=avgy, text=label, profile=profile) if _get(profile, 'autoscale'): autoscaleAxis(ax)
[docs] def drawRect(ax: plt.Axes, x0:float, y0:float, x1:float, y1:float, color=None, alpha:float=None, edgecolor=None, label:str=None, profile:dict=None) -> None: """ Draw a rectangle from point (x0, y0) to (x1, y1) Args: ax: the plot axe x0: x coord of the start point y0: y coord of the start point x1: x coord of the end point y1: y coord of the end point color: the face color edgecolor: the color of the edges alpha: alpha value for the rectangle (both facecolor and edgecolor) label: if given, a label is plotted at the center of the rectangle profile: the profile used, or None for default """ facecolor = _fallback(color, profile, 'facecolor') edgecolor = _fallback(edgecolor, profile, 'edgecolor') facecolor = _getcolor(facecolor) edgecolor = _getcolor(edgecolor) alpha = alpha if alpha is not None else _get(profile, 'alpha') rect = Rectangle((x0, y0), x1-x0, y1-y0, facecolor=facecolor, edgecolor=edgecolor, alpha=alpha) ax.add_patch(rect) if label is not None: drawLabel(ax, x=(x0+x1)*0.5, y=(y0+y1)*0.5, text=label, profile=profile) if _get(profile, 'autoscale'): autoscaleAxis(ax)
def _many(value, numitems:int, key:str=None, profile:dict=None) -> list: if isiterable(value): return value elif value is None: return [_get(profile, key)] * numitems return [value] * numitems
[docs] def drawRects(ax: plt.Axes, data, facecolor=None, alpha:float=None, edgecolor=None, linewidth:float=None, profile:dict=None, autolim=True) -> None: """ Draw multiple rectangles Args: ax: the plot axes data: either a 2D array of shape (num. rectangles, 4), or a list of tuples (x0, y0, x1, y1), where each row is a rectangle facecolor: the face color edgecolor: the color of the edges alpha: alpha value for the rectangle (both facecolor and edgecolor) label: if given, a label is plotted at the center of the rectangle profile: the profile used, or None for default autolim: autoscale view """ facecolor = _fallbackColor(facecolor, profile, key='facecolor') edgecolor = _fallbackColor(edgecolor, profile, key='edgecolor') linewidth = _fallback(linewidth, profile, 'linewidth') rects = [] for coords in data: x0, y0, x1, y1 = coords rect = Rectangle((x0, y0), x1-x0, y1-y0) rects.append(rect) coll = PatchCollection(rects, linewidth=linewidth, alpha=alpha, edgecolor=edgecolor, facecolor=facecolor) ax.add_collection(coll, autolim=True) if autolim: ax.autoscale_view()
[docs] def autoscaleAxis(ax: plt.Axes) -> None: ax.relim() ax.autoscale_view(True,True,True)
[docs] def makeAxis(pixels: tuple[int, int]=None, dpi=96) -> plt.Axes: """ Create a plotting axes Args: pixels: the size of the plot, in pixels dpi: dots per inch Returns: the plt.Axes """ # plt.subplots(figsize=(20, 10)) if pixels is None: fig, ax = plt.subplots() return ax if not isinstance(pixels, tuple): raise TypeError(f"pixels should be of the form (x, y), got {pixels}") xinches = pixels_to_inches(pixels[0], dpi=dpi) yinches = pixels_to_inches(pixels[1], dpi=dpi) fig,ax = plt.subplots(figsize=(xinches, yinches), dpi=dpi) return ax
def _fallback(value, profile: dict, key: str): return value if value is not None else _get(profile, key) def _fallbackColor(value, profile: dict, key: str) -> tuple[float, float, float, float]: if value is not None: return _getcolor(value) return _getcolor(_get(profile, key))
[docs] def drawBracket(ax: plt.Axes, x0: float, y0: float, x1: float, y1: float, label: str = None, color=None, linewidth: float = None, alpha: float = None, profile: dict = None) -> None: """ Draw a bracket from (x0, y0) to (x1, y1) Args: ax: the plot axe x0: x coord of the start point y0: y coord of the start point x1: x coord of the end point y1: y coord of the end point color: the face color alpha: alpha value for the rectangle (both facecolor and edgecolor) label: if given, a label is plotted at the center of the rectangle profile: the profile used, or None for default """ linewidth = _fallback(linewidth, profile, 'linewidth') color = _fallback(color, profile, 'edgecolor') alpha = _fallback(alpha, profile, 'annotation_alpha') data = [(x0, y0), (x0, y1), (x1, y1), (x1, y0)] drawConnectedLines(ax, data, color=color, linewidth=linewidth, label=label, alpha=alpha)
[docs] def plotDurs(durs: list[float], y0=0.0, x0=0.0, height=1.0, labels: list[str] = None, color=None, ax=None, groupLabel: str = None, profile: dict = None, stacked=False ) -> plt.Axes: """ Plot durations as contiguous rectangles Args: durs: the durations expressed in seconds y0: y of origin x0: x of origin height: the height of the drawn rectangles labels: if given, a label for each rectangle color: the color used for the rectangles ax: the axes to draw on. If not given, a new axes is created (and returned) groupLabel: a label for the group profile: the profile used, or None to use a default stacked: if True, the rectangles are drawn stacked vertically (the duration is still drawn horizontally). The result is then similar to a bars plot Returns: the plot axes. If *ax* was given, then it is returned; otherwise the new axes is returned. """ if ax is None: ax = makeAxis() numitems = len(durs) labels = labels if isiterable(labels) else [labels]*numitems color = _fallbackColor(color, profile, 'facecolor') if not stacked: x = x0 data = [] for i, dur in enumerate(durs): data.append((x, y0, x+dur, y0+height)) x += dur drawRects(ax, data, facecolor=color) if groupLabel is not None: sep = height * 0.05 y1 = y0 + height x1 = x0 + sum(durs) drawBracket(ax, x0, y1+sep, x1, y1+sep*2, color=_get(profile, 'annotation_color')) alpha = (_get(profile, 'annotation_alpha')+1)*0.5 drawLabel(ax, (x0+x1) * 0.5, y1 + sep, text=groupLabel, alpha=alpha) else: data = [] y = y0 for dur in durs: data.append((x0, y, x0+dur, y+height)) y += height drawRects(ax, data, facecolor=color) return ax