Skip to content

Basic projection examples

The following cell defines a hypersurface starting from a graph (f_paraboloid) and then defines an implicit function (f_implicit_paraboloid) that returns how far a point is from the surface. The solver is then used to project points onto the surface defined by the implicit function. The graph is defined as

\[z = f(x,y): \mathbb{R}^2\rightarrow \mathbb{R}= x^2 + y^2\]

and the implicit function is defined as

\[f_{implicit}(x,y,z): \mathbb{R}^3\rightarrow \mathbb{R} = f(x,y) - z = x^2 + y^2 - z\]

When f_implicit(x, y, z) = 0, the point (x, y, z) lies on the surface of the paraboloid. The solver uses the f_implicit formulation to project points onto the surface.

# Set JAX config before importing JAX
NGRID=50
PLOT_W = 400
PLOT_H = 400
import os
os.environ['JAX_PLATFORMS'] = "cpu"  # use "cuda" if you have a GPU
os.environ['JAX_ENABLE_X64'] = "1"

# generate gaussian samples in 3d and reproject them onto the paraboloid equations
import numpy as np
import jax
from jnlr.reconcile import make_solver_alm_optax as make_solver
import jax.numpy as jnp

# generate gaussian samples in 3d and reproject them onto the paraboloid equations
n_samples = 100
X = np.random.random((n_samples, 3))*2-1

# define an implicit function. Each component of the function (in this case m=1) returns how far a point is from the surface
def f_paraboloid(v):
    x, y = v
    return x**2 + y**2

def f_implicit_paraboloid(v):
    z = v[2]
    return f_paraboloid(v[:2]) - z

solver = make_solver(f_implicit_paraboloid, n_iterations=30)
X_proj = solver(X)

print("mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_implicit_paraboloid)(X)))))
print("mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_implicit_paraboloid)(X_proj)))))
mean abs f before projection: 9.34e-01
mean abs f after projection: 8.36e-10

In test.manifolds there are other examples of implicit functions that can be used to define hypersurfaces. Additionally, the jnlr.function_utils module provides a utility function f_impl that can convert a graph function into an implicit function. This is particularly useful for standard benchmark functions like Ackley and Rastrigin.

import jnlr.utils.manifolds as mfs
from jnlr.utils.function_utils import f_impl


solver = make_solver(f_impl(mfs.f_paraboloid), n_iterations=10)
X_proj = solver(X)

print("Paraboloid, mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_paraboloid))(X)))))
print("Paraboloid, mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_paraboloid))(X_proj)))))


solver = make_solver(f_impl(mfs.f_rastrigin), n_iterations=10)
X_proj = solver(X)

print("Ackley, mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_rastrigin))(X)))))
print("Ackley, mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_rastrigin))(X_proj)))))
Paraboloid, mean abs f before projection: 8.85e-01
Paraboloid, mean abs f after projection: 1.10e-09
Ackley, mean abs f before projection: 1.93e+01
Ackley, mean abs f after projection: 4.44e-10
from jnlr.utils.plot_utils import plot_3d_projection
plot_3d_projection(X, f_paraboloid, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=4, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
plot_3d_projection(X, mfs.f_abs, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=40, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
plot_3d_projection(X, mfs.f_ackley, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
import jax.numpy as jnp

def f_shubert(v):
    """
    Shubert function. Input: array of shape (2,)
    Use vmap externally for batching.
    """
    x1, x2 = v
    total1 = 0.0
    total2 = 0.0
    for j in range(1, 6):
        total1 += j * jnp.cos((j + 1) * x1 + j)
        total2 += j * jnp.cos((j + 1) * x2 + j)
    return total1 * total2 / 100

plot_3d_projection(X, f_shubert, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
plot_3d_projection(X, lambda z: mfs.f_rastrigin(z)/100, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
# generate gaussian samples in 3d and reproject them onto the paraboloid equations
n_samples = 100
X = np.random.random((n_samples, 3))*2-1

# define an implicit function for the sphere. Each component of the function (in this case m=1) returns how far a point is from the surface

def f_implicit_sphere(v):
    return v[0]**2 + v[1]**2 + v[2]**2 -1

plot_3d_projection(X, f_implicit=f_implicit_sphere, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20, n_grid=NGRID, width=PLOT_W, height=PLOT_H)
X = np.random.normal(size=(n_samples, 3), scale=0.2) + np.array([0.9, 0.9, 0])[None, :]
X_proj = solver(X)
fig = plot_3d_projection(X, f_paraboloid, show_kde=True, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=4, n_grid=NGRID, lo=-2*np.ones(2), hi=+2*np.ones(2), n_isolines=7, width=PLOT_W, height=PLOT_H)
fig