prtools.jax.lbfgs#

lbfgs(fn, x0, gtol=None, maxiter=None, callback=None, fn_args=None, fn_kwargs=None)[source]#

Minimize a scalar function of one or more variables using the L-BFGS algorithm

Parameters:
  • fn (callable) –

    The objective function to be minimized:

    fn(x, *fn_args, **fn_kwargs)
    

    where x is a 1-D array with shape (n,) and fn_args and fn_kwargs are optional positional and keyword arguments.

  • x0 (jax.Array) – Initial guess

  • gtol (float) – Iteration stops when l2_norm(grad) <= gtol

  • maxiter (int) – Maximum number of iterations

  • callback (callable, optional) –

    A callable called after each iteration with the signature

    callback(intermediate_result: JaxOptimizeResult)
    

    where intermediate_result is a JaxOptimizeResult.

  • fn_args (iterable or None) – Extra positional arguments passed to fn()

  • fn_kwargs (dict or None) – Extra keyword arguments passed to the fn()

Returns:

res – The optimization result. See JaxOptimizeResult for a description of attributes.

Return type:

JaxOptimizeResult