Soft interval overlap for collision losses¶
In gradient-based layout optimization, collision between shapes can be penalized by measuring how much their bounding intervals overlap. The overlap between two 1D intervals is min(r1, r2) - max(l1, l2), clamped to zero.
For the loss to drive gradient descent, it must be differentiable everywhere — but min, max, and relu all have zero or undefined gradients away from the transition. Replacing them with smooth approximations fixes this:
- Softplus smooths the final
reluclamp:log(1 + exp(x)) - LogSumExp smooths the inner
min/max:soft_max(a, b) = log(exp(βa) + exp(βb)) / β
The soft min/max also solves a subtler problem: when one interval is entirely contained inside the other, the hard overlap equals the inner interval's length — a constant with respect to position, so its gradient is zero. The optimizer receives no signal about which direction to move the shapes apart. The soft versions remain sensitive to position even in this fully-contained case, providing a gradient that pushes the shapes toward separation.
A sharpness parameter β controls the trade-off: larger β gives a tighter approximation of the hard overlap (stronger signal only when shapes actually collide), while smaller β creates a longer-range repulsive force that pushes shapes apart even before they touch.
interval_1 = np.array([0, 1])
interval_2 = np.array([0, 0.5])
def calculate_hard_overlap(interval_1, interval_2):
hard_overlap_before_relu = min(interval_1[1], interval_2[1]) - max(interval_1[0], interval_2[0])
hard_overlap = max(0, hard_overlap_before_relu)
return hard_overlap
def calculate_softplus_overlap(interval_1, interval_2):
overlap_before_relu = min(interval_1[1], interval_2[1]) - max(interval_1[0], interval_2[0])
soft_overlap = np.log(1 + np.exp(overlap_before_relu))
return soft_overlap
beta_clamp = 3.0 # sharpness of the softplus zero-clamp
_, ax = plt.subplots(figsize=(4, 3))
ax.plot(interval_1, [0, 0], "k", marker=".")
ax.plot(interval_2, [1, 1], "r", marker=".")
dx_list = np.linspace(-1, 2, 100)
shifted = interval_2 + dx_list[:, None] # shape (100, 2)
raw_overlap = np.minimum(interval_1[1], shifted[:, 1]) - np.maximum(interval_1[0], shifted[:, 0])
hard_overlap_list = np.maximum(0, raw_overlap)
soft_overlap_list = np.log(1 + np.exp(beta_clamp * raw_overlap)) / beta_clamp
_, ax = plt.subplots(figsize=(4, 3))
ax.plot(dx_list, hard_overlap_list, label="hard overlap")
ax.plot(dx_list, soft_overlap_list, label="soft overlap")
ax.legend()


dx_list = np.linspace(-1, 2, 100)
shifted = interval_2 + dx_list[:, None] # shape (100, 2)
beta_minmax = 5.0 # sharpness of the soft min/max boundary detection
soft_max = lambda a, b: np.log(np.exp(beta_minmax * a) + np.exp(beta_minmax * b)) / beta_minmax
soft_min = lambda a, b: -np.log(np.exp(-beta_minmax * a) + np.exp(-beta_minmax * b)) / beta_minmax
raw_overlap = soft_min(interval_1[1], shifted[:, 1]) - soft_max(interval_1[0], shifted[:, 0])
hard_overlap_list = np.maximum(0, raw_overlap)
beta_clamp = 3.0 # sharpness of the softplus zero-clamp
soft_overlap_list = np.log(1 + np.exp(beta_clamp*raw_overlap))/beta_clamp
_, ax = plt.subplots(figsize=(4, 3))
ax.plot(dx_list, hard_overlap_list, label="hard overlap")
ax.plot(dx_list, soft_overlap_list, label="soft overlap")
ax.legend()

Example: simple "bubble plot"¶
n_bubbles = 100
xy_init = 5*rng.uniform(size=(n_bubbles, 2))
bubble_df = pd.DataFrame(xy_init, columns=["x", "y"])
bubble_df["radius"] = rng.lognormal(size=n_bubbles) / 5
bubble_df.head()
_, ax = plt.subplots()
for i_row, row in bubble_df.iterrows():
circle = plt.Circle((row["x"], row["y"]), row["radius"], color="k", alpha=0.2)
ax.add_patch(circle)
xy_max = max(bubble_df[["x", "y"]].abs().max())
ax.set_xlim(-xy_max, xy_max)
ax.set_ylim(-xy_max, xy_max)
ax.axis("equal")

from jax import numpy as jnp
import numpy as np
from vizopt.base import ObjectiveTerm, OptimizationProblemTemplate
def _soft_circle_overlap(optim_vars, input_params):
"""Sum of soft pairwise circle overlaps (upper triangle only)."""
pos = optim_vars["positions"] # (n, 2)
radii = input_params["radii"] # (n,)
beta = 10.0
diff = pos[:, None, :] - pos[None, :, :] # (n, n, 2)
dist = jnp.sqrt(jnp.sum(diff ** 2, axis=-1) + 1e-8) # (n, n)
sum_radii = radii[:, None] + radii[None, :] # (n, n)
raw_overlap = sum_radii - dist
soft_overlap = jnp.log(1 + jnp.exp(beta * raw_overlap)) / beta
n = radii.shape[0]
rows, cols = np.triu_indices(n, 1)
return jnp.sum(soft_overlap[rows, cols])
def _hard_circle_overlap(optim_vars, input_params):
"""Sum of hard pairwise circle overlaps (upper triangle only)."""
pos = optim_vars["positions"] # (n, 2)
radii = input_params["radii"] # (n,)
diff = pos[:, None, :] - pos[None, :, :] # (n, n, 2)
dist = jnp.sqrt(jnp.sum(diff ** 2, axis=-1) + 1e-8) # (n, n)
sum_radii = radii[:, None] + radii[None, :] # (n, n)
raw_overlap = sum_radii - dist
hard_overlap = jnp.clip(raw_overlap, a_min=0)
n = radii.shape[0]
rows, cols = np.triu_indices(n, 1)
return jnp.sum(hard_overlap[rows, cols])
def _packing_loss(optim_vars, input_params):
"""Pull bubbles toward their centroid."""
pos = optim_vars["positions"]
centroid = jnp.mean(pos, axis=0)
return jnp.sum((pos - centroid) ** 2)
def _total_size_loss(optim_vars, input_params):
"""Minimize total drawing size width and height."""
pos = optim_vars["positions"]
radii = input_params["radii"]
max_coord = jnp.max(pos + radii.reshape(-1, 1), axis=0)
min_coord = jnp.min(pos - radii.reshape(-1, 1), axis=0)
return jnp.sum(max_coord - min_coord)
def _initialize(input_params, seed):
return {"positions": input_params["initial_positions"].copy()}
def _plot_bubbles(optim_vars, input_params):
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, positions, title in [
(axes[0], input_params["initial_positions"], "Initial"),
(axes[1], np.array(optim_vars["positions"]), "Optimized"),
]:
radii = input_params["radii"]
for (x, y), r in zip(positions, radii):
ax.add_patch(plt.Circle((x, y), r, color="steelblue", alpha=0.4, linewidth=0))
margin = radii.max()
ax.set_xlim(positions[:, 0].min() - margin, positions[:, 0].max() + margin)
ax.set_ylim(positions[:, 1].min() - margin, positions[:, 1].max() + margin)
ax.set_aspect("equal")
ax.set_title(title)
plt.tight_layout()
def _svg_bubbles(snapshots, input_params, size):
radii = input_params["radii"]
all_positions = np.stack([s["positions"] for _, s in snapshots]) # (frames, n, 2)
margin = radii.max()
x_min = all_positions[:, :, 0].min() - margin
x_max = all_positions[:, :, 0].max() + margin
y_min = all_positions[:, :, 1].min() - margin
y_max = all_positions[:, :, 1].max() + margin
span = max(x_max - x_min, y_max - y_min)
def to_x(x): return (x - x_min) / span * size
def to_y(y): return (1 - (y - y_min) / span) * size
def to_r(r): return r / span * size
elements = []
for i, r in enumerate(radii):
elements.append({
"tag": "circle",
"r": f"{to_r(r):.2f}",
"fill": "steelblue",
"fill-opacity": "0.4",
"cx": [f"{to_x(s['positions'][i, 0]):.2f}" for _, s in snapshots],
"cy": [f"{to_y(s['positions'][i, 1]):.2f}" for _, s in snapshots],
})
return elements
bubble_packing_soft_template = OptimizationProblemTemplate(
terms=[
ObjectiveTerm(name="collision", compute=_soft_circle_overlap, multiplier=10.0),
ObjectiveTerm(name="packing", compute=_packing_loss, multiplier=0.0),
ObjectiveTerm(name="total_size", compute=_total_size_loss, multiplier=1.0),
],
initialize=_initialize,
plot_configuration=_plot_bubbles,
svg_configuration=_svg_bubbles,
)
bubble_packing_hard_template = OptimizationProblemTemplate(
terms=[
ObjectiveTerm(name="collision", compute=_hard_circle_overlap, multiplier=10.0),
ObjectiveTerm(name="packing", compute=_packing_loss, multiplier=0.0),
ObjectiveTerm(name="total_size", compute=_total_size_loss, multiplier=1.0),
],
initialize=_initialize,
plot_configuration=_plot_bubbles,
svg_configuration=_svg_bubbles,
)
input_parameters = {
"radii": bubble_df["radius"].values,
"initial_positions": bubble_df[["x", "y"]].values,
}
problem_soft = bubble_packing_soft_template.instantiate(input_parameters)
optim_cfg = OptimConfig(n_iters=8000, learning_rate=0.01)
optim_vars_soft, history_soft = problem_soft.optimize(optim_cfg, track_every=100)
problem_hard = bubble_packing_hard_template.instantiate(input_parameters)
optim_vars_hard, history_hard = problem_hard.optimize(optim_cfg, track_every=100)
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
df_soft = pd.DataFrame(history_soft).set_index("iteration")
df_hard = pd.DataFrame(history_hard).set_index("iteration")
df_soft.plot(ax=axes[0], marker=".")
axes[0].set_title("Soft overlap")
axes[0].set_ylabel("Loss value")
df_hard.plot(ax=axes[1], marker=".")
axes[1].set_title("Hard overlap")
plt.tight_layout()

def plot_bubbles_comparison(optim_vars_soft, optim_vars_hard, input_params):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
cases = [
(axes[0], input_params["initial_positions"], "Initial"),
(axes[1], np.array(optim_vars_soft["positions"]), "Soft overlap"),
(axes[2], np.array(optim_vars_hard["positions"]), "Hard overlap"),
]
radii = input_params["radii"]
for ax, positions, title in cases:
for (x, y), r in zip(positions, radii):
ax.add_patch(plt.Circle((x, y), r, color="steelblue", alpha=0.4, linewidth=0))
margin = radii.max()
ax.set_xlim(positions[:, 0].min() - margin, positions[:, 0].max() + margin)
ax.set_ylim(positions[:, 1].min() - margin, positions[:, 1].max() + margin)
ax.set_aspect("equal")
ax.set_title(title)
plt.tight_layout()
plot_bubbles_comparison(optim_vars_soft, optim_vars_hard, input_parameters)

from vizopt.animation import SnapshotCallback, snapshots_to_animated_svg
from IPython.display import SVG
snapshot_cb_soft = SnapshotCallback(every=50)
problem_soft_anim = bubble_packing_soft_template.instantiate(input_parameters)
optim_cfg = OptimConfig(n_iters=2000, learning_rate=0.01)
problem_soft_anim.optimize(optim_cfg, callback=snapshot_cb_soft)
svg_soft = snapshots_to_animated_svg(problem_soft_anim, snapshot_cb_soft.snapshots, fps=12)
SVG(data=svg_soft)
snapshot_cb_hard = SnapshotCallback(every=50)
problem_hard_anim = bubble_packing_hard_template.instantiate(input_parameters)
problem_hard_anim.optimize(optim_cfg, callback=snapshot_cb_hard)
svg_hard = snapshots_to_animated_svg(problem_hard_anim, snapshot_cb_hard.snapshots, fps=12)
SVG(data=svg_hard)