prtools.jax.lbfgs#
- lbfgs(fn, x0, gtol=None, maxiter=None, callback=None, fn_args=None, fn_kwargs=None)[source]#
Minimize a scalar function of one or more variables using the L-BFGS algorithm
- Parameters:
fn (callable) –
The objective function to be minimized:
fn(x, *fn_args, **fn_kwargs)
where
x
is a 1-D array with shape (n,) andfn_args
andfn_kwargs
are optional positional and keyword arguments.x0 (jax.Array) – Initial guess
gtol (float) – Iteration stops when
l2_norm(grad) <= gtol
maxiter (int) – Maximum number of iterations
callback (callable, optional) –
A callable called after each iteration with the signature
callback(intermediate_result: JaxOptimizeResult)
where
intermediate_result
is aJaxOptimizeResult
.fn_args (iterable or None) – Extra positional arguments passed to
fn()
fn_kwargs (dict or None) – Extra keyword arguments passed to the
fn()
- Returns:
res – The optimization result. See
JaxOptimizeResult
for a description of attributes.- Return type: