Skip to content

Should

jnlr.should

constant_sign_curvature

constant_sign_curvature(f, grad_f, hessian_f, vmapped_solver, z_hat: ndarray) -> jnp.ndarray

Check if forecasting RMSE is guaranteed to reduce based on the curvature condition, for a hypersurface defined by f(z) = 0 having constant sign curvature.

Parameters:

Name Type Description Default
f

Constraint function.

required
grad_f

Function returning the gradient vector of shape (n,).

required
hessian_f

Function returning the Hessian matrix of shape (n, n).

required
vmapped_solver

Function to project points onto the constraint surface.

required
z_hat ndarray

Forecasted points of shape (batch_size, n).

required

Returns:

Type Description
ndarray

np.ndarray: Boolean array indicating if the curvature condition is satisfied for each point in z_hat

Notes

Theorem 1: For a hypersurface defined by f(z) = 0, forecasting RMSE is guaranteed to reduce if:

\[\lambda_{min}(H_{restricted}(\tilde{z})) * f(\hat{z}) > 0\]

where * \( \tilde{z} \) is the projection of the forecasted point, \(\hat{z}\), onto the surface defined by \(f(z) = 0.\)

Source code in src/jnlr/should.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@partial(jax.jit, static_argnames=('f', 'grad_f', 'hessian_f', 'vmapped_solver'))
def constant_sign_curvature(f, grad_f, hessian_f, vmapped_solver, z_hat: jnp.ndarray) -> jnp.ndarray:
    r"""
    Check if forecasting RMSE is guaranteed to reduce based on the curvature condition, for a hypersurface defined by f(z) = 0 having constant sign curvature.

    Args:
        f: Constraint function.
        grad_f: Function returning the gradient vector of shape (n,).
        hessian_f: Function returning the Hessian matrix of shape (n, n).
        vmapped_solver: Function to project points onto the constraint surface.
        z_hat: Forecasted points of shape (batch_size, n).


    Returns:
        np.ndarray: Boolean array indicating if the curvature condition is satisfied for each point in z_hat


    Notes:
        Theorem 1: For a hypersurface defined by f(z) = 0, forecasting RMSE is guaranteed to reduce if:

        $$\lambda_{min}(H_{restricted}(\tilde{z})) * f(\hat{z}) > 0$$

        where
        * \( \tilde{z} \) is the projection of the forecasted point, $\hat{z}$, onto the surface defined by $f(z) = 0.$
    """

    z_tilde = vmapped_solver(z_hat)
    lambda_min = jax.vmap(min_tangent_eigenvalue, in_axes=(None, None, 0))(grad_f,  hessian_f, z_tilde)
    f_val = jax.vmap(f)(z_hat)
    return lambda_min * f_val > 0

vector_valued_convex

vector_valued_convex(f, jacobian_F, hessians_F, vmapped_solver, z_hat: ndarray) -> jnp.ndarray

Generalization of scalar curvature condition to vector-valued case.

Parameters:

Name Type Description Default
f

Constraint function.

required
jacobian_F

Function returning \((m, n)\) Jacobian matrix \(DF(z)\).

required
hessians_F

Function returning \((m, n, n)\) Hessians of each

required
vmapped_solver

Function to project points onto the constraint surface.

required
z_hat ndarray

Forecasted points of shape \((batch_size, n)\).

required

Returns: np.ndarray: Boolean array indicating if the curvature condition is satisfied for each point in \(\hat{z}\).

Notes

Check if lambda_min(H_combined) >= 0, where: H_combined = \(sum_i (\delta_\pi[i] / ||∇f_i||) * H_{tan_i}\)

where: - \(\delta_\pi = \tilde{z} - \hat{z}\) - \(H_{tan_i} = E^T H_i E\) (Hessian projected to tangent space)

Source code in src/jnlr/should.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@partial(jax.jit, static_argnames=('f', 'jacobian_F', 'hessians_F', 'vmapped_solver'))
def vector_valued_convex(f, jacobian_F, hessians_F, vmapped_solver, z_hat: jnp.ndarray) -> jnp.ndarray:
    r"""
    Generalization of scalar curvature condition to vector-valued case.

    Args:
        f: Constraint function.
        jacobian_F: Function returning $(m, n)$ Jacobian matrix $DF(z)$.
        hessians_F: Function returning $(m, n, n)$ Hessians of each
        vmapped_solver: Function to project points onto the constraint surface.
        z_hat: Forecasted points of shape $(batch_size, n)$.
    Returns:
        np.ndarray: Boolean array indicating if the curvature condition is satisfied for each point in $\hat{z}$.

    Notes:
        Check if lambda_min(H_combined) >= 0, where:
        H_combined = $sum_i (\delta_\pi[i] / ||∇f_i||) * H_{tan_i}$
    where:
        - $\delta_\pi = \tilde{z} - \hat{z}$
        - $H_{tan_i} = E^T H_i E$  (Hessian projected to tangent space)

    """

    z_tilde = vmapped_solver(z_hat)
    lambdas = jax.vmap(min_tangent_eigenvalue_vv, in_axes=(None, None, None, 0, 0))(f, jacobian_F,  hessians_F, z_tilde, z_hat)
    lambda_min = lambdas[:, 0]  # Take the minimum eigenvalue from the pair (min, max)
    ratios = lambdas[:, 1]
    accepted = lambda_min > 0
    return accepted, ratios

p_reduction

p_reduction(vmapped_solver, z_hat: ndarray, z_hat_samples: ndarray, alpha=0.05)

Estimate the probability of RMSE reduction using projected bootstrap samples.

Parameters:

Name Type Description Default
vmapped_solver

Function to project points onto the constraint surface.

required
z_hat ndarray

Forecasted points of shape (batch_size, n).

required
z_hat_samples ndarray

Bootstrap samples of shape (batch_size, num_samples, n).

required
alpha

Significance level for confidence intervals (not used by this function).

0.05

Returns:

Type Description

np.ndarray: Estimated probability of RMSE reduction for each point in z_hat.

Notes

Theorem 3. For a hypersurface defined by \(f(z) = 0\), the probability of forecasting RMSE reduction can be estimated as

\[\frac{1}{N} \sum_{i=1}^{N} \mathbf{1}\!\left\{\,\tilde{\delta}_i^\top \delta_{\pi} > - \frac{\lVert \delta_{\pi} \rVert^2}{2} \right\}.\]

where

  • \( \mathbf{1} \) is the indicator function;
  • \( \tilde{\delta}_i = \tilde{y}_i - \tilde{y} \);
  • \( \tilde{y}_i \) is the projection of the \(i\)-th sample onto the surface \( f(z) = 0 \);
  • \( \tilde{y} \) is the projection of \( \hat{y} \) onto the surface \( f(z) = 0 \);
  • \( \delta_{\pi} = \tilde{y} - \hat{y} \).
Source code in src/jnlr/should.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@partial(jax.jit, static_argnames=('vmapped_solver'))
def p_reduction(vmapped_solver, z_hat: jnp.ndarray, z_hat_samples: jnp.ndarray, alpha=0.05):
    r"""
    Estimate the probability of RMSE reduction using projected bootstrap samples.

    Args:
        vmapped_solver: Function to project points onto the constraint surface.
        z_hat: Forecasted points of shape `(batch_size, n)`.
        z_hat_samples: Bootstrap samples of shape `(batch_size, num_samples, n)`.
        alpha: Significance level for confidence intervals (not used by this function).

    Returns:
        np.ndarray: Estimated probability of RMSE reduction for each point in `z_hat`.

    Notes:
        **Theorem 3.** For a hypersurface defined by $f(z) = 0$,
        the probability of forecasting RMSE reduction can be estimated as

        $$\frac{1}{N} \sum_{i=1}^{N}
        \mathbf{1}\!\left\{\,\tilde{\delta}_i^\top \delta_{\pi}
        > - \frac{\lVert \delta_{\pi} \rVert^2}{2} \right\}.$$


        where

        * \( \mathbf{1} \) is the indicator function;
        * \( \tilde{\delta}_i = \tilde{y}_i - \tilde{y} \);
        * \( \tilde{y}_i \) is the projection of the \(i\)-th sample onto the surface \( f(z) = 0 \);
        * \( \tilde{y} \) is the projection of \( \hat{y} \) onto the surface \( f(z) = 0 \);
        * \( \delta_{\pi} = \tilde{y} - \hat{y} \).
    """

    z_tilde = vmapped_solver(z_hat)
    z_tilde_samples = vmapped_solver(z_hat_samples.reshape(-1, z_hat.shape[-1])).reshape(z_hat_samples.shape)
    delta_pi = z_tilde - z_hat

    delta_tilde_samples = z_tilde_samples - z_tilde[:, None, :]
    norm_delta_pi = jnp.linalg.norm(delta_pi, axis=-1, keepdims=True)
    scalar_product = jnp.sum(delta_tilde_samples * delta_pi[:, None, :], axis=-1)
    threshold = norm_delta_pi ** 2 / 2
    condition = scalar_product > - threshold

    # cutoff
    delta_tilde_hat_samples = z_tilde_samples - z_hat[:, None, :]
    post_rec_dist = jnp.sum(delta_tilde_hat_samples**2, axis=-1)
    bootstrap_errs = z_hat_samples - z_hat[:, None, :]
    pre_rec_dist = jnp.sum(bootstrap_errs ** 2, axis=-1)
    max_pre_rec_dist = jnp.max(pre_rec_dist, axis=-1, keepdims=True)
    condition = jnp.where(post_rec_dist < max_pre_rec_dist, condition, jnp.nan)

    return jnp.nanmean(condition, axis=-1)

p_reduction_and_intervals

p_reduction_and_intervals(vmapped_solver, z_hat: ndarray, z_hat_samples: ndarray, alpha=0.02)

Estimate the probability of RMSE reduction using projected bootstrap samples, along with Clopper-Pearson confidence intervals. Args: vmapped_solver: Function to project points onto the constraint surface. z_hat: Forecasted points of shape (batch_size, n). z_hat_samples: Bootstrap samples of shape (batch_size, num_samples, n). alpha: Significance level for confidence intervals. Returns: Tuple[np.ndarray, np.ndarray]: A tuple containing: - Estimated probability of RMSE reduction for each point in z_hat. - Clopper-Pearson confidence intervals for each estimate.

Source code in src/jnlr/should.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def p_reduction_and_intervals(vmapped_solver, z_hat: jnp.ndarray, z_hat_samples: jnp.ndarray, alpha=0.02):
    r"""
    Estimate the probability of RMSE reduction using projected bootstrap samples,
    along with Clopper-Pearson confidence intervals.
    Args:
        vmapped_solver: Function to project points onto the constraint surface.
        z_hat: Forecasted points of shape `(batch_size, n)`.
        z_hat_samples: Bootstrap samples of shape `(batch_size, num_samples, n)`.
        alpha: Significance level for confidence intervals.
    Returns:
        Tuple[np.ndarray, np.ndarray]: A tuple containing:
            - Estimated probability of RMSE reduction for each point in `z_hat`.
            - Clopper-Pearson confidence intervals for each estimate.
    """

    z_tilde = vmapped_solver(z_hat)
    z_tilde_samples = vmapped_solver(z_hat_samples.reshape(-1, z_hat.shape[-1])).reshape(z_hat_samples.shape)
    delta_pi = z_tilde - z_hat

    delta_tilde_samples = z_tilde_samples - z_tilde[:, None, :]
    norm_delta_pi = jnp.linalg.norm(delta_pi, axis=-1, keepdims=True)
    scalar_product = jnp.sum(delta_tilde_samples * delta_pi[:, None, :], axis=-1)
    threshold = norm_delta_pi ** 2 / 2
    condition = scalar_product > - threshold

    # cutoff
    delta_tilde_hat_samples = z_tilde_samples - z_hat[:, None, :]
    post_rec_dist = jnp.sum(delta_tilde_hat_samples ** 2, axis=-1)
    bootstrap_errs = z_hat_samples - z_hat[:, None, :]
    pre_rec_dist = jnp.sum(bootstrap_errs ** 2, axis=-1)
    max_pre_rec_dist = jnp.max(pre_rec_dist, axis=-1, keepdims=True)
    condition = jnp.where(post_rec_dist < max_pre_rec_dist, condition, jnp.nan)

    # compute clopper-pearson intervals
    intervals = jax.vmap(clopper_pearson_intervals, in_axes=(0, None, None))(jnp.nansum(condition, axis=-1),
                                                                             condition.shape[-1], alpha)
    return jnp.nanmean(condition, axis=-1), intervals, delta_pi