include/cuda/experimental/__hierarchy/level_dimensions.cuh

File members: include/cuda/experimental/__hierarchy/level_dimensions.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__HIERARCHY_LEVEL_DIMENSIONS
#define _CUDAX__HIERARCHY_LEVEL_DIMENSIONS

#include <cuda/std/span>
#include <cuda/std/type_traits>

#include <cuda/experimental/__detail/config.cuh>
#include <cuda/experimental/__hierarchy/hierarchy_levels.cuh>

#if _CCCL_STD_VER >= 2017
namespace cuda::experimental
{

namespace detail
{

/* Keeping it around in case issues like https://github.com/NVIDIA/cccl/issues/522
template <typename T, size_t... Extents>
struct extents_corrected : public ::cuda::std::extents<T, Extents...> {
    using ::cuda::std::extents<T, Extents...>::extents;

    template <typename ::cuda::std::extents<T, Extents...>::rank_type Id>
    _CCCL_HOST_DEVICE constexpr auto extent_corrected() const {
        if constexpr (::cuda::std::extents<T, Extents...>::static_extent(Id) != ::cuda::std::dynamic_extent) {
            return this->static_extent(Id);
        }
        else {
            return this->extent(Id);
        }
    }
};
*/

template <typename Dims>
struct dimensions_handler
{
  static constexpr bool is_type_supported = ::cuda::std::is_integral_v<Dims>;

  _CCCL_NODISCARD _CCCL_HOST_DEVICE static constexpr auto translate(const Dims& d) noexcept
  {
    return dimensions<dimensions_index_type, ::cuda::std::dynamic_extent, 1, 1>(static_cast<unsigned int>(d));
  }
};

template <>
struct dimensions_handler<dim3>
{
  static constexpr bool is_type_supported = true;

  _CCCL_NODISCARD _CCCL_HOST_DEVICE static constexpr auto translate(const dim3& d) noexcept
  {
    return dimensions<dimensions_index_type,
                      ::cuda::std::dynamic_extent,
                      ::cuda::std::dynamic_extent,
                      ::cuda::std::dynamic_extent>(d.x, d.y, d.z);
  }
};

template <typename Dims, Dims Val>
struct dimensions_handler<::cuda::std::integral_constant<Dims, Val>>
{
  static constexpr bool is_type_supported = true;

  _CCCL_NODISCARD _CCCL_HOST_DEVICE static constexpr auto translate(const Dims& d) noexcept
  {
    return dimensions<dimensions_index_type, size_t(d), 1, 1>();
  }
};
} // namespace detail

template <typename Level, typename Dimensions>
struct level_dimensions
{
  static_assert(::cuda::std::is_base_of_v<hierarchy_level, Level>);
  using level_type = Level;

  // Needs alignas to work around an issue with tuple
  alignas(16) const Dimensions dims; // Unit for dimensions is implicit

  _CCCL_HOST_DEVICE constexpr level_dimensions(const Dimensions& d)
      : dims(d)
  {}
  _CCCL_HOST_DEVICE constexpr level_dimensions(Dimensions&& d)
      : dims(d)
  {}
  _CCCL_HOST_DEVICE constexpr level_dimensions()
      : dims(){};

#  if defined(__cpp_three_way_comparison) && __cpp_three_way_comparison >= 201907
  _CCCL_NODISCARD _CUDAX_API constexpr bool operator==(const level_dimensions&) const noexcept = default;
#  else
  _CCCL_NODISCARD_FRIEND _CUDAX_API constexpr bool
  operator==(const level_dimensions& left, const level_dimensions& right) noexcept
  {
    return left.dims == right.dims;
  }

  _CCCL_NODISCARD_FRIEND _CUDAX_API constexpr bool
  operator!=(const level_dimensions& left, const level_dimensions& right) noexcept
  {
    return left.dims != right.dims;
  }
#  endif
};

template <size_t X, size_t Y = 1, size_t Z = 1>
_CCCL_HOST_DEVICE constexpr auto grid_dims() noexcept
{
  return level_dimensions<grid_level, dimensions<dimensions_index_type, X, Y, Z>>();
}

template <typename T>
_CCCL_HOST_DEVICE constexpr auto grid_dims(T t) noexcept
{
  static_assert(detail::dimensions_handler<T>::is_type_supported);
  auto dims = detail::dimensions_handler<T>::translate(t);
  return level_dimensions<grid_level, decltype(dims)>(dims);
}

template <size_t X, size_t Y = 1, size_t Z = 1>
_CCCL_HOST_DEVICE constexpr auto cluster_dims() noexcept
{
  return level_dimensions<cluster_level, dimensions<dimensions_index_type, X, Y, Z>>();
}

template <typename T>
_CCCL_HOST_DEVICE constexpr auto cluster_dims(T t) noexcept
{
  static_assert(detail::dimensions_handler<T>::is_type_supported);
  auto dims = detail::dimensions_handler<T>::translate(t);
  return level_dimensions<cluster_level, decltype(dims)>(dims);
}

template <size_t X, size_t Y = 1, size_t Z = 1>
_CCCL_HOST_DEVICE constexpr auto block_dims() noexcept
{
  return level_dimensions<block_level, dimensions<dimensions_index_type, X, Y, Z>>();
}

template <typename T>
_CCCL_HOST_DEVICE constexpr auto block_dims(T t) noexcept
{
  static_assert(detail::dimensions_handler<T>::is_type_supported);
  auto dims = detail::dimensions_handler<T>::translate(t);
  return level_dimensions<block_level, decltype(dims)>(dims);
}

} // namespace cuda::experimental
#endif // _CCCL_STD_VER >= 2017
#endif // _CUDAX__HIERARCHY_LEVEL_DIMENSIONS