Source code for pyafv._plotutils

"""
Adapted from scipy.spatial._plotutils, with modifications to suit
pyafv's data structures.

The original code is licensed under BSD-3-Clause and can be found at:
    https://github.com/scipy/scipy/blob/v1.17.0/scipy/spatial/_plotutils.py

    Copyright (c) 2001-2002 Enthought, Inc. 2003-2025, SciPy Developers.
    All rights reserved.
"""

import numpy as np

__all__ = ['visualize_2d', 'visualize_2d_parallel']


def _get_axes():
    import matplotlib.pyplot as plt

    return plt.figure().gca()


def _adjust_bounds(ax, points, r):
    span = max(np.ptp(points, axis=0))
    if span > 0.1 * r:
        half_width = 0.5 * span + 3.0 * r
    else:
        half_width = 5.0 * r

    center = np.mean(points, axis=0)
    ax.set_xlim(center[0] - half_width, center[0] + half_width)
    ax.set_ylim(center[1] - half_width, center[1] + half_width)


def _slice_style_value(value, cell_indices: np.ndarray, n_points: int):     # pragma: no cover
    if value is None:
        return None

    from matplotlib.colors import to_rgba

    try:
        to_rgba(value)
        return value
    except Exception:
        pass

    try:
        values = list(value)
    except TypeError:
        return value

    if len(values) == n_points:
        return np.asarray(values, dtype=object)[cell_indices].tolist()
    return value


def _slice_sequence_value(value, cell_indices: np.ndarray, n_points: int):      # pragma: no cover
    if np.ndim(value) == 0:
        return value

    values = np.asarray(value)
    if len(values) == n_points:
        return values[cell_indices]
    return value


[docs] def visualize_2d( pts: np.ndarray, diag: dict[str, object], r: float, ax = None, *, selected = None, **kw, ): r""" Visualize a 2D snapshot using the diagnostic dictionary `diag` generated by :py:meth:`pyafv.FiniteVoronoiSimulator.build`. This is basically a wrapper around the vectorized custom plotting functions from the example notebooks and generally preferred over the original :py:meth:`pyafv.FiniteVoronoiSimulator.plot_2d` method. The plotting style follows :py:func:`scipy.spatial.voronoi_plot_2d`. .. note:: If you are visualizing a `diag` from :py:meth:`ParallelFiniteVoronoiSimulator.build`, you should use :func:`visualize_2d_parallel` instead. Args: pts: An (N, 2) array of point coordinates. diag: A diagnostic *dict* containing Voronoi diagram information. r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs. ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one. selected (array-like | None, optional): Cells to draw. This can be either a one-dimensional array of integer indices or a boolean mask with length equal to ``len(pts)``. If *None*, draw all cells. cell_colors (color or list, optional): A single color or a sequence of colors for filling cells, default 'C2'. Use *None* for no fill. If *selected* is provided, full-length per-cell color sequences are sliced to the selected cells. fill_alpha (float, optional): Specifies the alpha for cell fills, default 0.1. fill_zorder (float, optional): Specifies the z-order for cell fills, default 0. show_points (bool, optional): Add cell center points to the plot, default *False*. point_size (float, optional): Specifies the marker area for the points, default 4. If *selected* is provided, full-length per-cell sequences are sliced to the selected cells. point_colors (color or list, optional): A single color or a sequence of colors for the points, default 'C0'. If *selected* is provided, full-length per-cell color sequences are sliced to the selected cells. point_zorder (float, optional): Specifies the z-order for the points, default 3. straight_colors (color, optional): Color for straight contact edges, default 'C0'. straight_lw (float, optional): Line width for straight edges, default 1.0. straight_alpha (float, optional): Alpha for straight edges, default 1.0. straight_capstyle (str, optional): Cap style for straight edges, default 'butt'. straight_zorder (float, optional): Z-order for straight edges, default 2. arc_colors (color, optional): Color for arc non-contact edges, default 'C2'. arc_lw (float, optional): Line width for arc edges, default 1.0. arc_alpha (float, optional): Alpha for arc edges, default 1.0. arc_capstyle (str, optional): Cap style for arc edges, default 'butt'. arc_zorder (float, optional): Z-order for arc edges, default 1. auto_adjust_bounds (bool, optional): Whether to automatically adjust the plot bounds to fit the diagram, default *True*. Returns: matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas. """ if "plot_mode" in diag: raise ValueError( "diag appears to come from ParallelFiniteVoronoiSimulator.build; " "use visualize_2d_parallel with build(plot_mode=True) output" ) from matplotlib.collections import LineCollection ax = ax or _get_axes() pts = np.asarray(pts, dtype=float) if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover raise ValueError("pts must have shape (N,2)") full_N = pts.shape[0] # Number of points in the original diagram if selected is None: point_edges_type = diag["edges_type"] point_vertices_f_idx = diag["regions"] else: selected = np.asarray(selected) if selected.dtype == bool: if selected.shape != (full_N,): # pragma: no cover raise ValueError(f"boolean selected must have shape ({full_N},)") selected = np.flatnonzero(selected) # pragma: no cover else: selected = np.asarray(selected, dtype=int) if selected.ndim != 1: # pragma: no cover raise ValueError("selected must be a one-dimensional array") if selected.size == 0: # pragma: no cover raise ValueError("selected must not be empty") if np.any((selected < 0) | (selected >= full_N)): # pragma: no cover raise IndexError("selected indices are out of bounds") pts = pts[selected] point_edges_type = [diag["edges_type"][i] for i in selected] point_vertices_f_idx = [diag["regions"][i] for i in selected] if any(key in kw for key in ("cell_colors", "point_colors", "point_size")): kw = dict(kw) for key in ("cell_colors", "point_colors"): if key in kw: kw[key] = _slice_style_value(kw[key], selected, full_N) if "point_size" in kw: kw["point_size"] = _slice_sequence_value(kw["point_size"], selected, full_N) N = pts.shape[0] # Number of points to plot # Draw cell centers if kw.get('show_points', False): point_size = kw.get('point_size', 2**2) point_colors = kw.get('point_colors', 'C0') point_zorder = kw.get('point_zorder', 3) ax.scatter(pts[:, 0], pts[:, 1], s=point_size, c=point_colors, marker='o', zorder=point_zorder) vertices_all = diag["vertices"] # --- Classify cells --- cell_lens = np.fromiter((len(et) for et in point_edges_type), dtype=int, count=N) deg_mask = cell_lens < 2 valid_mask = ~deg_mask # --- Per-edge geometry (vectorized) --- straight_segs = None arc_xy = None valid_cells_idx = None offsets = None flat_e = None edge_to_arc = None straight_pts = None if valid_mask.any(): valid_cells_idx = np.where(valid_mask)[0] valid_lens = cell_lens[valid_mask] flat_v = np.concatenate([np.asarray(point_vertices_f_idx[i], dtype=int) for i in valid_cells_idx]) flat_e = np.concatenate([np.asarray(point_edges_type[i], dtype=int) for i in valid_cells_idx]) flat_cell = np.repeat(valid_cells_idx, valid_lens) offsets = np.concatenate(([0], np.cumsum(valid_lens))) next_idx = np.arange(flat_v.size) + 1 next_idx[offsets[1:] - 1] = offsets[:-1] flat_v2 = flat_v[next_idx] straight_mask = flat_e == 1 arc_mask = flat_e == 0 if straight_mask.any(): straight_segs = np.stack([vertices_all[flat_v[straight_mask]], vertices_all[flat_v2[straight_mask]]], axis=1) straight_pts = vertices_all[flat_v2] # (E, 2); consumed only at straight positions if arc_mask.any(): centers = pts[flat_cell[arc_mask]] V1a = vertices_all[flat_v[arc_mask]] V2a = vertices_all[flat_v2[arc_mask]] angle1 = np.arctan2(V1a[:, 1] - centers[:, 1], V1a[:, 0] - centers[:, 0]) angle2 = np.arctan2(V2a[:, 1] - centers[:, 1], V2a[:, 0] - centers[:, 0]) total = (angle1 - angle2) % (2 * np.pi) t = np.linspace(0.0, 1.0, 100) theta = angle1[:, None] - t[None, :] * total[:, None] # v1 -> v2 arc_xy = np.stack([centers[:, 0:1] + r * np.cos(theta), centers[:, 1:2] + r * np.sin(theta)], axis=-1) # (A, 100, 2) edge_to_arc = np.full(flat_e.size, -1, dtype=int) edge_to_arc[arc_mask] = np.arange(arc_mask.sum()) # --- Full-circle polylines for degenerate cells --- deg_circles = None deg_idx = None if deg_mask.any(): deg_idx = np.where(deg_mask)[0] th = np.linspace(0.0, 2 * np.pi, 100) circle_template = np.column_stack([np.cos(th), np.sin(th)]) # (100, 2) deg_circles = pts[deg_idx, None, :] + r * circle_template[None, :, :] # (D, 100, 2) cell_colors = kw.get('cell_colors', 'C2') if cell_colors is not None: from matplotlib.collections import PolyCollection from matplotlib.colors import to_rgba try: # Try treating as a single color to_rgba(cell_colors) cell_colors = [cell_colors] * N except Exception: # Must be a sequence of colors try: cell_colors = list(cell_colors) except TypeError: # pragma: no cover raise TypeError("cell_colors must be a single color or an iterable of colors") if len(cell_colors) != N: # pragma: no cover raise ValueError(f"cell_colors must have length {N}, got {len(cell_colors)}") # --- Assemble per-cell fill polygons (ragged list) --- polygons = [] face_colors = [] if valid_mask.any(): for c_idx, c in enumerate(valid_cells_idx): s, e = offsets[c_idx], offsets[c_idx + 1] parts = [straight_pts[p:p+1] if flat_e[p] == 1 else arc_xy[edge_to_arc[p]] for p in range(s, e)] polygons.append(np.concatenate(parts, axis=0)) face_colors.append(cell_colors[c]) if deg_circles is not None: for k, c in enumerate(deg_idx): polygons.append(deg_circles[k]) face_colors.append(cell_colors[c]) # --- Emit collections --- # Fills (zorder=0) if polygons: fill_alpha = kw.get('fill_alpha', 0.1) fill_zorder = kw.get('fill_zorder', 0) ax.add_collection(PolyCollection(polygons, facecolors=face_colors, alpha=fill_alpha, linewidths=0, zorder=fill_zorder)) # Straight strokes (zorder=2) if straight_segs is not None: straight_colors = kw.get('straight_colors', 'C0') straight_lw = kw.get('straight_lw', 1.0) straight_alpha = kw.get('straight_alpha', 1.0) straight_capstyle = kw.get('straight_capstyle', 'butt') straight_zorder = kw.get('straight_zorder', 2) ax.add_collection(LineCollection(straight_segs, colors=straight_colors, lw=straight_lw, alpha=straight_alpha, capstyle=straight_capstyle, zorder=straight_zorder)) # Arc + full-circle strokes (zorder=1) arc_polylines = [] if arc_xy is not None: arc_polylines.append(arc_xy) if deg_circles is not None: arc_polylines.append(deg_circles) if arc_polylines: arc_colors = kw.get('arc_colors', 'C2') arc_lw = kw.get('arc_lw', 1.0) arc_alpha = kw.get('arc_alpha', 1.0) arc_capstyle = kw.get('arc_capstyle', 'butt') arc_zorder = kw.get('arc_zorder', 1) ax.add_collection(LineCollection(np.concatenate(arc_polylines, axis=0), colors=arc_colors, lw=arc_lw, alpha=arc_alpha, capstyle=arc_capstyle, zorder=arc_zorder)) if kw.get('auto_adjust_bounds', True): _adjust_bounds(ax, pts, r) ax.set_aspect("equal") return ax.figure
[docs] def visualize_2d_parallel( pts: np.ndarray, diag: dict[str, object], r: float, ax = None, *, selected = None, **kw, ): r""" Visualize a 2D snapshot from :py:meth:`pyafv.ParallelFiniteVoronoiSimulator.build`. Note that "_parallel" here means that `diag` is generated by the parallel simulator, not that the visualization itself is parallelized with Python multiprocessing (though we have vectorized the drawing of individual domains). The diagnostic dictionary must be built with ``plot_mode=True``. Each domain is drawn using :func:`visualize_2d`, and the global axes bounds are adjusted once at the end. Args: pts: An (N, 2) array of point coordinates. diag: A diagnostic *dict* containing Voronoi diagram information. r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs. ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one. selected (array-like | None, optional): Global cells to draw. This can be either a one-dimensional array of integer indices or a boolean mask with length equal to ``len(pts)``. If *None*, draw all owned cells from every domain. **kw: Additional keyword arguments passed to :func:`visualize_2d`. Returns: matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas. """ pts = np.asarray(pts, dtype=float) if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover raise ValueError("pts must have shape (N,2)") if not diag.get("plot_mode", False) or "diag_plot" not in diag: raise ValueError("diag must come from build(plot_mode=True)") if "owned_global_ids" not in diag: # pragma: no cover raise ValueError("diag is missing owned_global_ids") ax = ax or _get_axes() owned_groups = diag["owned_global_ids"] domain_diags = diag["diag_plot"] if len(owned_groups) != len(domain_diags): # pragma: no cover raise ValueError("owned_global_ids and diag_plot must have the same length") auto_adjust_bounds = kw.pop("auto_adjust_bounds", True) n_points = pts.shape[0] if selected is None: selected_mask = None bounds_pts = pts else: selected = np.asarray(selected) if selected.dtype == bool: if selected.shape != (n_points,): # pragma: no cover raise ValueError(f"boolean selected must have shape ({n_points},)") selected = np.flatnonzero(selected) # pragma: no cover else: selected = np.asarray(selected, dtype=int) if selected.ndim != 1: # pragma: no cover raise ValueError("selected must be a one-dimensional array") if selected.size == 0: # pragma: no cover raise ValueError("selected must not be empty") if np.any((selected < 0) | (selected >= n_points)): # pragma: no cover raise IndexError("selected indices are out of bounds") selected_mask = np.zeros(n_points, dtype=bool) selected_mask[selected] = True bounds_pts = pts[selected] for owned_global_ids, domain_diag in zip(owned_groups, domain_diags): owned_global_ids = np.asarray(owned_global_ids, dtype=int) if owned_global_ids.size == 0: # pragma: no cover continue domain_selected = None if selected_mask is not None: domain_selected = np.flatnonzero(selected_mask[owned_global_ids]) if domain_selected.size == 0: continue domain_kw = dict(kw) for key in ("cell_colors", "point_colors"): if key in domain_kw: domain_kw[key] = _slice_style_value( domain_kw[key], owned_global_ids, n_points, ) if "point_size" in domain_kw: domain_kw["point_size"] = _slice_sequence_value( domain_kw["point_size"], owned_global_ids, n_points, ) visualize_2d( pts[owned_global_ids], domain_diag, r, ax=ax, selected=domain_selected, auto_adjust_bounds=False, **domain_kw, ) if auto_adjust_bounds: _adjust_bounds(ax, bounds_pts, r) ax.set_aspect("equal") return ax.figure