"""
Adapted from scipy.spatial._plotutils, with modifications to suit
pyafv's data structures.
The original code is licensed under BSD-3-Clause and can be found at:
https://github.com/scipy/scipy/blob/v1.17.0/scipy/spatial/_plotutils.py
Copyright (c) 2001-2002 Enthought, Inc. 2003-2025, SciPy Developers.
All rights reserved.
"""
import numpy as np
__all__ = ['visualize_2d', 'visualize_2d_parallel']
def _get_axes():
import matplotlib.pyplot as plt
return plt.figure().gca()
def _adjust_bounds(ax, points, r):
span = max(np.ptp(points, axis=0))
if span > 0.1 * r:
half_width = 0.5 * span + 3.0 * r
else:
half_width = 5.0 * r
center = np.mean(points, axis=0)
ax.set_xlim(center[0] - half_width, center[0] + half_width)
ax.set_ylim(center[1] - half_width, center[1] + half_width)
def _slice_style_value(value, cell_indices: np.ndarray, n_points: int): # pragma: no cover
if value is None:
return None
from matplotlib.colors import to_rgba
try:
to_rgba(value)
return value
except Exception:
pass
try:
values = list(value)
except TypeError:
return value
if len(values) == n_points:
return np.asarray(values, dtype=object)[cell_indices].tolist()
return value
def _slice_sequence_value(value, cell_indices: np.ndarray, n_points: int): # pragma: no cover
if np.ndim(value) == 0:
return value
values = np.asarray(value)
if len(values) == n_points:
return values[cell_indices]
return value
[docs]
def visualize_2d(
pts: np.ndarray,
diag: dict[str, object],
r: float,
ax = None,
*,
selected = None,
**kw,
):
r"""
Visualize a 2D snapshot using the diagnostic dictionary `diag` generated by :py:meth:`pyafv.FiniteVoronoiSimulator.build`.
This is basically a wrapper around the vectorized custom plotting functions from the example notebooks
and generally preferred over the original :py:meth:`pyafv.FiniteVoronoiSimulator.plot_2d` method.
The plotting style follows :py:func:`scipy.spatial.voronoi_plot_2d`.
.. note::
If you are visualizing a `diag` from :py:meth:`ParallelFiniteVoronoiSimulator.build`,
you should use :func:`visualize_2d_parallel` instead.
Args:
pts: An (N, 2) array of point coordinates.
diag: A diagnostic *dict* containing Voronoi diagram information.
r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs.
ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one.
selected (array-like | None, optional): Cells to draw. This can be
either a one-dimensional array of integer indices or a boolean mask
with length equal to ``len(pts)``. If *None*, draw all cells.
cell_colors (color or list, optional): A single color or a sequence of colors for filling cells, default 'C2'. Use *None* for no fill.
If *selected* is provided, full-length per-cell color sequences
are sliced to the selected cells.
fill_alpha (float, optional): Specifies the alpha for cell fills, default 0.1.
fill_zorder (float, optional): Specifies the z-order for cell fills, default 0.
show_points (bool, optional): Add cell center points to the plot, default *False*.
point_size (float, optional): Specifies the marker area for the points, default 4.
If *selected* is provided, full-length per-cell sequences are
sliced to the selected cells.
point_colors (color or list, optional): A single color or a sequence of colors for the points, default 'C0'.
If *selected* is provided, full-length per-cell color sequences
are sliced to the selected cells.
point_zorder (float, optional): Specifies the z-order for the points, default 3.
straight_colors (color, optional): Color for straight contact edges, default 'C0'.
straight_lw (float, optional): Line width for straight edges, default 1.0.
straight_alpha (float, optional): Alpha for straight edges, default 1.0.
straight_capstyle (str, optional): Cap style for straight edges, default 'butt'.
straight_zorder (float, optional): Z-order for straight edges, default 2.
arc_colors (color, optional): Color for arc non-contact edges, default 'C2'.
arc_lw (float, optional): Line width for arc edges, default 1.0.
arc_alpha (float, optional): Alpha for arc edges, default 1.0.
arc_capstyle (str, optional): Cap style for arc edges, default 'butt'.
arc_zorder (float, optional): Z-order for arc edges, default 1.
auto_adjust_bounds (bool, optional): Whether to automatically adjust the plot bounds to fit the diagram, default *True*.
Returns:
matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas.
"""
if "plot_mode" in diag:
raise ValueError(
"diag appears to come from ParallelFiniteVoronoiSimulator.build; "
"use visualize_2d_parallel with build(plot_mode=True) output"
)
from matplotlib.collections import LineCollection
ax = ax or _get_axes()
pts = np.asarray(pts, dtype=float)
if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover
raise ValueError("pts must have shape (N,2)")
full_N = pts.shape[0] # Number of points in the original diagram
if selected is None:
point_edges_type = diag["edges_type"]
point_vertices_f_idx = diag["regions"]
else:
selected = np.asarray(selected)
if selected.dtype == bool:
if selected.shape != (full_N,): # pragma: no cover
raise ValueError(f"boolean selected must have shape ({full_N},)")
selected = np.flatnonzero(selected) # pragma: no cover
else:
selected = np.asarray(selected, dtype=int)
if selected.ndim != 1: # pragma: no cover
raise ValueError("selected must be a one-dimensional array")
if selected.size == 0: # pragma: no cover
raise ValueError("selected must not be empty")
if np.any((selected < 0) | (selected >= full_N)): # pragma: no cover
raise IndexError("selected indices are out of bounds")
pts = pts[selected]
point_edges_type = [diag["edges_type"][i] for i in selected]
point_vertices_f_idx = [diag["regions"][i] for i in selected]
if any(key in kw for key in ("cell_colors", "point_colors", "point_size")):
kw = dict(kw)
for key in ("cell_colors", "point_colors"):
if key in kw:
kw[key] = _slice_style_value(kw[key], selected, full_N)
if "point_size" in kw:
kw["point_size"] = _slice_sequence_value(kw["point_size"], selected, full_N)
N = pts.shape[0] # Number of points to plot
# Draw cell centers
if kw.get('show_points', False):
point_size = kw.get('point_size', 2**2)
point_colors = kw.get('point_colors', 'C0')
point_zorder = kw.get('point_zorder', 3)
ax.scatter(pts[:, 0], pts[:, 1], s=point_size, c=point_colors, marker='o', zorder=point_zorder)
vertices_all = diag["vertices"]
# --- Classify cells ---
cell_lens = np.fromiter((len(et) for et in point_edges_type), dtype=int, count=N)
deg_mask = cell_lens < 2
valid_mask = ~deg_mask
# --- Per-edge geometry (vectorized) ---
straight_segs = None
arc_xy = None
valid_cells_idx = None
offsets = None
flat_e = None
edge_to_arc = None
straight_pts = None
if valid_mask.any():
valid_cells_idx = np.where(valid_mask)[0]
valid_lens = cell_lens[valid_mask]
flat_v = np.concatenate([np.asarray(point_vertices_f_idx[i], dtype=int) for i in valid_cells_idx])
flat_e = np.concatenate([np.asarray(point_edges_type[i], dtype=int) for i in valid_cells_idx])
flat_cell = np.repeat(valid_cells_idx, valid_lens)
offsets = np.concatenate(([0], np.cumsum(valid_lens)))
next_idx = np.arange(flat_v.size) + 1
next_idx[offsets[1:] - 1] = offsets[:-1]
flat_v2 = flat_v[next_idx]
straight_mask = flat_e == 1
arc_mask = flat_e == 0
if straight_mask.any():
straight_segs = np.stack([vertices_all[flat_v[straight_mask]], vertices_all[flat_v2[straight_mask]]], axis=1)
straight_pts = vertices_all[flat_v2] # (E, 2); consumed only at straight positions
if arc_mask.any():
centers = pts[flat_cell[arc_mask]]
V1a = vertices_all[flat_v[arc_mask]]
V2a = vertices_all[flat_v2[arc_mask]]
angle1 = np.arctan2(V1a[:, 1] - centers[:, 1], V1a[:, 0] - centers[:, 0])
angle2 = np.arctan2(V2a[:, 1] - centers[:, 1], V2a[:, 0] - centers[:, 0])
total = (angle1 - angle2) % (2 * np.pi)
t = np.linspace(0.0, 1.0, 100)
theta = angle1[:, None] - t[None, :] * total[:, None] # v1 -> v2
arc_xy = np.stack([centers[:, 0:1] + r * np.cos(theta), centers[:, 1:2] + r * np.sin(theta)], axis=-1) # (A, 100, 2)
edge_to_arc = np.full(flat_e.size, -1, dtype=int)
edge_to_arc[arc_mask] = np.arange(arc_mask.sum())
# --- Full-circle polylines for degenerate cells ---
deg_circles = None
deg_idx = None
if deg_mask.any():
deg_idx = np.where(deg_mask)[0]
th = np.linspace(0.0, 2 * np.pi, 100)
circle_template = np.column_stack([np.cos(th), np.sin(th)]) # (100, 2)
deg_circles = pts[deg_idx, None, :] + r * circle_template[None, :, :] # (D, 100, 2)
cell_colors = kw.get('cell_colors', 'C2')
if cell_colors is not None:
from matplotlib.collections import PolyCollection
from matplotlib.colors import to_rgba
try:
# Try treating as a single color
to_rgba(cell_colors)
cell_colors = [cell_colors] * N
except Exception:
# Must be a sequence of colors
try:
cell_colors = list(cell_colors)
except TypeError: # pragma: no cover
raise TypeError("cell_colors must be a single color or an iterable of colors")
if len(cell_colors) != N: # pragma: no cover
raise ValueError(f"cell_colors must have length {N}, got {len(cell_colors)}")
# --- Assemble per-cell fill polygons (ragged list) ---
polygons = []
face_colors = []
if valid_mask.any():
for c_idx, c in enumerate(valid_cells_idx):
s, e = offsets[c_idx], offsets[c_idx + 1]
parts = [straight_pts[p:p+1] if flat_e[p] == 1 else arc_xy[edge_to_arc[p]] for p in range(s, e)]
polygons.append(np.concatenate(parts, axis=0))
face_colors.append(cell_colors[c])
if deg_circles is not None:
for k, c in enumerate(deg_idx):
polygons.append(deg_circles[k])
face_colors.append(cell_colors[c])
# --- Emit collections ---
# Fills (zorder=0)
if polygons:
fill_alpha = kw.get('fill_alpha', 0.1)
fill_zorder = kw.get('fill_zorder', 0)
ax.add_collection(PolyCollection(polygons, facecolors=face_colors, alpha=fill_alpha, linewidths=0, zorder=fill_zorder))
# Straight strokes (zorder=2)
if straight_segs is not None:
straight_colors = kw.get('straight_colors', 'C0')
straight_lw = kw.get('straight_lw', 1.0)
straight_alpha = kw.get('straight_alpha', 1.0)
straight_capstyle = kw.get('straight_capstyle', 'butt')
straight_zorder = kw.get('straight_zorder', 2)
ax.add_collection(LineCollection(straight_segs, colors=straight_colors, lw=straight_lw, alpha=straight_alpha, capstyle=straight_capstyle, zorder=straight_zorder))
# Arc + full-circle strokes (zorder=1)
arc_polylines = []
if arc_xy is not None:
arc_polylines.append(arc_xy)
if deg_circles is not None:
arc_polylines.append(deg_circles)
if arc_polylines:
arc_colors = kw.get('arc_colors', 'C2')
arc_lw = kw.get('arc_lw', 1.0)
arc_alpha = kw.get('arc_alpha', 1.0)
arc_capstyle = kw.get('arc_capstyle', 'butt')
arc_zorder = kw.get('arc_zorder', 1)
ax.add_collection(LineCollection(np.concatenate(arc_polylines, axis=0), colors=arc_colors, lw=arc_lw, alpha=arc_alpha, capstyle=arc_capstyle, zorder=arc_zorder))
if kw.get('auto_adjust_bounds', True):
_adjust_bounds(ax, pts, r)
ax.set_aspect("equal")
return ax.figure
[docs]
def visualize_2d_parallel(
pts: np.ndarray,
diag: dict[str, object],
r: float,
ax = None,
*,
selected = None,
**kw,
):
r"""
Visualize a 2D snapshot from :py:meth:`pyafv.ParallelFiniteVoronoiSimulator.build`. Note that
"_parallel" here means that `diag` is generated by the parallel simulator, not that the visualization
itself is parallelized with Python multiprocessing (though we have vectorized the drawing of individual domains).
The diagnostic dictionary must be built with ``plot_mode=True``. Each
domain is drawn using :func:`visualize_2d`, and the global axes bounds are
adjusted once at the end.
Args:
pts: An (N, 2) array of point coordinates.
diag: A diagnostic *dict* containing Voronoi diagram information.
r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs.
ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one.
selected (array-like | None, optional): Global cells to draw. This can
be either a one-dimensional array of integer indices or a boolean
mask with length equal to ``len(pts)``. If *None*, draw all owned
cells from every domain.
**kw: Additional keyword arguments passed to :func:`visualize_2d`.
Returns:
matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas.
"""
pts = np.asarray(pts, dtype=float)
if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover
raise ValueError("pts must have shape (N,2)")
if not diag.get("plot_mode", False) or "diag_plot" not in diag:
raise ValueError("diag must come from build(plot_mode=True)")
if "owned_global_ids" not in diag: # pragma: no cover
raise ValueError("diag is missing owned_global_ids")
ax = ax or _get_axes()
owned_groups = diag["owned_global_ids"]
domain_diags = diag["diag_plot"]
if len(owned_groups) != len(domain_diags): # pragma: no cover
raise ValueError("owned_global_ids and diag_plot must have the same length")
auto_adjust_bounds = kw.pop("auto_adjust_bounds", True)
n_points = pts.shape[0]
if selected is None:
selected_mask = None
bounds_pts = pts
else:
selected = np.asarray(selected)
if selected.dtype == bool:
if selected.shape != (n_points,): # pragma: no cover
raise ValueError(f"boolean selected must have shape ({n_points},)")
selected = np.flatnonzero(selected) # pragma: no cover
else:
selected = np.asarray(selected, dtype=int)
if selected.ndim != 1: # pragma: no cover
raise ValueError("selected must be a one-dimensional array")
if selected.size == 0: # pragma: no cover
raise ValueError("selected must not be empty")
if np.any((selected < 0) | (selected >= n_points)): # pragma: no cover
raise IndexError("selected indices are out of bounds")
selected_mask = np.zeros(n_points, dtype=bool)
selected_mask[selected] = True
bounds_pts = pts[selected]
for owned_global_ids, domain_diag in zip(owned_groups, domain_diags):
owned_global_ids = np.asarray(owned_global_ids, dtype=int)
if owned_global_ids.size == 0: # pragma: no cover
continue
domain_selected = None
if selected_mask is not None:
domain_selected = np.flatnonzero(selected_mask[owned_global_ids])
if domain_selected.size == 0:
continue
domain_kw = dict(kw)
for key in ("cell_colors", "point_colors"):
if key in domain_kw:
domain_kw[key] = _slice_style_value(
domain_kw[key],
owned_global_ids,
n_points,
)
if "point_size" in domain_kw:
domain_kw["point_size"] = _slice_sequence_value(
domain_kw["point_size"],
owned_global_ids,
n_points,
)
visualize_2d(
pts[owned_global_ids],
domain_diag,
r,
ax=ax,
selected=domain_selected,
auto_adjust_bounds=False,
**domain_kw,
)
if auto_adjust_bounds:
_adjust_bounds(ax, bounds_pts, r)
ax.set_aspect("equal")
return ax.figure