cub/device/device_merge.cuh
File members: cub/device/device_merge.cuh
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#pragma once
#include <cub/config.cuh>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <cub/detail/nvtx.cuh>
#include <cub/device/dispatch/dispatch_merge.cuh>
#include <cub/util_namespace.cuh>
#include <cuda/std/functional>
CUB_NAMESPACE_BEGIN
struct DeviceMerge
{
template <typename KeyIteratorIn1,
typename KeyIteratorIn2,
typename KeyIteratorOut,
typename CompareOp = ::cuda::std::less<>>
CUB_RUNTIME_FUNCTION static cudaError_t MergeKeys(
void* d_temp_storage,
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
int num_keys1,
KeyIteratorIn2 keys_in2,
int num_keys2,
KeyIteratorOut keys_out,
CompareOp compare_op = {},
cudaStream_t stream = nullptr)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergeKeys");
return detail::merge::
dispatch_t<KeyIteratorIn1, NullType*, KeyIteratorIn2, NullType*, KeyIteratorOut, NullType*, int, CompareOp>::
dispatch(
d_temp_storage,
temp_storage_bytes,
keys_in1,
nullptr,
num_keys1,
keys_in2,
nullptr,
num_keys2,
keys_out,
nullptr,
compare_op,
stream);
}
template <typename KeyIteratorIn1,
typename ValueIteratorIn1,
typename KeyIteratorIn2,
typename ValueIteratorIn2,
typename KeyIteratorOut,
typename ValueIteratorOut,
typename CompareOp = ::cuda::std::less<>>
CUB_RUNTIME_FUNCTION static cudaError_t MergePairs(
void* d_temp_storage,
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
ValueIteratorIn1 values_in1,
int num_pairs1,
KeyIteratorIn2 keys_in2,
ValueIteratorIn2 values_in2,
int num_pairs2,
KeyIteratorOut keys_out,
ValueIteratorOut values_out,
CompareOp compare_op = {},
cudaStream_t stream = nullptr)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergePairs");
return detail::merge::dispatch_t<
KeyIteratorIn1,
ValueIteratorIn1,
KeyIteratorIn2,
ValueIteratorIn2,
KeyIteratorOut,
ValueIteratorOut,
int,
CompareOp>::dispatch(d_temp_storage,
temp_storage_bytes,
keys_in1,
values_in1,
num_pairs1,
keys_in2,
values_in2,
num_pairs2,
keys_out,
values_out,
compare_op,
stream);
}
};
CUB_NAMESPACE_END