cub/device/device_merge.cuh

File members: cub/device/device_merge.cuh

// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/detail/nvtx.cuh>
#include <cub/device/dispatch/dispatch_merge.cuh>
#include <cub/util_namespace.cuh>

#include <cuda/std/functional>

CUB_NAMESPACE_BEGIN

struct DeviceMerge
{
  template <typename KeyIteratorIn1,
            typename KeyIteratorIn2,
            typename KeyIteratorOut,
            typename CompareOp = ::cuda::std::less<>>
  CUB_RUNTIME_FUNCTION static cudaError_t MergeKeys(
    void* d_temp_storage,
    std::size_t& temp_storage_bytes,
    KeyIteratorIn1 keys_in1,
    int num_keys1,
    KeyIteratorIn2 keys_in2,
    int num_keys2,
    KeyIteratorOut keys_out,
    CompareOp compare_op = {},
    cudaStream_t stream  = nullptr)
  {
    CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergeKeys");
    return detail::merge::
      dispatch_t<KeyIteratorIn1, NullType*, KeyIteratorIn2, NullType*, KeyIteratorOut, NullType*, int, CompareOp>::
        dispatch(
          d_temp_storage,
          temp_storage_bytes,
          keys_in1,
          nullptr,
          num_keys1,
          keys_in2,
          nullptr,
          num_keys2,
          keys_out,
          nullptr,
          compare_op,
          stream);
  }

  template <typename KeyIteratorIn1,
            typename ValueIteratorIn1,
            typename KeyIteratorIn2,
            typename ValueIteratorIn2,
            typename KeyIteratorOut,
            typename ValueIteratorOut,
            typename CompareOp = ::cuda::std::less<>>
  CUB_RUNTIME_FUNCTION static cudaError_t MergePairs(
    void* d_temp_storage,
    std::size_t& temp_storage_bytes,
    KeyIteratorIn1 keys_in1,
    ValueIteratorIn1 values_in1,
    int num_pairs1,
    KeyIteratorIn2 keys_in2,
    ValueIteratorIn2 values_in2,
    int num_pairs2,
    KeyIteratorOut keys_out,
    ValueIteratorOut values_out,
    CompareOp compare_op = {},
    cudaStream_t stream  = nullptr)
  {
    CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergePairs");
    return detail::merge::dispatch_t<
      KeyIteratorIn1,
      ValueIteratorIn1,
      KeyIteratorIn2,
      ValueIteratorIn2,
      KeyIteratorOut,
      ValueIteratorOut,
      int,
      CompareOp>::dispatch(d_temp_storage,
                           temp_storage_bytes,
                           keys_in1,
                           values_in1,
                           num_pairs1,
                           keys_in2,
                           values_in2,
                           num_pairs2,
                           keys_out,
                           values_out,
                           compare_op,
                           stream);
  }
};

CUB_NAMESPACE_END