warp.optim.linear.cg#
- warp.optim.linear.cg(
- A,
- b,
- x,
- tol=None,
- atol=None,
- maxiter=0,
- M=None,
- callback=None,
- check_every=10,
- use_cuda_graph=True,
Computes an approximate solution to a symmetric, positive-definite linear system using the Conjugate Gradient algorithm.
- Parameters:
A (array | BsrMatrix | LinearOperator) – the linear system’s left-hand-side
b (array) – the linear system’s right-hand-side
x (array) – initial guess and solution vector
tol (float | None) – relative tolerance for the residual, as a ratio of the right-hand-side norm
atol (float | None) – absolute tolerance for the residual
maxiter (float | None) – maximum number of iterations to perform before aborting. Defaults to the system size.
M (array | BsrMatrix | LinearOperator | None) – optional left-preconditioner, ideally chosen such that
M Ais close to identity.callback (Callable | None) – function to be called every check_every iteration with the current iteration number, residual and tolerance. If check_every is 0, the callback should be a Warp kernel.
check_every – number of iterations every which to call callback, check the residual against the tolerance and possibility terminate the algorithm. Setting check_every to 0 disables host-side residual checks, making the solver fully CUDA-graph capturable. If conditional CUDA graphs are supported, convergence checks are performed device-side; otherwise, the solver will always run to the maximum number of iterations.
use_cuda_graph – If true and when run on a CUDA device, capture the solver iteration as a CUDA graph for reduced launch overhead. The linear operator and preconditioner must only perform graph-friendly operations.
- Returns:
- Tuple (final_iteration, residual_norm, absolute_tolerance)
final_iteration: The number of iterations performed before convergence or reaching maxiter
residual_norm: The final residual norm ||b - Ax||
absolute_tolerance: The absolute tolerance used for convergence checking
- If check_every is 0: Tuple (final_iteration_array, residual_norm_squared_array, absolute_tolerance_squared_array)
final_iteration_array: Device array containing the number of iterations performed
residual_norm_squared_array: Device array containing the squared residual norm ||b - Ax||²
absolute_tolerance_squared_array: Device array containing the squared absolute tolerance
- Return type:
If check_every > 0
If both tol and atol are provided, the absolute tolerance used as the termination criterion for the residual norm is
max(atol, tol * norm(b)).