Source code for nvalchemiops.math.math

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import warp as wp


@wp.func
def wp_safe_divide(x: wp.float64, y: wp.float64) -> wp.float64:
    """Safe division.

    Divides x by y, with a safe division to avoid division by zero.
    """
    return wp.where(y < wp.float64(1e-8), wp.float64(0.0), x / y)


@wp.func
def wp_exp_kernel(x: wp.float64, factor: wp.float64) -> wp.float64:
    """
    Safe exponential multiplication and division.

    Calculates exp(-x * factor) / x, with a safe division to avoid division by zero.
    """
    return wp_safe_divide(wp.exp(-x * factor), x)


[docs] @wp.func def wpdivmod(a: int, b: int): # type: ignore """Warp implementation of the divmod utility.""" div = int(a / b) mod = a % b if mod < 0: div -= 1 mod = b + mod return div, mod
@wp.func def wp_erfc(x: Any) -> Any: """Complementary error function approximation for float32. Uses the Abramowitz and Stegun approximation with maximum error ~1.5e-7. erfc(x) = 1 - erf(x) for x >= 0, and erfc(-x) = 2 - erfc(x) for x < 0. Parameters ---------- x : Any Input value Returns ------- Any erfc(x) approximation """ abs_x = wp.abs(x) # Abramowitz and Stegun constants for erfc approximation p = type(x)(0.3275911) a1 = type(x)(0.254829592) a2 = type(x)(-0.284496736) a3 = type(x)(1.421413741) a4 = type(x)(-1.453152027) a5 = type(x)(1.061405429) # Compute approximation for |x| t = type(x)(1.0) / (type(x)(1.0) + p * abs_x) t2 = t * t t3 = t2 * t t4 = t3 * t t5 = t4 * t # Polynomial approximation poly = a1 * t + a2 * t2 + a3 * t3 + a4 * t4 + a5 * t5 # Apply exponential factor exp_neg_x2 = wp.exp(-abs_x * abs_x) erfc_abs_x = poly * exp_neg_x2 # Handle sign: erfc(-x) = 2 - erfc(x) return wp.where(x >= type(x)(0.0), erfc_abs_x, type(x)(2.0) - erfc_abs_x)