"""Visualization module for clustering algorithms.
Provides plotting utilities for clustering analysis with consistent styling. All visual
configuration flows through :class:`PlotStyle`, which must be passed to
:class:`ClusterVisualizer` to ensure uniform appearance across all figures.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from itertools import permutations
from typing import Any, cast
from matplotlib.axes import Axes
from matplotlib.colors import BoundaryNorm, ListedColormap
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde
from geomexp.clustering.clustering_base import ClusterResult
def _align_result_to_reference(reference: ClusterResult, target: ClusterResult) -> ClusterResult:
"""Align cluster labels of ``target`` to those of ``reference`` by permutation matching.
Finds the permutation of cluster indices that minimises the sum of Euclidean distances
between matched center pairs, then relabels ``target`` accordingly. This ensures
consistent colouring when comparing results side by side.
Args:
reference: The result whose labelling is treated as canonical.
target: The result to relabel.
Returns:
A new ``ClusterResult`` with relabelled assignments, centers, and metadata indices.
"""
n_clusters = len(reference.centers)
if n_clusters != len(target.centers):
return target
best_perm: tuple[int, ...] | None = None
best_cost = np.inf
for perm in permutations(range(n_clusters)):
cost = sum(
float(np.linalg.norm(reference.centers[i] - target.centers[perm[i]]))
for i in range(n_clusters)
)
if cost < best_cost:
best_cost = cost
best_perm = perm
assert best_perm is not None
inv_perm = [0] * n_clusters
for new_label, old_label in enumerate(best_perm):
inv_perm[old_label] = new_label
new_assignments = np.array([inv_perm[a] for a in target.assignments])
new_centers = np.empty_like(target.centers)
for new_label, old_label in enumerate(best_perm):
new_centers[new_label] = target.centers[old_label]
new_metadata: dict[str, object] | None = None
if target.metadata is not None:
new_metadata = dict(target.metadata)
if "indices" in new_metadata:
old_indices = cast(np.ndarray, new_metadata["indices"])
new_indices = np.empty_like(old_indices)
for new_label, old_label in enumerate(best_perm):
new_indices[new_label] = old_indices[old_label]
new_metadata["indices"] = new_indices
return ClusterResult(
assignments=new_assignments,
centers=new_centers,
objective=target.objective,
n_iterations=target.n_iterations,
converged=target.converged,
metadata=new_metadata,
)
[docs]
@dataclass
class PlotStyle:
"""Central styling configuration for all cluster visualisations.
Every plot produced by :class:`ClusterVisualizer` reads its visual parameters from this object,
ensuring that all figures in a study share a uniform appearance. Construct a ``PlotStyle``
instance first, then pass it to ``ClusterVisualizer(style)``.
Attributes:
figsize: Default figure size ``(width, height)`` in inches.
dpi: Resolution in dots per inch.
point_size: Scatter point size for data points.
center_size: Scatter point size for cluster centers.
center_linewidth: Edge line width for center markers.
alpha_data: Opacity of data points.
alpha_contour: Opacity of coloured decision regions.
color_palette: List of hex colour strings cycled over clusters.
center_color: Colour of center markers.
center_marker: Marker shape for centers.
arrow_color: Colour of index vector arrows.
arrow_linewidth: Line width of index vector arrows.
arrow_head_width: Head width of index vector arrows.
arrow_head_length: Head length of index vector arrows.
arrow_scale: Multiplicative scaling applied to index vector arrows for visibility.
contour_color: Colour of decision boundary lines.
contour_linewidth: Line width of decision boundary contours.
contour_alpha: Opacity of decision boundary contour lines.
curve_linewidth: Line width of expectile level-set contours.
curve_alpha: Opacity of expectile level-set contour lines.
kde_color: Colour of KDE density contour lines.
kde_linewidth: Line width of KDE density contour lines.
kde_alpha: Opacity of KDE density contour lines.
loss_linewidth: Line width of loss function curves.
decision_boundary_resolution: Grid resolution for boundary computation.
decision_boundary_show_points: Whether to overlay data points on boundary plots.
decision_boundary_show_contour_lines: Whether to draw boundary contour lines.
decision_boundary_show_centers: Whether to show center markers on boundary plots.
decision_boundary_show_regions: Whether to fill decision regions with colour.
decision_boundary_color: Override colour for boundary lines (``None`` uses palette).
axis_linewidth: Width of axis spines and tick marks.
grid_alpha: Opacity of grid lines.
grid_linewidth: Width of grid lines.
fontsize: Base font size for labels and titles.
use_latex: Whether to enable LaTeX rendering via ``text.usetex``.
"""
figsize: tuple[float, float] = (3.5, 3.5)
dpi: int = 300
point_size: float = 1.5
center_size: float = 15
center_linewidth: float = 0.8
alpha_data: float = 0.7
alpha_contour: float = 0.15
color_palette: list[str] | None = None
center_color: str = "black"
center_marker: str = "o"
arrow_color: str = "black"
arrow_linewidth: float = 0.8
arrow_head_width: float = 0.15
arrow_head_length: float = 0.2
arrow_scale: float = 1
contour_color: str = "black"
contour_linewidth: float = 0.5
contour_alpha: float = 0.9
curve_linewidth: float = 0.8
curve_alpha: float = 0.7
kde_color: str = "gray"
kde_linewidth: float = 0.4
kde_alpha: float = 0.6
loss_linewidth: float = 1.0
decision_boundary_resolution: int = 300
decision_boundary_show_points: bool = True
decision_boundary_show_contour_lines: bool = True
decision_boundary_show_centers: bool = True
decision_boundary_show_regions: bool = True
decision_boundary_color: str | None = None
axis_linewidth: float = 0.6
grid_alpha: float = 0.15
grid_linewidth: float = 0.3
fontsize: float = 9
use_latex: bool = True
def __post_init__(self) -> None:
if self.color_palette is None:
self.color_palette = ["#FF5733", "#C70039", "#900C3F", "#581845", "#FFC300"]
[docs]
class ClusterVisualizer:
"""Plotting engine for cluster analysis visualisations.
All plotting methods read visual configuration from the ``style`` attribute.
A :class:`PlotStyle` instance **must** be passed at construction time; there is no default.
Args:
style: A :class:`PlotStyle` instance that governs all visual parameters.
Raises:
TypeError: If ``style`` is not a :class:`PlotStyle`.
Example:
>>> style = PlotStyle(figsize=(5, 4), use_latex=False)
>>> viz = ClusterVisualizer(style)
"""
def __init__(self, style: PlotStyle) -> None:
if not isinstance(style, PlotStyle):
raise TypeError(
f"ClusterVisualizer requires a PlotStyle instance, got {type(style).__name__}"
)
self.style = style
self._setup_matplotlib()
def _setup_matplotlib(self) -> None:
"""Apply ``self.style`` settings to matplotlib rcParams."""
fs = self.style.fontsize
font_params: dict[str, object] = {
"axes.labelsize": fs,
"font.size": fs,
"legend.fontsize": fs - 1,
"xtick.labelsize": fs - 1,
"ytick.labelsize": fs - 1,
"axes.titlesize": fs,
}
if self.style.use_latex:
font_params.update(
{
"text.usetex": True,
"font.family": "serif",
"text.latex.preamble": r"\usepackage{amsmath}\usepackage{fourier}",
}
)
plt.rcParams.update(font_params)
plt.rcParams.update(
{
"figure.dpi": self.style.dpi,
"savefig.dpi": self.style.dpi,
"savefig.bbox": "tight",
"savefig.pad_inches": 0.05,
"axes.linewidth": self.style.axis_linewidth,
"axes.edgecolor": "black",
"xtick.major.width": self.style.axis_linewidth,
"ytick.major.width": self.style.axis_linewidth,
"xtick.minor.width": self.style.axis_linewidth * 0.7,
"ytick.minor.width": self.style.axis_linewidth * 0.7,
"xtick.color": "black",
"ytick.color": "black",
"grid.linewidth": self.style.grid_linewidth,
"legend.frameon": True,
"legend.edgecolor": "black",
"legend.framealpha": 1,
"legend.fancybox": False,
}
)
def _compute_costs(self, grid_points: np.ndarray, result: ClusterResult) -> np.ndarray:
"""Compute per-cluster costs for every grid point.
Args:
grid_points: Array of shape ``(n_grid, 2)``.
result: Clustering result with centers and optional index metadata.
Returns:
Cost array of shape ``(n_grid, n_clusters)``.
"""
n_clusters = len(result.centers)
costs = np.zeros((len(grid_points), n_clusters))
if result.metadata is not None and "indices" in result.metadata:
indices = cast(np.ndarray, result.metadata["indices"])
for k in range(n_clusters):
residuals = grid_points - result.centers[k]
dists = np.linalg.norm(residuals, axis=1)
costs[:, k] = 0.5 * dists**2 + 0.5 * dists * np.sum(indices[k] * residuals, axis=1)
else:
for k in range(n_clusters):
costs[:, k] = np.sum((grid_points - result.centers[k]) ** 2, axis=1)
return costs
def _draw_boundary_lines(
self,
ax: Axes,
xx: np.ndarray,
yy: np.ndarray,
costs: np.ndarray,
n_clusters: int,
linewidth: float,
color: str | None = None,
) -> None:
"""Draw decision boundary lines via pairwise cost-difference contours.
For each pair of clusters that share an edge, a zero-level contour of their cost difference
is drawn, clipped to the region where one of the two clusters is closest.
Args:
ax: Matplotlib axes to draw on.
xx: Meshgrid x-coordinates.
yy: Meshgrid y-coordinates.
costs: Flattened cost array of shape ``(n_grid, n_clusters)``.
n_clusters: Number of clusters.
linewidth: Line width for boundary contours.
color: Override boundary colour (``None`` uses ``self.style.contour_color``).
"""
line_color = self.style.contour_color if color is None else color
assignments = np.argmin(costs, axis=1).reshape(xx.shape)
for i in range(n_clusters):
for j in range(i + 1, n_clusters):
mask_i = assignments == i
mask_j = assignments == j
if not (mask_i.any() and mask_j.any()):
continue
diff = (costs[:, i] - costs[:, j]).reshape(xx.shape)
masked_diff = np.where(mask_i | mask_j, diff, np.nan)
ax.contour(
xx,
yy,
masked_diff,
levels=[0.0],
colors=line_color,
linewidths=linewidth,
alpha=self.style.contour_alpha,
)
[docs]
def plot_cluster_assignments(
self,
X: np.ndarray,
result: ClusterResult,
ax: Axes | None = None,
show_centers: bool = True,
show_indices: bool = True,
show_legend: bool = False,
title: str | None = None,
) -> Figure:
"""Plot data points coloured by cluster assignment.
Args:
X: Data array of shape ``(n_samples, 2)``.
result: Clustering result.
ax: Matplotlib axes (created if ``None``).
show_centers: Whether to overlay center markers.
show_indices: Whether to draw index vector arrows.
show_legend: Whether to show a cluster legend.
title: Optional axes title.
Returns:
The matplotlib ``Figure`` containing the plot.
"""
if X.shape[1] != 2:
raise ValueError("Only 2D data can be visualized")
if ax is None:
fig, ax = plt.subplots(figsize=self.style.figsize)
else:
fig = cast(Figure, ax.get_figure())
assert fig is not None
assert self.style.color_palette is not None
for k in range(len(result.centers)):
mask = result.assignments == k
ax.scatter(
X[mask, 0],
X[mask, 1],
c=self.style.color_palette[k % len(self.style.color_palette)],
s=self.style.point_size,
alpha=self.style.alpha_data,
edgecolors="none",
label=f"{k}" if show_legend else None,
)
if show_indices and result.metadata is not None and "indices" in result.metadata:
indices = cast(np.ndarray, result.metadata["indices"])
for k in range(len(result.centers)):
u_k = indices[k] * self.style.arrow_scale
if np.linalg.norm(u_k) > 1e-10:
c = result.centers[k]
ax.annotate(
"",
xy=(c[0] + u_k[0], c[1] + u_k[1]),
xytext=(c[0], c[1]),
arrowprops={
"arrowstyle": (
f"->,head_width={self.style.arrow_head_width}"
f",head_length={self.style.arrow_head_length}"
),
"lw": self.style.arrow_linewidth,
"color": self.style.arrow_color,
"shrinkA": 0,
"shrinkB": 0,
},
zorder=8,
)
if show_centers:
ax.scatter(
result.centers[:, 0],
result.centers[:, 1],
facecolors="none",
edgecolors=self.style.center_color,
s=self.style.center_size,
marker=self.style.center_marker,
linewidths=self.style.center_linewidth,
zorder=10,
)
self._style_axes(ax, X, title, show_legend)
return fig
[docs]
def plot_decision_boundaries(
self,
X: np.ndarray,
result: ClusterResult,
ax: Axes | None = None,
title: str | None = None,
resolution: int | None = None,
show_points: bool | None = None,
show_contour_lines: bool | None = None,
show_centers: bool | None = None,
show_regions: bool | None = None,
x_lim: tuple[float, float] | None = None,
y_lim: tuple[float, float] | None = None,
boundary_color: str | None = None,
) -> Figure:
"""Plot decision boundaries (and optionally regions, data points, centers).
Args:
X: Data array of shape ``(n_samples, 2)``.
result: Clustering result.
ax: Matplotlib axes (created if ``None``).
title: Optional axes title.
resolution: Grid resolution for boundary computation. Defaults to
``self.style.decision_boundary_resolution``.
show_points: Overlay data points. Defaults to style setting.
show_contour_lines: Draw boundary contour lines. Defaults to style setting.
show_centers: Show center markers. Defaults to style setting.
show_regions: Fill decision regions with colour. Defaults to style setting.
x_lim: Manual x-axis limits.
y_lim: Manual y-axis limits.
boundary_color: Override colour for boundary lines.
Returns:
The matplotlib ``Figure`` containing the plot.
"""
if X.shape[1] != 2:
raise ValueError("Only 2D data can be visualized")
s = self.style
resolution = resolution or s.decision_boundary_resolution
show_points = show_points if show_points is not None else s.decision_boundary_show_points
show_contour_lines = (
show_contour_lines
if show_contour_lines is not None
else s.decision_boundary_show_contour_lines
)
show_centers = (
show_centers if show_centers is not None else s.decision_boundary_show_centers
)
show_regions = (
show_regions if show_regions is not None else s.decision_boundary_show_regions
)
if boundary_color is None:
boundary_color = s.decision_boundary_color
if boundary_color is None:
assert s.color_palette is not None
boundary_color = s.color_palette[0]
if ax is None:
fig, ax = plt.subplots(figsize=s.figsize)
else:
fig = cast(Figure, ax.get_figure())
assert fig is not None
if x_lim is not None:
x_min, x_max = x_lim
else:
x_margin = (X[:, 0].max() - X[:, 0].min()) * 0.1
x_min, x_max = X[:, 0].min() - x_margin, X[:, 0].max() + x_margin
if y_lim is not None:
y_min, y_max = y_lim
else:
y_margin = (X[:, 1].max() - X[:, 1].min()) * 0.1
y_min, y_max = X[:, 1].min() - y_margin, X[:, 1].max() + y_margin
pad_x, pad_y = (x_max - x_min) * 0.02, (y_max - y_min) * 0.02
xx, yy = np.meshgrid(
np.linspace(x_min - pad_x, x_max + pad_x, resolution),
np.linspace(y_min - pad_y, y_max + pad_y, resolution),
)
n_clusters = len(result.centers)
costs = self._compute_costs(np.c_[xx.ravel(), yy.ravel()], result)
assert s.color_palette is not None
colors = [s.color_palette[k % len(s.color_palette)] for k in range(n_clusters)]
if show_regions:
Z = np.argmin(costs, axis=1).reshape(xx.shape)
ax.pcolormesh(
xx,
yy,
Z,
cmap=ListedColormap(colors),
norm=BoundaryNorm(np.arange(n_clusters + 1) - 0.5, n_clusters),
alpha=s.alpha_contour,
shading="auto",
zorder=0,
)
if show_contour_lines:
self._draw_boundary_lines(
ax,
xx,
yy,
costs,
n_clusters,
linewidth=s.contour_linewidth,
color=boundary_color,
)
if show_points:
for k in range(n_clusters):
mask = result.assignments == k
ax.scatter(
X[mask, 0],
X[mask, 1],
c=colors[k],
s=s.point_size,
alpha=s.alpha_data,
edgecolors="none",
)
if show_centers:
ax.scatter(
result.centers[:, 0],
result.centers[:, 1],
facecolors="none",
edgecolors=s.center_color,
s=s.center_size,
marker=s.center_marker,
linewidths=s.center_linewidth,
zorder=10,
)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_aspect("equal", adjustable="box")
ax.grid(True, alpha=s.grid_alpha, linewidth=s.grid_linewidth)
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
if title:
ax.set_title(title, fontsize=s.fontsize)
return fig
[docs]
def plot_density_contours(
self,
X: np.ndarray,
ax: Axes | None = None,
title: str | None = None,
levels: int = 20,
bandwidth: float = 0.2,
) -> Figure:
"""Plot KDE density contours of the data.
Args:
X: Data array of shape ``(n_samples, 2)``.
ax: Matplotlib axes (created if ``None``).
title: Optional axes title.
levels: Number of contour levels.
bandwidth: KDE bandwidth parameter.
Returns:
The matplotlib ``Figure`` containing the plot.
"""
if X.shape[1] != 2:
raise ValueError("Only 2D data can be visualized")
if ax is None:
fig, ax = plt.subplots(figsize=self.style.figsize)
else:
fig = cast(Figure, ax.get_figure())
assert fig is not None
kde = gaussian_kde(X.T, bw_method=bandwidth)
x_margin = (X[:, 0].max() - X[:, 0].min()) * 0.1
y_margin = (X[:, 1].max() - X[:, 1].min()) * 0.1
xx, yy = np.meshgrid(
np.linspace(X[:, 0].min() - x_margin, X[:, 0].max() + x_margin, 150),
np.linspace(X[:, 1].min() - y_margin, X[:, 1].max() + y_margin, 150),
)
ax.contour(
xx,
yy,
kde(np.vstack([xx.ravel(), yy.ravel()])).reshape(xx.shape),
levels=levels,
colors=self.style.kde_color,
linewidths=self.style.kde_linewidth,
alpha=self.style.kde_alpha,
)
self._style_axes(ax, X, title, show_legend=False)
return fig
[docs]
def plot_density_with_boundaries(
self,
X: np.ndarray,
result: ClusterResult,
ax: Axes | None = None,
title: str | None = None,
resolution: int = 300,
kde_levels: int = 20,
kde_bandwidth: float = 0.2,
show_centers: bool = True,
x_lim: tuple[float, float] | None = None,
y_lim: tuple[float, float] | None = None,
) -> Figure:
"""Overlay KDE density contours with decision boundary lines.
Args:
X: Data array of shape ``(n_samples, 2)``.
result: Clustering result.
ax: Matplotlib axes (created if ``None``).
title: Optional axes title.
resolution: Grid resolution for boundary computation.
kde_levels: Number of KDE contour levels.
kde_bandwidth: KDE bandwidth parameter.
show_centers: Whether to overlay center markers.
x_lim: Manual x-axis limits.
y_lim: Manual y-axis limits.
Returns:
The matplotlib ``Figure`` containing the plot.
"""
if X.shape[1] != 2:
raise ValueError("Only 2D data can be visualized")
if ax is None:
fig, ax = plt.subplots(figsize=self.style.figsize)
else:
fig = cast(Figure, ax.get_figure())
assert fig is not None
if x_lim is not None:
x_min, x_max = x_lim
else:
x_margin = (X[:, 0].max() - X[:, 0].min()) * 0.1
x_min, x_max = X[:, 0].min() - x_margin, X[:, 0].max() + x_margin
if y_lim is not None:
y_min, y_max = y_lim
else:
y_margin = (X[:, 1].max() - X[:, 1].min()) * 0.1
y_min, y_max = X[:, 1].min() - y_margin, X[:, 1].max() + y_margin
pad_x, pad_y = (x_max - x_min) * 0.02, (y_max - y_min) * 0.02
xx, yy = np.meshgrid(
np.linspace(x_min - pad_x, x_max + pad_x, resolution),
np.linspace(y_min - pad_y, y_max + pad_y, resolution),
)
kde = gaussian_kde(X.T, bw_method=kde_bandwidth)
ax.contour(
xx,
yy,
kde(np.vstack([xx.ravel(), yy.ravel()])).reshape(xx.shape),
levels=kde_levels,
colors=self.style.kde_color,
linewidths=self.style.kde_linewidth,
alpha=self.style.kde_alpha,
)
grid_points = np.c_[xx.ravel(), yy.ravel()]
costs = self._compute_costs(grid_points, result)
self._draw_boundary_lines(
ax,
xx,
yy,
costs,
len(result.centers),
linewidth=self.style.contour_linewidth,
)
if show_centers:
ax.scatter(
result.centers[:, 0],
result.centers[:, 1],
facecolors="none",
edgecolors=self.style.center_color,
s=self.style.center_size,
marker=self.style.center_marker,
linewidths=self.style.center_linewidth,
zorder=10,
)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_aspect("equal", adjustable="box")
ax.grid(True, alpha=self.style.grid_alpha, linewidth=self.style.grid_linewidth)
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
if title:
ax.set_title(title, fontsize=self.style.fontsize)
return fig
[docs]
def plot_expectile_curves(
self,
center: np.ndarray,
index_vector: np.ndarray,
cost_levels: np.ndarray | None = None,
ax: Axes | None = None,
title: str | None = None,
show_index_arrow: bool = True,
curve_color: str | None = None,
extent: float = 3.0,
resolution: int = 300,
) -> Figure:
"""Plot level curves of the expectile loss around a single center.
Args:
center: Center point of shape ``(2,)``.
index_vector: Index vector of shape ``(2,)``.
cost_levels: Loss values at which to draw contours.
ax: Matplotlib axes (created if ``None``).
title: Optional axes title.
show_index_arrow: Whether to draw the index vector arrow.
curve_color: Override colour for contour lines.
extent: Half-width of the plotting window around ``center``.
resolution: Grid resolution.
Returns:
The matplotlib ``Figure`` containing the plot.
"""
if ax is None:
fig, ax = plt.subplots(figsize=self.style.figsize)
else:
fig = cast(Figure, ax.get_figure())
assert fig is not None
if curve_color is None:
assert self.style.color_palette is not None
curve_color = self.style.color_palette[0]
if cost_levels is None:
cost_levels = np.linspace(0.5, 3.0, 8)
xx, yy = np.meshgrid(
np.linspace(center[0] - extent, center[0] + extent, resolution),
np.linspace(center[1] - extent, center[1] + extent, resolution),
)
rx, ry = xx - center[0], yy - center[1]
norms = np.sqrt(rx**2 + ry**2)
Z = 0.5 * (norms**2 + norms * (index_vector[0] * rx + index_vector[1] * ry))
ax.contour(
xx,
yy,
Z,
levels=cost_levels,
colors=curve_color,
linewidths=self.style.curve_linewidth,
alpha=self.style.curve_alpha,
)
ax.scatter(
center[0],
center[1],
facecolors="none",
edgecolors=self.style.center_color,
s=self.style.center_size,
marker=self.style.center_marker,
zorder=10,
linewidths=self.style.center_linewidth,
)
if show_index_arrow and np.linalg.norm(index_vector) > 1e-10:
ax.annotate(
"",
xy=(center[0] + index_vector[0], center[1] + index_vector[1]),
xytext=(center[0], center[1]),
arrowprops={
"arrowstyle": (
f"->,head_width={self.style.arrow_head_width}"
f",head_length={self.style.arrow_head_length}"
),
"lw": self.style.arrow_linewidth,
"color": self.style.arrow_color,
"shrinkA": 0,
"shrinkB": 0,
},
zorder=8,
)
ax.set_xlim(center[0] - extent, center[0] + extent)
ax.set_ylim(center[1] - extent, center[1] + extent)
ax.set_aspect("equal")
ax.grid(True, alpha=self.style.grid_alpha, linewidth=self.style.grid_linewidth)
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
if title:
ax.set_title(title, fontsize=self.style.fontsize)
return fig
[docs]
def plot_loss_comparison(
self,
loss_functions: list[tuple[str, Callable[[np.ndarray], np.ndarray]]],
x_range: tuple[float, float] = (-3, 3),
n_points: int = 300,
ax: Axes | None = None,
title: str | None = None,
show_legend: bool = True,
) -> Figure:
"""Plot one-dimensional loss function comparisons.
Args:
loss_functions: List of ``(name, callable)`` pairs.
x_range: Domain ``(x_min, x_max)`` for the plot.
n_points: Number of evaluation points.
ax: Matplotlib axes (created if ``None``).
title: Optional axes title.
show_legend: Whether to display a legend.
Returns:
The matplotlib ``Figure`` containing the plot.
"""
if ax is None:
fig, ax = plt.subplots(figsize=self.style.figsize)
else:
fig = cast(Figure, ax.get_figure())
assert fig is not None
x = np.linspace(x_range[0], x_range[1], n_points)
assert self.style.color_palette is not None
for idx, (name, loss_fn) in enumerate(loss_functions):
ax.plot(
x,
loss_fn(x),
label=name,
color=self.style.color_palette[idx % len(self.style.color_palette)],
linewidth=self.style.loss_linewidth,
)
ax.axhline(y=0, color="black", linewidth=0.4, linestyle="--", alpha=0.3)
ax.axvline(x=0, color="black", linewidth=0.4, linestyle="--", alpha=0.3)
ax.grid(True, alpha=self.style.grid_alpha, linewidth=self.style.grid_linewidth)
ax.set_xlabel(r"Distance $d$")
ax.set_ylabel(r"Loss $L(d)$")
if show_legend:
legend = ax.legend(frameon=True, loc="best", edgecolor="black", fancybox=False)
legend.get_frame().set_linewidth(self.style.axis_linewidth)
if title:
ax.set_title(title, fontsize=self.style.fontsize)
return fig
def _style_axes(
self, ax: Axes, X: np.ndarray, title: str | None = None, show_legend: bool = False
) -> None:
"""Apply common axis styling: equal aspect, grid, labels, limits."""
x_min, x_max = X[:, 0].min(), X[:, 0].max()
y_min, y_max = X[:, 1].min(), X[:, 1].max()
max_range = max(x_max - x_min, y_max - y_min)
x_center, y_center = (x_min + x_max) / 2, (y_min + y_max) / 2
margin = 0.15 * max_range
ax.set_xlim(x_center - max_range / 2 - margin, x_center + max_range / 2 + margin)
ax.set_ylim(y_center - max_range / 2 - margin, y_center + max_range / 2 + margin)
ax.set_aspect("equal", adjustable="box")
ax.grid(True, alpha=self.style.grid_alpha, linewidth=self.style.grid_linewidth)
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
if title:
ax.set_title(title, fontsize=self.style.fontsize)
if show_legend:
legend = ax.legend(
frameon=True,
loc="upper left",
bbox_to_anchor=(0.02, 0.98),
handletextpad=0.3,
columnspacing=0.5,
edgecolor="black",
fancybox=False,
)
for text in legend.get_texts():
text.set_fontsize(self.style.fontsize - 2)
legend.get_frame().set_linewidth(self.style.axis_linewidth)