Skip to content

Soft interval overlap for collision losses

In gradient-based layout optimization, collision between shapes can be penalized by measuring how much their bounding intervals overlap. The overlap between two 1D intervals is min(r1, r2) - max(l1, l2), clamped to zero.

For the loss to drive gradient descent, it must be differentiable everywhere — but min, max, and relu all have zero or undefined gradients away from the transition. Replacing them with smooth approximations fixes this:

  • Softplus smooths the final relu clamp: log(1 + exp(x))
  • LogSumExp smooths the inner min/max: soft_max(a, b) = log(exp(βa) + exp(βb)) / β

The soft min/max also solves a subtler problem: when one interval is entirely contained inside the other, the hard overlap equals the inner interval's length — a constant with respect to position, so its gradient is zero. The optimizer receives no signal about which direction to move the shapes apart. The soft versions remain sensitive to position even in this fully-contained case, providing a gradient that pushes the shapes toward separation.

A sharpness parameter β controls the trade-off: larger β gives a tighter approximation of the hard overlap (stronger signal only when shapes actually collide), while smaller β creates a longer-range repulsive force that pushes shapes apart even before they touch.

import numpy as np
from matplotlib import pyplot as plt
from vizopt.base import OptimConfig
interval_1 = np.array([0, 1])
interval_2 = np.array([0, 0.5])

def calculate_hard_overlap(interval_1, interval_2):
    hard_overlap_before_relu = min(interval_1[1], interval_2[1]) - max(interval_1[0], interval_2[0])
    hard_overlap = max(0, hard_overlap_before_relu)
    return hard_overlap

def calculate_softplus_overlap(interval_1, interval_2):
    overlap_before_relu = min(interval_1[1], interval_2[1]) - max(interval_1[0], interval_2[0])
    soft_overlap = np.log(1 + np.exp(overlap_before_relu))
    return soft_overlap

beta_clamp = 3.0  # sharpness of the softplus zero-clamp

_, ax = plt.subplots(figsize=(4, 3))
ax.plot(interval_1, [0, 0], "k", marker=".")
ax.plot(interval_2, [1, 1], "r", marker=".")

dx_list = np.linspace(-1, 2, 100)
shifted = interval_2 + dx_list[:, None]  # shape (100, 2)
raw_overlap = np.minimum(interval_1[1], shifted[:, 1]) - np.maximum(interval_1[0], shifted[:, 0])

hard_overlap_list = np.maximum(0, raw_overlap)
soft_overlap_list = np.log(1 + np.exp(beta_clamp * raw_overlap)) / beta_clamp

_, ax = plt.subplots(figsize=(4, 3))
ax.plot(dx_list, hard_overlap_list, label="hard overlap")
ax.plot(dx_list, soft_overlap_list, label="soft overlap")
ax.legend()

output

output

dx_list = np.linspace(-1, 2, 100)
shifted = interval_2 + dx_list[:, None]  # shape (100, 2)

beta_minmax = 5.0  # sharpness of the soft min/max boundary detection
soft_max = lambda a, b: np.log(np.exp(beta_minmax * a) + np.exp(beta_minmax * b)) / beta_minmax
soft_min = lambda a, b: -np.log(np.exp(-beta_minmax * a) + np.exp(-beta_minmax * b)) / beta_minmax
raw_overlap = soft_min(interval_1[1], shifted[:, 1]) - soft_max(interval_1[0], shifted[:, 0])

hard_overlap_list = np.maximum(0, raw_overlap)

beta_clamp = 3.0  # sharpness of the softplus zero-clamp
soft_overlap_list = np.log(1 + np.exp(beta_clamp*raw_overlap))/beta_clamp

_, ax = plt.subplots(figsize=(4, 3))
ax.plot(dx_list, hard_overlap_list, label="hard overlap")
ax.plot(dx_list, soft_overlap_list, label="soft overlap")
ax.legend()

output

Example: simple "bubble plot"

from numpy.random import default_rng
import pandas as pd

rng = default_rng()
n_bubbles = 100
xy_init = 5*rng.uniform(size=(n_bubbles, 2))
bubble_df = pd.DataFrame(xy_init, columns=["x", "y"])
bubble_df["radius"] = rng.lognormal(size=n_bubbles) / 5
bubble_df.head()
_, ax = plt.subplots()
for i_row, row in bubble_df.iterrows():
    circle = plt.Circle((row["x"], row["y"]), row["radius"], color="k", alpha=0.2)
    ax.add_patch(circle)
xy_max = max(bubble_df[["x", "y"]].abs().max())
ax.set_xlim(-xy_max, xy_max)
ax.set_ylim(-xy_max, xy_max)
ax.axis("equal")

output

from jax import numpy as jnp
import numpy as np
from vizopt.base import ObjectiveTerm, OptimizationProblemTemplate


def _soft_circle_overlap(optim_vars, input_params):
    """Sum of soft pairwise circle overlaps (upper triangle only)."""
    pos = optim_vars["positions"]               # (n, 2)
    radii = input_params["radii"]               # (n,)
    beta = 10.0
    diff = pos[:, None, :] - pos[None, :, :]   # (n, n, 2)
    dist = jnp.sqrt(jnp.sum(diff ** 2, axis=-1) + 1e-8)  # (n, n)
    sum_radii = radii[:, None] + radii[None, :]           # (n, n)
    raw_overlap = sum_radii - dist
    soft_overlap = jnp.log(1 + jnp.exp(beta * raw_overlap)) / beta
    n = radii.shape[0]
    rows, cols = np.triu_indices(n, 1)
    return jnp.sum(soft_overlap[rows, cols])


def _hard_circle_overlap(optim_vars, input_params):
    """Sum of hard pairwise circle overlaps (upper triangle only)."""
    pos = optim_vars["positions"]               # (n, 2)
    radii = input_params["radii"]               # (n,)
    diff = pos[:, None, :] - pos[None, :, :]   # (n, n, 2)
    dist = jnp.sqrt(jnp.sum(diff ** 2, axis=-1) + 1e-8)  # (n, n)
    sum_radii = radii[:, None] + radii[None, :]           # (n, n)
    raw_overlap = sum_radii - dist
    hard_overlap = jnp.clip(raw_overlap, a_min=0)
    n = radii.shape[0]
    rows, cols = np.triu_indices(n, 1)
    return jnp.sum(hard_overlap[rows, cols])


def _packing_loss(optim_vars, input_params):
    """Pull bubbles toward their centroid."""
    pos = optim_vars["positions"]
    centroid = jnp.mean(pos, axis=0)
    return jnp.sum((pos - centroid) ** 2)

def _total_size_loss(optim_vars, input_params):
    """Minimize total drawing size width and height."""
    pos = optim_vars["positions"]
    radii = input_params["radii"]
    max_coord = jnp.max(pos + radii.reshape(-1, 1), axis=0)
    min_coord = jnp.min(pos - radii.reshape(-1, 1), axis=0)
    return jnp.sum(max_coord - min_coord)


def _initialize(input_params, seed):
    return {"positions": input_params["initial_positions"].copy()}


def _plot_bubbles(optim_vars, input_params):
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    for ax, positions, title in [
        (axes[0], input_params["initial_positions"], "Initial"),
        (axes[1], np.array(optim_vars["positions"]), "Optimized"),
    ]:
        radii = input_params["radii"]
        for (x, y), r in zip(positions, radii):
            ax.add_patch(plt.Circle((x, y), r, color="steelblue", alpha=0.4, linewidth=0))
        margin = radii.max()
        ax.set_xlim(positions[:, 0].min() - margin, positions[:, 0].max() + margin)
        ax.set_ylim(positions[:, 1].min() - margin, positions[:, 1].max() + margin)
        ax.set_aspect("equal")
        ax.set_title(title)
    plt.tight_layout()


def _svg_bubbles(snapshots, input_params, size):
    radii = input_params["radii"]
    all_positions = np.stack([s["positions"] for _, s in snapshots])  # (frames, n, 2)
    margin = radii.max()
    x_min = all_positions[:, :, 0].min() - margin
    x_max = all_positions[:, :, 0].max() + margin
    y_min = all_positions[:, :, 1].min() - margin
    y_max = all_positions[:, :, 1].max() + margin
    span = max(x_max - x_min, y_max - y_min)

    def to_x(x): return (x - x_min) / span * size
    def to_y(y): return (1 - (y - y_min) / span) * size
    def to_r(r): return r / span * size

    elements = []
    for i, r in enumerate(radii):
        elements.append({
            "tag": "circle",
            "r": f"{to_r(r):.2f}",
            "fill": "steelblue",
            "fill-opacity": "0.4",
            "cx": [f"{to_x(s['positions'][i, 0]):.2f}" for _, s in snapshots],
            "cy": [f"{to_y(s['positions'][i, 1]):.2f}" for _, s in snapshots],
        })
    return elements


bubble_packing_soft_template = OptimizationProblemTemplate(
    terms=[
        ObjectiveTerm(name="collision", compute=_soft_circle_overlap, multiplier=10.0),
        ObjectiveTerm(name="packing", compute=_packing_loss, multiplier=0.0),
        ObjectiveTerm(name="total_size", compute=_total_size_loss, multiplier=1.0),
    ],
    initialize=_initialize,
    plot_configuration=_plot_bubbles,
    svg_configuration=_svg_bubbles,
)

bubble_packing_hard_template = OptimizationProblemTemplate(
    terms=[
        ObjectiveTerm(name="collision", compute=_hard_circle_overlap, multiplier=10.0),
        ObjectiveTerm(name="packing", compute=_packing_loss, multiplier=0.0),
        ObjectiveTerm(name="total_size", compute=_total_size_loss, multiplier=1.0),
    ],
    initialize=_initialize,
    plot_configuration=_plot_bubbles,
    svg_configuration=_svg_bubbles,
)
input_parameters = {
    "radii": bubble_df["radius"].values,
    "initial_positions": bubble_df[["x", "y"]].values,
}

problem_soft = bubble_packing_soft_template.instantiate(input_parameters)
optim_cfg = OptimConfig(n_iters=8000, learning_rate=0.01)
optim_vars_soft, history_soft = problem_soft.optimize(optim_cfg, track_every=100)

problem_hard = bubble_packing_hard_template.instantiate(input_parameters)
optim_vars_hard, history_hard = problem_hard.optimize(optim_cfg, track_every=100)
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)

df_soft = pd.DataFrame(history_soft).set_index("iteration")
df_hard = pd.DataFrame(history_hard).set_index("iteration")

df_soft.plot(ax=axes[0], marker=".")
axes[0].set_title("Soft overlap")
axes[0].set_ylabel("Loss value")

df_hard.plot(ax=axes[1], marker=".")
axes[1].set_title("Hard overlap")

plt.tight_layout()

output

def plot_bubbles_comparison(optim_vars_soft, optim_vars_hard, input_params):
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    cases = [
        (axes[0], input_params["initial_positions"], "Initial"),
        (axes[1], np.array(optim_vars_soft["positions"]), "Soft overlap"),
        (axes[2], np.array(optim_vars_hard["positions"]), "Hard overlap"),
    ]
    radii = input_params["radii"]
    for ax, positions, title in cases:
        for (x, y), r in zip(positions, radii):
            ax.add_patch(plt.Circle((x, y), r, color="steelblue", alpha=0.4, linewidth=0))
        margin = radii.max()
        ax.set_xlim(positions[:, 0].min() - margin, positions[:, 0].max() + margin)
        ax.set_ylim(positions[:, 1].min() - margin, positions[:, 1].max() + margin)
        ax.set_aspect("equal")
        ax.set_title(title)
    plt.tight_layout()

plot_bubbles_comparison(optim_vars_soft, optim_vars_hard, input_parameters)

output

from vizopt.animation import SnapshotCallback, snapshots_to_animated_svg
from IPython.display import SVG

snapshot_cb_soft = SnapshotCallback(every=50)
problem_soft_anim = bubble_packing_soft_template.instantiate(input_parameters)
optim_cfg = OptimConfig(n_iters=2000, learning_rate=0.01)
problem_soft_anim.optimize(optim_cfg, callback=snapshot_cb_soft)

svg_soft = snapshots_to_animated_svg(problem_soft_anim, snapshot_cb_soft.snapshots, fps=12)
SVG(data=svg_soft)

output

snapshot_cb_hard = SnapshotCallback(every=50)
problem_hard_anim = bubble_packing_hard_template.instantiate(input_parameters)
problem_hard_anim.optimize(optim_cfg, callback=snapshot_cb_hard)

svg_hard = snapshots_to_animated_svg(problem_hard_anim, snapshot_cb_hard.snapshots, fps=12)
SVG(data=svg_hard)

output