include/cuda/experimental/__hierarchy/dimensions.cuh
File members: include/cuda/experimental/__hierarchy/dimensions.cuh
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#ifndef _CUDAX__HIERARCHY_DIMENSIONS
#define _CUDAX__HIERARCHY_DIMENSIONS
#include <cuda/std/mdspan>
#if _CCCL_STD_VER >= 2017
namespace cuda::experimental
{
template <typename T, size_t... Extents>
using dimensions = ::cuda::std::extents<T, Extents...>;
// not unsigned because of a bug in ::cuda::std::extents
using dimensions_index_type = int;
template <typename T, size_t... Extents>
struct hierarchy_query_result : public dimensions<T, Extents...>
{
using Dims = dimensions<T, Extents...>;
using Dims::Dims;
_CCCL_HOST_DEVICE constexpr hierarchy_query_result()
: Dims()
, x(Dims::extent(0))
, y(Dims::rank() > 1 ? Dims::extent(1) : 1)
, z(Dims::rank() > 2 ? Dims::extent(2) : 1)
{}
_CCCL_HOST_DEVICE explicit constexpr hierarchy_query_result(const Dims& dims)
: Dims(dims)
, x(Dims::extent(0))
, y(Dims::rank() > 1 ? Dims::extent(1) : 1)
, z(Dims::rank() > 2 ? Dims::extent(2) : 1)
{}
static_assert(Dims::rank() > 0 && Dims::rank() <= 3);
const T x;
const T y;
const T z;
_CCCL_HOST_DEVICE constexpr operator dim3() const
{
return dim3(static_cast<uint32_t>(x), static_cast<uint32_t>(y), static_cast<uint32_t>(z));
}
};
namespace detail
{
template <typename OpType>
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr size_t merge_extents(size_t e1, size_t e2)
{
if (e1 == ::cuda::std::dynamic_extent || e2 == ::cuda::std::dynamic_extent)
{
return ::cuda::std::dynamic_extent;
}
else
{
OpType op;
return op(e1, e2);
}
}
template <typename DstType, typename OpType, typename T1, size_t... Extents1, typename T2, size_t... Extents2>
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto
dims_op(const OpType& op, const dimensions<T1, Extents1...>& h1, const dimensions<T2, Extents2...>& h2) noexcept
{
// For now target only 3 dim extents
static_assert(sizeof...(Extents1) == sizeof...(Extents2));
static_assert(sizeof...(Extents1) == 3);
return dimensions<DstType, merge_extents<OpType>(Extents1, Extents2)...>(
op(static_cast<DstType>(h1.extent(0)), h2.extent(0)),
op(static_cast<DstType>(h1.extent(1)), h2.extent(1)),
op(static_cast<DstType>(h1.extent(2)), h2.extent(2)));
}
template <typename DstType, typename T1, size_t... Extents1, typename T2, size_t... Extents2>
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto
dims_product(const dimensions<T1, Extents1...>& h1, const dimensions<T2, Extents2...>& h2) noexcept
{
return dims_op<DstType>(::cuda::std::multiplies(), h1, h2);
}
template <typename DstType, typename T1, size_t... Extents1, typename T2, size_t... Extents2>
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto
dims_sum(const dimensions<T1, Extents1...>& h1, const dimensions<T2, Extents2...>& h2) noexcept
{
return dims_op<DstType>(::cuda::std::plus(), h1, h2);
}
template <typename T, size_t... Extents>
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto convert_to_query_result(const dimensions<T, Extents...>& result)
{
return hierarchy_query_result<T, Extents...>(result);
}
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto dim3_to_dims(const dim3& dims)
{
return dimensions<dimensions_index_type,
::cuda::std::dynamic_extent,
::cuda::std::dynamic_extent,
::cuda::std::dynamic_extent>(dims.x, dims.y, dims.z);
}
template <typename TyTrunc, typename Index, typename Dims>
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto index_to_linear(const Index& index, const Dims& dims)
{
static_assert(Dims::rank() == 3);
return (static_cast<TyTrunc>(index.extent(2)) * dims.extent(1) + index.extent(1)) * dims.extent(0) + index.extent(0);
}
} // namespace detail
} // namespace cuda::experimental
#endif // _CCCL_STD_VER >= 2017
#endif // _CUDAX__HIERARCHY_DIMENSIONS