"""
Adapted from scipy.spatial._plotutils, with modifications to suit
pyafv's data structures.
The original code is licensed under BSD-3-Clause and can be found at:
https://github.com/scipy/scipy/blob/v1.17.0/scipy/spatial/_plotutils.py
Copyright (c) 2001-2002 Enthought, Inc. 2003-2025, SciPy Developers.
All rights reserved.
"""
import numpy as np
__all__ = ['visualize_2d', 'visualize_2d_parallel']
def _get_axes():
import matplotlib.pyplot as plt
return plt.figure().gca()
def _adjust_bounds(ax, points, r):
span = max(np.ptp(points, axis=0))
if span > 0.1 * r:
half_width = 0.5 * span + 3.0 * r
else:
half_width = 5.0 * r
center = np.mean(points, axis=0)
ax.set_xlim(center[0] - half_width, center[0] + half_width)
ax.set_ylim(center[1] - half_width, center[1] + half_width)
[docs]
def visualize_2d(pts: np.ndarray, diag: dict[str, object], r: float, ax = None, **kw):
r"""
Visualize a 2D snapshot using the diagnostic dictionary `diag` generated by :py:meth:`pyafv.FiniteVoronoiSimulator.build`.
This is basically a wrapper around the vectorized custom plotting functions from the example notebooks
and generally preferred over the original :py:meth:`pyafv.FiniteVoronoiSimulator.plot_2d` method.
The plotting style follows :py:func:`scipy.spatial.voronoi_plot_2d`.
.. note::
If you are visualizing a `diag` from :py:meth:`ParallelFiniteVoronoiSimulator.build`,
you should use :func:`visualize_2d_parallel` instead.
Args:
pts: An (N, 2) array of point coordinates.
diag: A diagnostic *dict* containing Voronoi diagram information.
r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs.
ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one.
cell_colors (color or list, optional): A single color or a sequence of colors for filling cells, default 'C2'. Use *None* for no fill.
fill_alpha (float, optional): Specifies the alpha for cell fills, default 0.1.
fill_zorder (float, optional): Specifies the z-order for cell fills, default 0.
show_points (bool, optional): Add cell center points to the plot, default *False*.
point_size (float, optional): Specifies the marker area for the points, default 4.
point_colors (color or list, optional): A single color or a sequence of colors for the points, default 'C0'.
point_zorder (float, optional): Specifies the z-order for the points, default 3.
straight_colors (color, optional): Color for straight contact edges, default 'C0'.
straight_lw (float, optional): Line width for straight edges, default 1.0.
straight_alpha (float, optional): Alpha for straight edges, default 1.0.
straight_capstyle (str, optional): Cap style for straight edges, default 'butt'.
straight_zorder (float, optional): Z-order for straight edges, default 2.
arc_colors (color, optional): Color for arc non-contact edges, default 'C2'.
arc_lw (float, optional): Line width for arc edges, default 1.0.
arc_alpha (float, optional): Alpha for arc edges, default 1.0.
arc_capstyle (str, optional): Cap style for arc edges, default 'butt'.
arc_zorder (float, optional): Z-order for arc edges, default 1.
auto_adjust_bounds (bool, optional): Whether to automatically adjust the plot bounds to fit the diagram, default *True*.
Returns:
matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas.
"""
if "plot_mode" in diag:
raise ValueError(
"diag appears to come from ParallelFiniteVoronoiSimulator.build; "
"use visualize_2d_parallel with build(plot_mode=True) output"
)
from matplotlib.collections import LineCollection
ax = ax or _get_axes()
pts = np.asarray(pts, dtype=float)
if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover
raise ValueError("pts must have shape (N,2)")
N = pts.shape[0] # Number of points
# Draw cell centers
if kw.get('show_points', False):
point_size = kw.get('point_size', 2**2)
point_colors = kw.get('point_colors', 'C0')
point_zorder = kw.get('point_zorder', 3)
ax.scatter(pts[:, 0], pts[:, 1], s=point_size, c=point_colors, marker='o', zorder=point_zorder)
point_edges_type = diag["edges_type"]
point_vertices_f_idx = diag["regions"]
vertices_all = diag["vertices"]
# --- Classify cells ---
cell_lens = np.fromiter((len(et) for et in point_edges_type), dtype=int, count=N)
deg_mask = cell_lens < 2
valid_mask = ~deg_mask
# --- Per-edge geometry (vectorized) ---
straight_segs = None
arc_xy = None
valid_cells_idx = None
offsets = None
flat_e = None
edge_to_arc = None
straight_pts = None
if valid_mask.any():
valid_cells_idx = np.where(valid_mask)[0]
valid_lens = cell_lens[valid_mask]
flat_v = np.concatenate([np.asarray(point_vertices_f_idx[i], dtype=int) for i in valid_cells_idx])
flat_e = np.concatenate([np.asarray(point_edges_type[i], dtype=int) for i in valid_cells_idx])
flat_cell = np.repeat(valid_cells_idx, valid_lens)
offsets = np.concatenate(([0], np.cumsum(valid_lens)))
next_idx = np.arange(flat_v.size) + 1
next_idx[offsets[1:] - 1] = offsets[:-1]
flat_v2 = flat_v[next_idx]
straight_mask = flat_e == 1
arc_mask = flat_e == 0
if straight_mask.any():
straight_segs = np.stack([vertices_all[flat_v[straight_mask]], vertices_all[flat_v2[straight_mask]]], axis=1)
straight_pts = vertices_all[flat_v2] # (E, 2); consumed only at straight positions
if arc_mask.any():
centers = pts[flat_cell[arc_mask]]
V1a = vertices_all[flat_v[arc_mask]]
V2a = vertices_all[flat_v2[arc_mask]]
angle1 = np.arctan2(V1a[:, 1] - centers[:, 1], V1a[:, 0] - centers[:, 0])
angle2 = np.arctan2(V2a[:, 1] - centers[:, 1], V2a[:, 0] - centers[:, 0])
total = (angle1 - angle2) % (2 * np.pi)
t = np.linspace(0.0, 1.0, 100)
theta = angle1[:, None] - t[None, :] * total[:, None] # v1 -> v2
arc_xy = np.stack([centers[:, 0:1] + r * np.cos(theta), centers[:, 1:2] + r * np.sin(theta)], axis=-1) # (A, 100, 2)
edge_to_arc = np.full(flat_e.size, -1, dtype=int)
edge_to_arc[arc_mask] = np.arange(arc_mask.sum())
# --- Full-circle polylines for degenerate cells ---
deg_circles = None
deg_idx = None
if deg_mask.any():
deg_idx = np.where(deg_mask)[0]
th = np.linspace(0.0, 2 * np.pi, 100)
circle_template = np.column_stack([np.cos(th), np.sin(th)]) # (100, 2)
deg_circles = pts[deg_idx, None, :] + r * circle_template[None, :, :] # (D, 100, 2)
cell_colors = kw.get('cell_colors', 'C2')
if cell_colors is not None:
from matplotlib.collections import PolyCollection
from matplotlib.colors import to_rgba
try:
# Try treating as a single color
to_rgba(cell_colors)
cell_colors = [cell_colors] * N
except Exception:
# Must be a sequence of colors
try:
cell_colors = list(cell_colors)
except TypeError: # pragma: no cover
raise TypeError("cell_colors must be a single color or an iterable of colors")
if len(cell_colors) != N: # pragma: no cover
raise ValueError(f"cell_colors must have length {N}, got {len(cell_colors)}")
# --- Assemble per-cell fill polygons (ragged list) ---
polygons = []
face_colors = []
if valid_mask.any():
for c_idx, c in enumerate(valid_cells_idx):
s, e = offsets[c_idx], offsets[c_idx + 1]
parts = [straight_pts[p:p+1] if flat_e[p] == 1 else arc_xy[edge_to_arc[p]] for p in range(s, e)]
polygons.append(np.concatenate(parts, axis=0))
face_colors.append(cell_colors[c])
if deg_circles is not None:
for k, c in enumerate(deg_idx):
polygons.append(deg_circles[k])
face_colors.append(cell_colors[c])
# --- Emit collections ---
# Fills (zorder=0)
if polygons:
fill_alpha = kw.get('fill_alpha', 0.1)
fill_zorder = kw.get('fill_zorder', 0)
ax.add_collection(PolyCollection(polygons, facecolors=face_colors, alpha=fill_alpha, linewidths=0, zorder=fill_zorder))
# Straight strokes (zorder=2)
if straight_segs is not None:
straight_colors = kw.get('straight_colors', 'C0')
straight_lw = kw.get('straight_lw', 1.0)
straight_alpha = kw.get('straight_alpha', 1.0)
straight_capstyle = kw.get('straight_capstyle', 'butt')
straight_zorder = kw.get('straight_zorder', 2)
ax.add_collection(LineCollection(straight_segs, colors=straight_colors, lw=straight_lw, alpha=straight_alpha, capstyle=straight_capstyle, zorder=straight_zorder))
# Arc + full-circle strokes (zorder=1)
arc_polylines = []
if arc_xy is not None:
arc_polylines.append(arc_xy)
if deg_circles is not None:
arc_polylines.append(deg_circles)
if arc_polylines:
arc_colors = kw.get('arc_colors', 'C2')
arc_lw = kw.get('arc_lw', 1.0)
arc_alpha = kw.get('arc_alpha', 1.0)
arc_capstyle = kw.get('arc_capstyle', 'butt')
arc_zorder = kw.get('arc_zorder', 1)
ax.add_collection(LineCollection(np.concatenate(arc_polylines, axis=0), colors=arc_colors, lw=arc_lw, alpha=arc_alpha, capstyle=arc_capstyle, zorder=arc_zorder))
if kw.get('auto_adjust_bounds', True):
_adjust_bounds(ax, pts, r)
ax.set_aspect("equal")
return ax.figure
def _domain_style_value(value, owned_global_ids: np.ndarray, n_points: int):
if value is None: # pragma: no cover
return None
from matplotlib.colors import to_rgba
try: # pragma: no cover
to_rgba(value)
return value
except Exception:
pass
try:
values = list(value)
except TypeError: # pragma: no cover
return value
if len(values) == n_points:
return np.asarray(values, dtype=object)[owned_global_ids].tolist()
return value # pragma: no cover
def _domain_sequence_value(value, owned_global_ids: np.ndarray, n_points: int): # pragma: no cover
if np.ndim(value) == 0:
return value
values = np.asarray(value)
if len(values) == n_points:
return values[owned_global_ids]
return value
[docs]
def visualize_2d_parallel(
pts: np.ndarray,
diag: dict[str, object],
r: float,
ax = None,
**kw,
):
r"""
Visualize a 2D snapshot from :py:meth:`pyafv.ParallelFiniteVoronoiSimulator.build`. Note that
"_parallel" here means that `diag` is generated by the parallel simulator, not that the visualization
itself is parallelized with Python multiprocessing (though we have vectorized the drawing of individual domains).
The diagnostic dictionary must be built with ``plot_mode=True``. Each
domain is drawn using :func:`visualize_2d`, and the global axes bounds are
adjusted once at the end.
Args:
pts: An (N, 2) array of point coordinates.
diag: A diagnostic *dict* containing Voronoi diagram information.
r: Maximum radius (or denoted as :math:`\ell`) used for drawing arcs.
ax (matplotlib.axes.Axes | None): If provided, draw into the axes; otherwise create a new one.
**kw: Additional keyword arguments passed to :func:`visualize_2d`.
Returns:
matplotlib.figure.Figure: The matplotlib figure object representing the entire canvas.
"""
pts = np.asarray(pts, dtype=float)
if pts.ndim != 2 or pts.shape[1] != 2: # pragma: no cover
raise ValueError("pts must have shape (N,2)")
if not diag.get("plot_mode", False) or "diag_plot" not in diag:
raise ValueError("diag must come from build(plot_mode=True)")
if "owned_global_ids" not in diag: # pragma: no cover
raise ValueError("diag is missing owned_global_ids")
ax = ax or _get_axes()
owned_groups = diag["owned_global_ids"]
domain_diags = diag["diag_plot"]
if len(owned_groups) != len(domain_diags): # pragma: no cover
raise ValueError("owned_global_ids and diag_plot must have the same length")
auto_adjust_bounds = kw.pop("auto_adjust_bounds", True)
n_points = pts.shape[0]
for owned_global_ids, domain_diag in zip(owned_groups, domain_diags):
owned_global_ids = np.asarray(owned_global_ids, dtype=int)
if owned_global_ids.size == 0: # pragma: no cover
continue
domain_kw = dict(kw)
for key in ("cell_colors", "point_colors"):
if key in domain_kw:
domain_kw[key] = _domain_style_value(
domain_kw[key],
owned_global_ids,
n_points,
)
if "point_size" in domain_kw:
domain_kw["point_size"] = _domain_sequence_value(
domain_kw["point_size"],
owned_global_ids,
n_points,
)
visualize_2d(
pts[owned_global_ids],
domain_diag,
r,
ax=ax,
auto_adjust_bounds=False,
**domain_kw,
)
if auto_adjust_bounds:
_adjust_bounds(ax, pts, r)
ax.set_aspect("equal")
return ax.figure