"""
Adapted from scipy.spatial._plotutils, with modifications to suit
pyafv's data structures.
The original code is licensed under BSD-3-Clause and can be found at:
https://github.com/scipy/scipy/blob/v1.17.0/scipy/spatial/_plotutils.py
Copyright (c) 2001-2002 Enthought, Inc. 2003-2025, SciPy Developers.
All rights reserved.
"""
import numpy as np
__all__ = ['visualize_2d']
def _get_axes():
import matplotlib.pyplot as plt
return plt.figure().gca()
def _adjust_bounds(ax, points, r):
span = max(np.ptp(points, axis=0))
if span > 0.1 * r:
half_width = 0.5 * span + 3.0 * r
else:
half_width = 5.0 * r
center = np.mean(points, axis=0)
ax.set_xlim(center[0] - half_width, center[0] + half_width)
ax.set_ylim(center[1] - half_width, center[1] + half_width)
[docs]
def visualize_2d(pts: np.ndarray, diag: dict[str, object], r: float, ax = None, **kw):
r"""
Visualize a 2D snapshot using the diagnostic *dict* `diag` generated by :py:meth:`pyafv.FiniteVoronoiSimulator.build`.
This is basically a wrapper around the vectorized custom plotting functions used from the example notebooks.
The plotting style follows :py:func:`scipy.spatial.voronoi_plot_2d`.
Args:
pts: An (N, 2) array of point coordinates.
diag: A diagnostic *dict* containing Voronoi diagram information.
r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs.
ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one.
cell_colors (color or list, optional): A single color or a sequence of colors for filling cells, default 'C2'. Use *None* for no fill.
fill_alpha (float, optional): Specifies the alpha for cell fills, default 0.1.
fill_zorder (float, optional): Specifies the z-order for cell fills, default 0.
show_points (bool, optional): Add cell center points to the plot, default *False*.
point_size (float, optional): Specifies the marker area for the points, default 4.
point_colors (color or list, optional): A single color or a sequence of colors for the points, default 'C0'.
point_zorder (float, optional): Specifies the z-order for the points, default 3.
straight_colors (color or list, optional): A single color or a sequence of colors for straight contact edges, default 'C0'.
straight_lw (float, optional): Line width for straight edges, default 1.0.
straight_alpha (float, optional): Alpha for straight edges, default 1.0.
straight_capstyle (str, optional): Cap style for straight edges, default 'butt'.
straight_zorder (float, optional): Z-order for straight edges, default 2.
arc_colors (color or list, optional): A single color or a sequence of colors for arc non-contact edges, default 'C2'.
arc_lw (float, optional): Line width for arc edges, default 1.0.
arc_alpha (float, optional): Alpha for arc edges, default 1.0.
arc_capstyle (str, optional): Cap style for arc edges, default 'butt'.
arc_zorder (float, optional): Z-order for arc edges, default 1.
auto_adjust_bounds (bool, optional): Whether to automatically adjust the plot bounds to fit the diagram, default *True*.
Returns:
matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas.
"""
from matplotlib.collections import LineCollection
ax = ax or _get_axes()
pts = np.asarray(pts, dtype=float)
if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover
raise ValueError("pts must have shape (N,2)")
N = pts.shape[0] # Number of points
# Draw cell centers
if kw.get('show_points', False):
point_size = kw.get('point_size', 2**2)
point_colors = kw.get('point_colors', 'C0')
point_zorder = kw.get('point_zorder', 3)
ax.scatter(pts[:, 0], pts[:, 1], s=point_size, c=point_colors, marker='o', zorder=point_zorder)
point_edges_type = diag["edges_type"]
point_vertices_f_idx = diag["regions"]
vertices_all = diag["vertices"]
# --- Classify cells ---
cell_lens = np.fromiter((len(et) for et in point_edges_type), dtype=int, count=N)
deg_mask = cell_lens < 2
valid_mask = ~deg_mask
# --- Per-edge geometry (vectorized) ---
straight_segs = None
arc_xy = None
valid_cells_idx = None
offsets = None
flat_e = None
edge_to_arc = None
straight_pts = None
if valid_mask.any():
valid_cells_idx = np.where(valid_mask)[0]
valid_lens = cell_lens[valid_mask]
flat_v = np.concatenate([np.asarray(point_vertices_f_idx[i], dtype=int) for i in valid_cells_idx])
flat_e = np.concatenate([np.asarray(point_edges_type[i], dtype=int) for i in valid_cells_idx])
flat_cell = np.repeat(valid_cells_idx, valid_lens)
offsets = np.concatenate(([0], np.cumsum(valid_lens)))
next_idx = np.arange(flat_v.size) + 1
next_idx[offsets[1:] - 1] = offsets[:-1]
flat_v2 = flat_v[next_idx]
straight_mask = flat_e == 1
arc_mask = flat_e == 0
if straight_mask.any():
straight_segs = np.stack([vertices_all[flat_v[straight_mask]], vertices_all[flat_v2[straight_mask]]], axis=1)
straight_pts = vertices_all[flat_v2] # (E, 2); consumed only at straight positions
if arc_mask.any():
centers = pts[flat_cell[arc_mask]]
V1a = vertices_all[flat_v[arc_mask]]
V2a = vertices_all[flat_v2[arc_mask]]
angle1 = np.arctan2(V1a[:, 1] - centers[:, 1], V1a[:, 0] - centers[:, 0])
angle2 = np.arctan2(V2a[:, 1] - centers[:, 1], V2a[:, 0] - centers[:, 0])
total = (angle1 - angle2) % (2 * np.pi)
t = np.linspace(0.0, 1.0, 100)
theta = angle1[:, None] - t[None, :] * total[:, None] # v1 -> v2
arc_xy = np.stack([centers[:, 0:1] + r * np.cos(theta), centers[:, 1:2] + r * np.sin(theta)], axis=-1) # (A, 100, 2)
edge_to_arc = np.full(flat_e.size, -1, dtype=int)
edge_to_arc[arc_mask] = np.arange(arc_mask.sum())
# --- Full-circle polylines for degenerate cells ---
deg_circles = None
deg_idx = None
if deg_mask.any():
deg_idx = np.where(deg_mask)[0]
th = np.linspace(0.0, 2 * np.pi, 100)
circle_template = np.column_stack([np.cos(th), np.sin(th)]) # (100, 2)
deg_circles = pts[deg_idx, None, :] + r * circle_template[None, :, :] # (D, 100, 2)
cell_colors = kw.get('cell_colors', 'C2')
if cell_colors is not None:
from matplotlib.collections import PolyCollection
from matplotlib.colors import to_rgba
try:
# Try treating as a single color
to_rgba(cell_colors)
cell_colors = [cell_colors] * N
except Exception:
# Must be a sequence of colors
try:
cell_colors = list(cell_colors)
except TypeError: # pragma: no cover
raise TypeError("cell_colors must be a single color or an iterable of colors")
if len(cell_colors) != N: # pragma: no cover
raise ValueError(f"cell_colors must have length {N}, got {len(cell_colors)}")
# --- Assemble per-cell fill polygons (ragged list) ---
polygons = []
face_colors = []
if valid_mask.any():
for c_idx, c in enumerate(valid_cells_idx):
s, e = offsets[c_idx], offsets[c_idx + 1]
parts = [straight_pts[p:p+1] if flat_e[p] == 1 else arc_xy[edge_to_arc[p]] for p in range(s, e)]
polygons.append(np.concatenate(parts, axis=0))
face_colors.append(cell_colors[c])
if deg_circles is not None:
for k, c in enumerate(deg_idx):
polygons.append(deg_circles[k])
face_colors.append(cell_colors[c])
# --- Emit collections ---
# Fills (zorder=0)
if polygons:
fill_alpha = kw.get('fill_alpha', 0.1)
fill_zorder = kw.get('fill_zorder', 0)
ax.add_collection(PolyCollection(polygons, facecolors=face_colors, alpha=fill_alpha, linewidths=0, zorder=fill_zorder))
# Straight strokes (zorder=2)
if straight_segs is not None:
straight_colors = kw.get('straight_colors', 'C0')
straight_lw = kw.get('straight_lw', 1.0)
straight_alpha = kw.get('straight_alpha', 1.0)
straight_capstyle = kw.get('straight_capstyle', 'butt')
straight_zorder = kw.get('straight_zorder', 2)
ax.add_collection(LineCollection(straight_segs, colors=straight_colors, lw=straight_lw, alpha=straight_alpha, capstyle=straight_capstyle, zorder=straight_zorder))
# Arc + full-circle strokes (zorder=1)
arc_polylines = []
if arc_xy is not None:
arc_polylines.append(arc_xy)
if deg_circles is not None:
arc_polylines.append(deg_circles)
if arc_polylines:
arc_colors = kw.get('arc_colors', 'C2')
arc_lw = kw.get('arc_lw', 1.0)
arc_alpha = kw.get('arc_alpha', 1.0)
arc_capstyle = kw.get('arc_capstyle', 'butt')
arc_zorder = kw.get('arc_zorder', 1)
ax.add_collection(LineCollection(np.concatenate(arc_polylines, axis=0), colors=arc_colors, lw=arc_lw, alpha=arc_alpha, capstyle=arc_capstyle, zorder=arc_zorder))
if kw.get('auto_adjust_bounds', True):
_adjust_bounds(ax, pts, r)
ax.set_aspect("equal")
return ax.figure