Basic projection examples
The following cell defines a hypersurface starting from a graph (f_paraboloid) and then defines an implicit function (f_implicit_paraboloid) that returns how far a point is from the surface. The solver is then used to project points onto the surface defined by the implicit function. The graph is defined as
and the implicit function is defined as
When f_implicit(x, y, z) = 0, the point (x, y, z) lies on the surface of the paraboloid. The solver uses the f_implicit formulation to project points onto the surface.
# Set JAX config before importing JAX
NGRID=50
PLOT_W = 400
PLOT_H = 400
import os
os.environ['JAX_PLATFORMS'] = "cpu" # use "cuda" if you have a GPU
os.environ['JAX_ENABLE_X64'] = "1"
# generate gaussian samples in 3d and reproject them onto the paraboloid equations
import numpy as np
import jax
from jnlr.reconcile import make_solver_alm_optax as make_solver
import jax.numpy as jnp
# generate gaussian samples in 3d and reproject them onto the paraboloid equations
n_samples = 100
X = np.random.random((n_samples, 3))*2-1
# define an implicit function. Each component of the function (in this case m=1) returns how far a point is from the surface
def f_paraboloid(v):
x, y = v
return x**2 + y**2
def f_implicit_paraboloid(v):
z = v[2]
return f_paraboloid(v[:2]) - z
solver = make_solver(f_implicit_paraboloid, n_iterations=30)
X_proj = solver(X)
print("mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_implicit_paraboloid)(X)))))
print("mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_implicit_paraboloid)(X_proj)))))
mean abs f before projection: 9.34e-01
mean abs f after projection: 8.36e-10
In test.manifolds there are other examples of implicit functions that can be used to define hypersurfaces. Additionally, the jnlr.function_utils module provides a utility function f_impl that can convert a graph function into an implicit function. This is particularly useful for standard benchmark functions like Ackley and Rastrigin.
import jnlr.utils.manifolds as mfs
from jnlr.utils.function_utils import f_impl
solver = make_solver(f_impl(mfs.f_paraboloid), n_iterations=10)
X_proj = solver(X)
print("Paraboloid, mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_paraboloid))(X)))))
print("Paraboloid, mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_paraboloid))(X_proj)))))
solver = make_solver(f_impl(mfs.f_rastrigin), n_iterations=10)
X_proj = solver(X)
print("Ackley, mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_rastrigin))(X)))))
print("Ackley, mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_rastrigin))(X_proj)))))
Paraboloid, mean abs f before projection: 8.85e-01
Paraboloid, mean abs f after projection: 1.10e-09
Ackley, mean abs f before projection: 1.93e+01
Ackley, mean abs f after projection: 4.44e-10
from jnlr.utils.plot_utils import plot_3d_projection
plot_3d_projection(X, f_paraboloid, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=4, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
plot_3d_projection(X, mfs.f_abs, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=40, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
plot_3d_projection(X, mfs.f_ackley, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
import jax.numpy as jnp
def f_shubert(v):
"""
Shubert function. Input: array of shape (2,)
Use vmap externally for batching.
"""
x1, x2 = v
total1 = 0.0
total2 = 0.0
for j in range(1, 6):
total1 += j * jnp.cos((j + 1) * x1 + j)
total2 += j * jnp.cos((j + 1) * x2 + j)
return total1 * total2 / 100
plot_3d_projection(X, f_shubert, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
plot_3d_projection(X, lambda z: mfs.f_rastrigin(z)/100, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
# generate gaussian samples in 3d and reproject them onto the paraboloid equations
n_samples = 100
X = np.random.random((n_samples, 3))*2-1
# define an implicit function for the sphere. Each component of the function (in this case m=1) returns how far a point is from the surface
def f_implicit_sphere(v):
return v[0]**2 + v[1]**2 + v[2]**2 -1
plot_3d_projection(X, f_implicit=f_implicit_sphere, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
X = np.random.normal(size=(n_samples, 3), scale=0.2) + np.array([0.9, 0.9, 0])[None, :]
X_proj = solver(X)
fig = plot_3d_projection(X, f_paraboloid, show_kde=True, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=4, n_grid=NGRID, lo=-2*np.ones(2), hi=+2*np.ones(2), n_isolines=7, width=PLOT_W, height=PLOT_H)
fig