Source code for pyafv._plotutils

"""
Adapted from scipy.spatial._plotutils, with modifications to suit
pyafv's data structures.

The original code is licensed under BSD-3-Clause and can be found at:
    https://github.com/scipy/scipy/blob/v1.17.0/scipy/spatial/_plotutils.py

    Copyright (c) 2001-2002 Enthought, Inc. 2003-2025, SciPy Developers.
    All rights reserved.
"""

import numpy as np

__all__ = ['visualize_2d', 'visualize_2d_parallel']


def _get_axes():
    import matplotlib.pyplot as plt

    return plt.figure().gca()


def _adjust_bounds(ax, points, r):
    span = max(np.ptp(points, axis=0))
    if span > 0.1 * r:
        half_width = 0.5 * span + 3.0 * r
    else:
        half_width = 5.0 * r

    center = np.mean(points, axis=0)
    ax.set_xlim(center[0] - half_width, center[0] + half_width)
    ax.set_ylim(center[1] - half_width, center[1] + half_width)


[docs] def visualize_2d(pts: np.ndarray, diag: dict[str, object], r: float, ax = None, **kw): r""" Visualize a 2D snapshot using the diagnostic dictionary `diag` generated by :py:meth:`pyafv.FiniteVoronoiSimulator.build`. This is basically a wrapper around the vectorized custom plotting functions from the example notebooks and generally preferred over the original :py:meth:`pyafv.FiniteVoronoiSimulator.plot_2d` method. The plotting style follows :py:func:`scipy.spatial.voronoi_plot_2d`. .. note:: If you are visualizing a `diag` from :py:meth:`ParallelFiniteVoronoiSimulator.build`, you should use :func:`visualize_2d_parallel` instead. Args: pts: An (N, 2) array of point coordinates. diag: A diagnostic *dict* containing Voronoi diagram information. r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs. ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one. cell_colors (color or list, optional): A single color or a sequence of colors for filling cells, default 'C2'. Use *None* for no fill. fill_alpha (float, optional): Specifies the alpha for cell fills, default 0.1. fill_zorder (float, optional): Specifies the z-order for cell fills, default 0. show_points (bool, optional): Add cell center points to the plot, default *False*. point_size (float, optional): Specifies the marker area for the points, default 4. point_colors (color or list, optional): A single color or a sequence of colors for the points, default 'C0'. point_zorder (float, optional): Specifies the z-order for the points, default 3. straight_colors (color, optional): Color for straight contact edges, default 'C0'. straight_lw (float, optional): Line width for straight edges, default 1.0. straight_alpha (float, optional): Alpha for straight edges, default 1.0. straight_capstyle (str, optional): Cap style for straight edges, default 'butt'. straight_zorder (float, optional): Z-order for straight edges, default 2. arc_colors (color, optional): Color for arc non-contact edges, default 'C2'. arc_lw (float, optional): Line width for arc edges, default 1.0. arc_alpha (float, optional): Alpha for arc edges, default 1.0. arc_capstyle (str, optional): Cap style for arc edges, default 'butt'. arc_zorder (float, optional): Z-order for arc edges, default 1. auto_adjust_bounds (bool, optional): Whether to automatically adjust the plot bounds to fit the diagram, default *True*. Returns: matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas. """ if "plot_mode" in diag: raise ValueError( "diag appears to come from ParallelFiniteVoronoiSimulator.build; " "use visualize_2d_parallel with build(plot_mode=True) output" ) from matplotlib.collections import LineCollection ax = ax or _get_axes() pts = np.asarray(pts, dtype=float) if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover raise ValueError("pts must have shape (N,2)") N = pts.shape[0] # Number of points # Draw cell centers if kw.get('show_points', False): point_size = kw.get('point_size', 2**2) point_colors = kw.get('point_colors', 'C0') point_zorder = kw.get('point_zorder', 3) ax.scatter(pts[:, 0], pts[:, 1], s=point_size, c=point_colors, marker='o', zorder=point_zorder) point_edges_type = diag["edges_type"] point_vertices_f_idx = diag["regions"] vertices_all = diag["vertices"] # --- Classify cells --- cell_lens = np.fromiter((len(et) for et in point_edges_type), dtype=int, count=N) deg_mask = cell_lens < 2 valid_mask = ~deg_mask # --- Per-edge geometry (vectorized) --- straight_segs = None arc_xy = None valid_cells_idx = None offsets = None flat_e = None edge_to_arc = None straight_pts = None if valid_mask.any(): valid_cells_idx = np.where(valid_mask)[0] valid_lens = cell_lens[valid_mask] flat_v = np.concatenate([np.asarray(point_vertices_f_idx[i], dtype=int) for i in valid_cells_idx]) flat_e = np.concatenate([np.asarray(point_edges_type[i], dtype=int) for i in valid_cells_idx]) flat_cell = np.repeat(valid_cells_idx, valid_lens) offsets = np.concatenate(([0], np.cumsum(valid_lens))) next_idx = np.arange(flat_v.size) + 1 next_idx[offsets[1:] - 1] = offsets[:-1] flat_v2 = flat_v[next_idx] straight_mask = flat_e == 1 arc_mask = flat_e == 0 if straight_mask.any(): straight_segs = np.stack([vertices_all[flat_v[straight_mask]], vertices_all[flat_v2[straight_mask]]], axis=1) straight_pts = vertices_all[flat_v2] # (E, 2); consumed only at straight positions if arc_mask.any(): centers = pts[flat_cell[arc_mask]] V1a = vertices_all[flat_v[arc_mask]] V2a = vertices_all[flat_v2[arc_mask]] angle1 = np.arctan2(V1a[:, 1] - centers[:, 1], V1a[:, 0] - centers[:, 0]) angle2 = np.arctan2(V2a[:, 1] - centers[:, 1], V2a[:, 0] - centers[:, 0]) total = (angle1 - angle2) % (2 * np.pi) t = np.linspace(0.0, 1.0, 100) theta = angle1[:, None] - t[None, :] * total[:, None] # v1 -> v2 arc_xy = np.stack([centers[:, 0:1] + r * np.cos(theta), centers[:, 1:2] + r * np.sin(theta)], axis=-1) # (A, 100, 2) edge_to_arc = np.full(flat_e.size, -1, dtype=int) edge_to_arc[arc_mask] = np.arange(arc_mask.sum()) # --- Full-circle polylines for degenerate cells --- deg_circles = None deg_idx = None if deg_mask.any(): deg_idx = np.where(deg_mask)[0] th = np.linspace(0.0, 2 * np.pi, 100) circle_template = np.column_stack([np.cos(th), np.sin(th)]) # (100, 2) deg_circles = pts[deg_idx, None, :] + r * circle_template[None, :, :] # (D, 100, 2) cell_colors = kw.get('cell_colors', 'C2') if cell_colors is not None: from matplotlib.collections import PolyCollection from matplotlib.colors import to_rgba try: # Try treating as a single color to_rgba(cell_colors) cell_colors = [cell_colors] * N except Exception: # Must be a sequence of colors try: cell_colors = list(cell_colors) except TypeError: # pragma: no cover raise TypeError("cell_colors must be a single color or an iterable of colors") if len(cell_colors) != N: # pragma: no cover raise ValueError(f"cell_colors must have length {N}, got {len(cell_colors)}") # --- Assemble per-cell fill polygons (ragged list) --- polygons = [] face_colors = [] if valid_mask.any(): for c_idx, c in enumerate(valid_cells_idx): s, e = offsets[c_idx], offsets[c_idx + 1] parts = [straight_pts[p:p+1] if flat_e[p] == 1 else arc_xy[edge_to_arc[p]] for p in range(s, e)] polygons.append(np.concatenate(parts, axis=0)) face_colors.append(cell_colors[c]) if deg_circles is not None: for k, c in enumerate(deg_idx): polygons.append(deg_circles[k]) face_colors.append(cell_colors[c]) # --- Emit collections --- # Fills (zorder=0) if polygons: fill_alpha = kw.get('fill_alpha', 0.1) fill_zorder = kw.get('fill_zorder', 0) ax.add_collection(PolyCollection(polygons, facecolors=face_colors, alpha=fill_alpha, linewidths=0, zorder=fill_zorder)) # Straight strokes (zorder=2) if straight_segs is not None: straight_colors = kw.get('straight_colors', 'C0') straight_lw = kw.get('straight_lw', 1.0) straight_alpha = kw.get('straight_alpha', 1.0) straight_capstyle = kw.get('straight_capstyle', 'butt') straight_zorder = kw.get('straight_zorder', 2) ax.add_collection(LineCollection(straight_segs, colors=straight_colors, lw=straight_lw, alpha=straight_alpha, capstyle=straight_capstyle, zorder=straight_zorder)) # Arc + full-circle strokes (zorder=1) arc_polylines = [] if arc_xy is not None: arc_polylines.append(arc_xy) if deg_circles is not None: arc_polylines.append(deg_circles) if arc_polylines: arc_colors = kw.get('arc_colors', 'C2') arc_lw = kw.get('arc_lw', 1.0) arc_alpha = kw.get('arc_alpha', 1.0) arc_capstyle = kw.get('arc_capstyle', 'butt') arc_zorder = kw.get('arc_zorder', 1) ax.add_collection(LineCollection(np.concatenate(arc_polylines, axis=0), colors=arc_colors, lw=arc_lw, alpha=arc_alpha, capstyle=arc_capstyle, zorder=arc_zorder)) if kw.get('auto_adjust_bounds', True): _adjust_bounds(ax, pts, r) ax.set_aspect("equal") return ax.figure
def _domain_style_value(value, owned_global_ids: np.ndarray, n_points: int): if value is None: # pragma: no cover return None from matplotlib.colors import to_rgba try: # pragma: no cover to_rgba(value) return value except Exception: pass try: values = list(value) except TypeError: # pragma: no cover return value if len(values) == n_points: return np.asarray(values, dtype=object)[owned_global_ids].tolist() return value # pragma: no cover def _domain_sequence_value(value, owned_global_ids: np.ndarray, n_points: int): # pragma: no cover if np.ndim(value) == 0: return value values = np.asarray(value) if len(values) == n_points: return values[owned_global_ids] return value
[docs] def visualize_2d_parallel( pts: np.ndarray, diag: dict[str, object], r: float, ax = None, **kw, ): r""" Visualize a 2D snapshot from :py:meth:`pyafv.ParallelFiniteVoronoiSimulator.build`. Note that "_parallel" here means that `diag` is generated by the parallel simulator, not that the visualization itself is parallelized with Python multiprocessing (though we have vectorized the drawing of individual domains). The diagnostic dictionary must be built with ``plot_mode=True``. Each domain is drawn using :func:`visualize_2d`, and the global axes bounds are adjusted once at the end. Args: pts: An (N, 2) array of point coordinates. diag: A diagnostic *dict* containing Voronoi diagram information. r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs. ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one. **kw: Additional keyword arguments passed to :func:`visualize_2d`. Returns: matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas. """ pts = np.asarray(pts, dtype=float) if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover raise ValueError("pts must have shape (N,2)") if not diag.get("plot_mode", False) or "diag_plot" not in diag: raise ValueError("diag must come from build(plot_mode=True)") if "owned_global_ids" not in diag: # pragma: no cover raise ValueError("diag is missing owned_global_ids") ax = ax or _get_axes() owned_groups = diag["owned_global_ids"] domain_diags = diag["diag_plot"] if len(owned_groups) != len(domain_diags): # pragma: no cover raise ValueError("owned_global_ids and diag_plot must have the same length") auto_adjust_bounds = kw.pop("auto_adjust_bounds", True) n_points = pts.shape[0] for owned_global_ids, domain_diag in zip(owned_groups, domain_diags): owned_global_ids = np.asarray(owned_global_ids, dtype=int) if owned_global_ids.size == 0: # pragma: no cover continue domain_kw = dict(kw) for key in ("cell_colors", "point_colors"): if key in domain_kw: domain_kw[key] = _domain_style_value( domain_kw[key], owned_global_ids, n_points, ) if "point_size" in domain_kw: domain_kw["point_size"] = _domain_sequence_value( domain_kw["point_size"], owned_global_ids, n_points, ) visualize_2d( pts[owned_global_ids], domain_diag, r, ax=ax, auto_adjust_bounds=False, **domain_kw, ) if auto_adjust_bounds: _adjust_bounds(ax, pts, r) ax.set_aspect("equal") return ax.figure