Skip to content

API Reference

vizopt.base

Core abstractions for the optimization framework.

A term in an objective function.

Attributes:

Name Type Description
name str

A name for the term, e.g. "total distance".

compute Callable[[OptimVars, Any], Array]

A function that computes the value of the term with arguments optim_vars, input_parameters

multiplier float

A multiplicative factor for the term.

schedule Callable[[Array], Array] | None

Optional JAX-compatible callable (step: Array) -> Array that returns a scalar multiplier for the given iteration step. The effective weight is multiplier * schedule(step). Must use JAX ops (e.g. jnp.minimum, jnp.where) so that it can be traced through without recompilation. None means constant 1.0 (no scheduling).

Build a composite objective function from a list of terms.

Parameters:

Name Type Description Default
terms list[ObjectiveTerm]

Objective terms to sum, each weighted by its multiplier.

required
input_parameters Any

Fixed data passed to each term's compute function.

required

Returns:

Type Description
Callable[[OptimVars, Array], Array]

A callable fun(optim_vars, step) -> scalar suitable for gradient

Callable[[OptimVars, Array], Array]

descent. step is the current iteration as a JAX int32 array and

Callable[[OptimVars, Array], Array]

is passed to each term's schedule (if any).

Bases: Generic[InputParams, OptimVars]

A template for a class of optimization problems.

An instance represents a specific type of optimization problem (e.g. bubble layout optimization), independently of any particular input data. Call :meth:instantiate with concrete input parameters to obtain a runnable :class:OptimizationProblem.

If input_params_class is provided, it must be a Pydantic model class. instantiate will call model_validate on the supplied parameters, triggering Pydantic validation and coercion before the problem is created.

Attributes:

Name Type Description
terms list[ObjectiveTerm]

Objective terms defining the loss function.

initialize Callable[[InputParams, int], OptimVars]

Callable that produces initial optimization variables from input_parameters.

input_params_class type[InputParams] | None

Optional Pydantic model class for input parameters. When set, validation is performed at instantiation time.

plot_configuration Callable[[OptimVars, InputParams], None] | None

Optional callable to visualize a configuration. Signature: plot_configuration(optim_vars, input_parameters).

svg_configuration Callable[[list, InputParams, int], list[dict]] | None

Optional callable to produce SVG element specs for animation. Signature: svg_configuration(snapshots, input_parameters, size) -> list[dict] where each dict has a "tag" key and SVG attribute keys; list values are animated per-frame, scalar values are static.

instantiate(input_parameters, weight_overrides=None)

Create a runnable problem instance from concrete input parameters.

If input_params_class is set, validates input_parameters via model_validate before creating the problem. The plain dict is passed through to the problem unchanged (Pydantic is used for validation only, so that input_parameters remains a JAX-compatible pytree).

Parameters:

Name Type Description Default
input_parameters InputParams

Fixed data for this problem instance.

required
x

Optional mapping of term name to multiplier. Overrides the default multiplier for the named terms. Unknown names raise KeyError.

required

Returns:

Name Type Description
An OptimizationProblem[InputParams, OptimVars]

class:OptimizationProblem ready to optimize.

Raises:

Type Description
KeyError

If a name in weight_overrides does not match any term.

ValidationError

If input_params_class is set and validation fails.

Bases: Generic[InputParams, OptimVars]

An optimization problem.

Attributes:

Name Type Description
input_parameters InputParams

Fixed data for the problem (not optimized).

terms list[ObjectiveTerm]

Objective terms defining the loss function.

initialize Callable[[InputParams, int], OptimVars]

Callable that produces initial optimization variables from input_parameters.

plot_configuration Callable[[OptimVars, InputParams], None] | None

Optional callable to visualize a configuration. Signature: plot_configuration(optim_vars, input_parameters).

svg_configuration Callable[[list, InputParams, int], list[dict]] | None

Optional callable to produce SVG element specs for animation. Signature: svg_configuration(snapshots, input_parameters, size) -> list[dict].

optimize(optim_config=None, callback=None, track_every=10)

Run gradient descent to minimize the objective.

When optim_config.n_restarts > 1, the optimization is run that many times with seeds seed, seed + 1, …. The result with the lowest final loss is returned.

Parameters:

Name Type Description Default
optim_config OptimConfig | None

Optimizer settings (iterations, learning rate, seeds, restarts). Uses :class:OptimConfig defaults when None.

None
callback Callback | None

Optional callback called after each iteration with (iteration, loss, optim_vars, grads).

None
track_every int

Record per-term history every this many iterations.

10

Returns:

Type Description
OptimVars

Tuple of (optimized variables, history). History is a list of

list[dict]

dicts with keys "iteration", "total", and one key per

tuple[OptimVars, list[dict]]

term name containing the weighted term value at that iteration.

tuple[OptimVars, list[dict]]

When using multiple restarts, history corresponds to the best run.


vizopt.animation

Optimization progress visualization.

Callback that saves a numpy copy of optim_vars at regular intervals.

Pass an instance as the callback argument to :meth:OptimizationProblem.optimize. Snapshots accumulate in :attr:snapshots and can be passed to :func:animate.

Parameters:

Name Type Description Default
every int

Save a snapshot every this many iterations.

10

Attributes:

Name Type Description
snapshots list[tuple[int, Any]]

List of (iteration, optim_vars) tuples, one per recorded step.

Example::

cb = SnapshotCallback(every=100)
optim_vars_opt, history = problem.optimize(n_iters=2000, callback=cb)
anim = animate(problem, cb.snapshots)

Create an animation of the optimization process.

Renders each optim_vars snapshot via problem.plot_configuration and assembles the frames into a FuncAnimation.

Parameters:

Name Type Description Default
problem OptimizationProblem

The optimization problem; must have plot_configuration set.

required
snapshots list[tuple[int, Any]]

List of (iteration, optim_vars) tuples as produced by :class:SnapshotCallback.

required
interval int

Delay between frames in milliseconds.

200

Returns:

Type Description
Any

A matplotlib.animation.FuncAnimation. In a Jupyter notebook,

Any

display with IPython.display.HTML(anim.to_jshtml()).

Raises:

Type Description
ValueError

If problem.plot_configuration is not set or snapshots is empty.

Create an animated SVG from optimization snapshots.

Uses problem.svg_configuration to obtain per-element SVG specs, then builds SMIL <animate> elements for attributes that vary across frames.

Parameters:

Name Type Description Default
problem OptimizationProblem

The optimization problem; must have svg_configuration set.

required
snapshots list[tuple[int, Any]]

List of (iteration, optim_vars) tuples as produced by :class:SnapshotCallback.

required
fps int

Frames per second.

10
size int

Width and height of the output SVG in pixels.

500
calc_mode str

"linear" for smooth interpolation or "discrete" for instant jumps between frames.

'linear'
history list[dict] | None

Optional list of history dicts as returned by :meth:OptimizationProblem.optimize (each dict has an "iteration" key and a "total" key with the aggregate loss). When provided, a loss curve is rendered below the animation with an animated marker tracking the current frame.

None
loss_curve_height int

Height in pixels of the loss curve panel, used only when history is provided.

120
log_scale bool

If True, the loss axis uses a log10 scale.

False

Returns:

Type Description
str

An SVG string. Save with Path("out.svg").write_text(svg) or

str

display in Jupyter with IPython.display.SVG(data=svg).

Raises:

Type Description
ValueError

If problem.svg_configuration is not set or snapshots is empty.


vizopt.components

Reusable JAX loss components.

Calculate the pairwise intersections of two sets of bounding boxes

This vectorized implementation is more efficient than the avoided double for loop

Parameters:

Name Type Description Default
bbox_matrix ndarray

numpy array of shape (n, 2, 2) dimensions: points, min and max, xy coordinates

required
other_bbox_matrix ndarray

numpy array of shape (m, 2, 2) dimensions: points, min and max, xy coordinates

required

Returns:

Type Description

numpy array of shape (n, m)

A penalty for the overall width and height of the drawing with circular nodes.

Parameters:

Name Type Description Default
node_xys ndarray

Array of node positions with shape (n, 2).

required
node_radii ndarray

Array of node radii with shape (n,).

required

A penalty for the overall width and height of the drawing.

Parameters:

Name Type Description Default
node_xys ndarray

Array of node positions with shape (n, 2).

required

A penalty for negative values.


vizopt.schedules

Loss term weight scheduling.

Linear warmup: ramps from 0.01 to 1.0 over a window of the run.

Parameters:

Name Type Description Default
delay_frac float

Fraction of n_iters before ramping starts.

required
ramp_frac float

Fraction of n_iters over which to ramp up.

required
n_iters int

Total number of optimization iterations.

required

Returns:

Type Description

JAX-compatible callable (step: Array) -> Array.

Linear cooldown: ramps from 1.0 down to 0.01 over a window of the run.

Parameters:

Name Type Description Default
peak_frac float

Fraction of n_iters at which the weight is still 1.0.

required
ramp_frac float

Fraction of n_iters over which to ramp down.

required
n_iters int

Total number of optimization iterations.

required

Returns:

Type Description

JAX-compatible callable (step: Array) -> Array.

Build a term_schedules dict from a flat parameter dict.

Parameters:

Name Type Description Default
params dict

Dict with fractional schedule parameters: collision_delay, collision_ramp, exclusion_delay, exclusion_ramp, area_delay, area_ramp, perimeter_delay, perimeter_ramp, attraction_peak, attraction_ramp. All values are fractions of n_iters.

required
n_iters int

Total number of optimization iterations. Schedules scale automatically — the same params work for any run length.

required

Returns:

Type Description
dict

Dict mapping term name to a JAX-compatible schedule callable.