include/cuda/experimental/__hierarchy/hierarchy_dimensions.cuh

File members: include/cuda/experimental/__hierarchy/hierarchy_dimensions.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__HIERARCHY_HIERARCHY_DIMENSIONS
#define _CUDAX__HIERARCHY_HIERARCHY_DIMENSIONS

#include <cuda/std/__type_traits/fold.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/__type_traits/remove_cvref.h>
#include <cuda/std/__type_traits/type_list.h>
#include <cuda/std/__utility/declval.h>
#include <cuda/std/__utility/integer_sequence.h>
#include <cuda/std/span>
#include <cuda/std/tuple>

#include <cuda/experimental/__detail/config.cuh>
#include <cuda/experimental/__hierarchy/level_dimensions.cuh>

#include <nv/target>

#if _CCCL_STD_VER >= 2017
namespace cuda::experimental
{

/* TODO right now operator stacking can end up with a wrong unit, we could use below type, but we would need an explicit
 thread_level inserter
struct unknown_unit : public hierarchy_level
{
  using product_type  = unsigned int;
  using allowed_above = allowed_levels<>;
  using allowed_below = allowed_levels<>;
};
*/

namespace detail
{
template <typename _Level>
_CCCL_NODISCARD _CUDAX_API constexpr auto __as_level(_Level __l) noexcept -> _Level
{
  return __l;
}

template <typename _LevelFn>
_CCCL_NODISCARD _CUDAX_API constexpr auto __as_level(_LevelFn* __fn) noexcept -> decltype(__fn())
{
  return {};
}
} // namespace detail

template <class _Level>
using __level_type_of = typename _Level::level_type;

template <typename BottomUnit, typename... Levels>
struct hierarchy_dimensions_fragment;

// If lowest unit in the hierarchy is thread, it can be considered a full hierarchy and not only a fragment
template <typename... Levels>
using hierarchy_dimensions = hierarchy_dimensions_fragment<thread_level, Levels...>;

namespace detail
{
// Function to sometimes convince the compiler something is a constexpr and not really accessing runtime storage
// Mostly a work around for what was addressed in P2280 (c++23) by leveraging the argumentless constructor of extents
template <typename T, size_t... Extents>
_CCCL_NODISCARD _CUDAX_API constexpr auto fool_compiler(const dimensions<T, Extents...>& ex)
{
  if constexpr (dimensions<T, Extents...>::rank_dynamic() == 0)
  {
    return dimensions<T, Extents...>();
  }
  else
  {
    return ex;
  }
  _CCCL_UNREACHABLE();
}

template <typename QueryLevel, typename Hierarchy>
struct has_level_helper;

template <typename QueryLevel, typename Unit, typename... Levels>
struct has_level_helper<QueryLevel, hierarchy_dimensions_fragment<Unit, Levels...>>
    : public ::cuda::std::__fold_or<::cuda::std::is_same_v<QueryLevel, __level_type_of<Levels>>...>
{};

// Is this needed?
template <typename QueryLevel, typename... Levels>
struct has_level_helper<QueryLevel, hierarchy_dimensions<Levels...>>
    : public ::cuda::std::__fold_or<::cuda::std::is_same_v<QueryLevel, __level_type_of<Levels>>...>
{};

template <typename QueryLevel, typename Hierarchy>
struct has_unit
{};

template <typename QueryLevel, typename Unit, typename... Levels>
struct has_unit<QueryLevel, hierarchy_dimensions_fragment<Unit, Levels...>> : ::cuda::std::is_same<QueryLevel, Unit>
{};

template <typename QueryLevel>
struct get_level_helper
{
  template <typename TopLevel, typename... Levels>
  _CCCL_NODISCARD _CUDAX_API constexpr auto& operator()(const TopLevel& top, const Levels&... levels)
  {
    if constexpr (::cuda::std::is_same_v<QueryLevel, __level_type_of<TopLevel>>)
    {
      return top;
    }
    else
    {
      return (*this)(levels...);
    }
    _CCCL_UNREACHABLE();
  }
};
} // namespace detail

template <typename QueryLevel, typename Hierarchy>
_CCCL_INLINE_VAR constexpr bool has_level =
  detail::has_level_helper<QueryLevel, ::cuda::std::remove_cvref_t<Hierarchy>>::value;

template <typename QueryLevel, typename Hierarchy>
_CCCL_INLINE_VAR constexpr bool has_level_or_unit =
  detail::has_level_helper<QueryLevel, ::cuda::std::remove_cvref_t<Hierarchy>>::value
  || detail::has_unit<QueryLevel, ::cuda::std::remove_cvref_t<Hierarchy>>::value;

namespace detail
{
template <typename... Levels>
struct can_stack_checker
{
  template <typename... LevelsShifted>
  using can_stack = ::cuda::std::__fold_and<detail::can_stack_on_top<LevelsShifted, Levels>...>;
};

template <typename LUnit, typename L1, typename... Levels>
_CCCL_INLINE_VAR constexpr bool __can_stack =
  can_stack_checker<__level_type_of<L1>,
                    __level_type_of<Levels>...>::template can_stack<__level_type_of<Levels>..., LUnit>::value;

template <size_t... _Id>
_CUDAX_API constexpr auto __reverse_indices(::cuda::std::index_sequence<_Id...>) noexcept
{
  return ::cuda::std::index_sequence<(sizeof...(_Id) - 1 - _Id)...>();
}

template <typename LUnit, bool Reversed = false>
struct __make_hierarchy_fragment
{
  template <class Levels, size_t... _Ids>
  _CCCL_NODISCARD _CUDAX_TRIVIAL_API static constexpr auto
  __apply_reverse(const Levels& ls, ::cuda::std::index_sequence<_Ids...>) noexcept
  {
    return __make_hierarchy_fragment<LUnit, true>()(::cuda::std::get<_Ids>(ls)...);
  }

  template <typename... Levels>
  _CCCL_NODISCARD _CUDAX_API constexpr auto operator()(const Levels&... ls) const noexcept
  {
    using UnitOrDefault = ::cuda::std::conditional_t<
      ::cuda::std::is_same_v<void, LUnit>,
      __default_unit_below<::cuda::std::__type_index_c<sizeof...(Levels) - 1, __level_type_of<Levels>...>>,
      LUnit>;
    if constexpr (__can_stack<UnitOrDefault, Levels...>)
    {
      return hierarchy_dimensions_fragment(UnitOrDefault{}, ls...);
    }
    else if constexpr (!Reversed)
    {
      return __apply_reverse(::cuda::std::tie(ls...), __reverse_indices(::cuda::std::index_sequence_for<Levels...>()));
    }
    else
    {
      static_assert(__can_stack<UnitOrDefault, Levels...>,
                    "Provided levels can't create a valid hierarchy when stacked in the provided order or reversed");
    }
  }
};

template <typename LUnit>
_CCCL_NODISCARD _CUDAX_API constexpr auto get_levels_range_end() noexcept
{
  return ::cuda::std::make_tuple();
}

// Find LUnit in Levels... and discard the rest
// maybe_unused needed for MSVC
template <typename LUnit, typename LDims, typename... Levels>
_CCCL_NODISCARD _CUDAX_API constexpr auto
get_levels_range_end(const LDims& l, [[maybe_unused]] const Levels&... levels) noexcept
{
  if constexpr (::cuda::std::is_same_v<LUnit, __level_type_of<LDims>>)
  {
    return ::cuda::std::make_tuple();
  }
  else
  {
    return ::cuda::std::tuple_cat(::cuda::std::tie(l), get_levels_range_end<LUnit>(levels...));
  }
}

// Find the LTop in Levels... and discard the preceeding ones
template <typename LTop, typename LUnit, typename LTopDims, typename... Levels>
_CCCL_NODISCARD _CUDAX_API constexpr auto get_levels_range_start(const LTopDims& ltop, const Levels&... levels) noexcept
{
  if constexpr (::cuda::std::is_same_v<LTop, __level_type_of<LTopDims>>)
  {
    return get_levels_range_end<LUnit>(ltop, levels...);
  }
  else
  {
    return get_levels_range_start<LTop, LUnit>(levels...);
  }
}

// Creates a new hierachy from Levels... cutting out levels between LTop and LUnit
template <typename LTop, typename LUnit, typename... Levels>
_CCCL_NODISCARD _CUDAX_API constexpr auto get_levels_range(const Levels&... levels) noexcept
{
  return get_levels_range_start<LTop, LUnit>(levels...);
}

template <typename T, size_t... Extents, size_t... Ids>
_CCCL_NODISCARD _CUDAX_API constexpr auto
dims_to_count_helper(const dimensions<T, Extents...>& ex, ::cuda::std::index_sequence<Ids...>)
{
  return (ex.extent(Ids) * ...);
}

template <typename T, size_t... Extents>
_CCCL_NODISCARD _CUDAX_API constexpr auto dims_to_count(const dimensions<T, Extents...>& dims) noexcept
{
  return dims_to_count_helper(dims, ::cuda::std::make_index_sequence<sizeof...(Extents)>{});
}

template <typename... Levels>
_CCCL_NODISCARD _CUDAX_API constexpr auto get_level_counts_helper(const Levels&... ls)
{
  return ::cuda::std::make_tuple(dims_to_count(ls.dims)...);
}

template <typename Unit, typename Level, typename Dims>
_CCCL_NODISCARD _CUDAX_API constexpr auto replace_with_intrinsics_or_constexpr(const Dims& dims)
{
  if constexpr (is_core_cuda_hierarchy_level<Level> && is_core_cuda_hierarchy_level<Unit> && Dims::rank_dynamic() != 0)
  {
    // We replace hierarchy access with CUDA intrinsic to enable compiler optimizations, its ok for the prototype,
    // but might lead to unexpected results and should be eventually addressed at the API level
    // TODO with device side launch we should have a way to disable it for the device-side created hierarchy
    NV_IF_ELSE_TARGET(NV_IS_DEVICE,
                      (dim3 intr_dims = dims_helper<Unit, Level>::extents();
                       return fool_compiler(Dims(intr_dims.x, intr_dims.y, intr_dims.z));),
                      (return fool_compiler(dims);));
  }
  else
  {
    return fool_compiler(dims);
  }
}

template <typename BottomUnit>
struct hierarchy_extents_helper
{
  template <typename LTopDims, typename... Levels>
  _CCCL_NODISCARD _CUDAX_API constexpr auto operator()(const LTopDims& ltop, const Levels&... levels) noexcept
  {
    using TopLevel = __level_type_of<LTopDims>;
    if constexpr (sizeof...(Levels) == 0)
    {
      return replace_with_intrinsics_or_constexpr<BottomUnit, TopLevel>(ltop.dims);
    }
    else
    {
      using Unit = ::cuda::std::__type_index_c<0, __level_type_of<Levels>...>;
      return dims_product<typename TopLevel::product_type>(
        replace_with_intrinsics_or_constexpr<Unit, TopLevel>(ltop.dims), (*this)(levels...));
    }
  }
};

template <typename T, size_t... Extents>
_CCCL_NODISCARD _CCCL_DEVICE constexpr auto static_index_hint(const dimensions<T, Extents...>& dims, dim3 index)
{
  using hinted_index_t = dimensions<T, (Extents == 1 ? 0 : ::cuda::std::dynamic_extent)...>;
  return hinted_index_t(index.x, index.y, index.z);
}

template <typename BottomUnit>
struct index_helper
{
  template <typename LTopDims, typename... Levels>
  _CCCL_NODISCARD _CCCL_DEVICE constexpr auto operator()(const LTopDims& ltop, const Levels&... levels) noexcept
  {
    using TopLevel = __level_type_of<LTopDims>;
    if constexpr (sizeof...(Levels) == 0)
    {
      return static_index_hint(ltop.dims, dims_helper<BottomUnit, TopLevel>::index());
    }
    else
    {
      using Unit        = ::cuda::std::__type_index_c<0, __level_type_of<Levels>...>;
      auto hinted_index = static_index_hint(ltop.dims, dims_helper<Unit, TopLevel>::index());
      return dims_sum<typename TopLevel::product_type>(
        dims_product<typename TopLevel::product_type>(hinted_index, hierarchy_extents_helper<BottomUnit>()(levels...)),
        index_helper<BottomUnit>()(levels...));
    }
  }
};

template <typename BottomUnit>
struct rank_helper
{
  template <typename LTopDims, typename... Levels>
  _CCCL_NODISCARD _CCCL_DEVICE constexpr auto operator()(const LTopDims& ltop, const Levels&... levels) noexcept
  {
    using TopLevel = __level_type_of<LTopDims>;
    if constexpr (sizeof...(Levels) == 0)
    {
      auto hinted_index = static_index_hint(ltop.dims, dims_helper<BottomUnit, TopLevel>::index());
      return detail::index_to_linear<typename TopLevel::product_type>(hinted_index, ltop.dims);
    }
    else
    {
      using Unit        = ::cuda::std::__type_index_c<0, __level_type_of<Levels>...>;
      auto hinted_index = static_index_hint(ltop.dims, dims_helper<Unit, TopLevel>::index());
      auto level_rank   = detail::index_to_linear<typename TopLevel::product_type>(hinted_index, ltop.dims);
      return level_rank * dims_to_count(hierarchy_extents_helper<BottomUnit>()(levels...))
           + rank_helper<BottomUnit>()(levels...);
    }
  }
};
} // namespace detail

template <typename BottomUnit, typename... Levels>
struct hierarchy_dimensions_fragment
{
  static_assert(::cuda::std::is_base_of_v<hierarchy_level, BottomUnit> || ::cuda::std::is_same_v<BottomUnit, void>);
  ::cuda::std::tuple<Levels...> levels;

  _CUDAX_API constexpr hierarchy_dimensions_fragment(const Levels&... ls) noexcept
      : levels(ls...)
  {}
  _CUDAX_API constexpr hierarchy_dimensions_fragment(const BottomUnit&, const Levels&... ls) noexcept
      : levels(ls...)
  {}

  _CUDAX_API constexpr hierarchy_dimensions_fragment(const ::cuda::std::tuple<Levels...>& ls) noexcept
      : levels(ls)
  {}

  _CUDAX_API constexpr hierarchy_dimensions_fragment(const BottomUnit&, const ::cuda::std::tuple<Levels...>& ls) noexcept
      : levels(ls)
  {}

#  if defined(__cpp_three_way_comparison) && __cpp_three_way_comparison >= 201907
  _CCCL_NODISCARD _CUDAX_API constexpr bool operator==(const hierarchy_dimensions_fragment&) const noexcept = default;
#  else
  _CCCL_NODISCARD_FRIEND _CUDAX_API constexpr bool
  operator==(const hierarchy_dimensions_fragment& left, const hierarchy_dimensions_fragment& right) noexcept
  {
    return left.levels == right.levels;
  }

  _CCCL_NODISCARD_FRIEND _CUDAX_API constexpr bool
  operator!=(const hierarchy_dimensions_fragment& left, const hierarchy_dimensions_fragment& right) noexcept
  {
    return left.levels != right.levels;
  }
#  endif

private:
  // This being static is a bit of a hack to make extents_type working without incomplete class member access
  template <typename Unit, typename Level>
  _CCCL_NODISCARD _CUDAX_API static constexpr auto
  levels_range_static(const ::cuda::std::tuple<Levels...>& levels) noexcept
  {
    static_assert(has_level<Level, hierarchy_dimensions_fragment<BottomUnit, Levels...>>);
    static_assert(has_level_or_unit<Unit, hierarchy_dimensions_fragment<BottomUnit, Levels...>>);
    static_assert(detail::legal_unit_for_level<Unit, Level>);
    return ::cuda::std::apply(detail::get_levels_range<Level, Unit, Levels...>, levels);
  }

  // TODO is this useful enough to expose?
  template <typename Unit, typename Level>
  _CCCL_NODISCARD _CUDAX_API constexpr auto levels_range() const noexcept
  {
    return levels_range_static<Unit, Level>(levels);
  }

  template <typename Unit>
  struct fragment_helper
  {
    template <typename... Selected>
    _CCCL_NODISCARD _CUDAX_API constexpr auto operator()(const Selected&... levels) const noexcept
    {
      return hierarchy_dimensions_fragment<Unit, Selected...>(levels...);
    }
  };

public:
  template <typename Unit, typename Level>
  using extents_type = decltype(::cuda::std::apply(
    ::cuda::std::declval<detail::hierarchy_extents_helper<Unit>>(),
    levels_range_static<Unit, Level>(::cuda::std::declval<::cuda::std::tuple<Levels...>>())));

  template <typename Unit, typename Level>
  _CUDAX_API constexpr auto fragment(const Unit& = Unit(), const Level& = Level()) const noexcept
  {
    auto selected = levels_range<Unit, Level>();
    // TODO fragment can't do constexpr queries because we use references here, can we create copies of the levels in
    // some cases and move to the constructor?
    return ::cuda::std::apply(fragment_helper<Unit>(), selected);
  }

  template <typename Unit = BottomUnit, typename Level = __level_type_of<::cuda::std::__type_index_c<0, Levels...>>>
  _CUDAX_API constexpr auto extents(const Unit& = Unit(), const Level& = Level()) const noexcept
  {
    auto selected = levels_range<Unit, Level>();
    return detail::convert_to_query_result(::cuda::std::apply(detail::hierarchy_extents_helper<Unit>{}, selected));
  }

  // template <typename Unit, typename Level>
  // using extents_type = ::cuda::std::invoke_result_t<
  //   decltype(&hierarchy_dimensions_fragment<BottomUnit, Levels...>::template extents<Unit, Level>),
  //   hierarchy_dimensions_fragment<BottomUnit, Levels...>,
  //   Unit(),
  //   Level()>;

  template <typename Unit = BottomUnit, typename Level = __level_type_of<::cuda::std::__type_index_c<0, Levels...>>>
  _CUDAX_API constexpr auto count(const Unit& = Unit(), const Level& = Level()) const noexcept
  {
    return detail::dims_to_count(extents<Unit, Level>());
  }

  // TODO static extents?

  template <typename Unit = BottomUnit, typename Level = __level_type_of<::cuda::std::__type_index_c<0, Levels...>>>
  _CUDAX_API constexpr static auto static_count(const Unit& = Unit(), const Level& = Level()) noexcept
  {
    if constexpr (extents_type<Unit, Level>::rank_dynamic() == 0)
    {
      return detail::dims_to_count(extents_type<Unit, Level>());
    }
    else
    {
      return ::cuda::std::dynamic_extent;
    }
  }

  template <typename Unit = BottomUnit, typename Level = __level_type_of<::cuda::std::__type_index_c<0, Levels...>>>
  _CCCL_DEVICE constexpr auto index(const Unit& = Unit(), const Level& = Level()) const noexcept
  {
    auto selected = levels_range<Unit, Level>();
    return detail::convert_to_query_result(::cuda::std::apply(detail::index_helper<Unit>{}, selected));
  }

  template <typename Unit = BottomUnit, typename Level = __level_type_of<::cuda::std::__type_index_c<0, Levels...>>>
  _CCCL_DEVICE constexpr auto rank(const Unit& = Unit(), const Level& = Level()) const noexcept
  {
    auto selected = levels_range<Unit, Level>();
    return ::cuda::std::apply(detail::rank_helper<Unit>{}, selected);
  }

  template <typename Level>
  _CUDAX_API constexpr auto level(const Level&) const noexcept
  {
    static_assert(has_level<Level, hierarchy_dimensions_fragment<BottomUnit, Levels...>>);

    return ::cuda::std::apply(detail::get_level_helper<Level>{}, levels);
  }
};

template <typename... Levels>
constexpr auto _CCCL_HOST get_launch_dimensions(const hierarchy_dimensions<Levels...>& hierarchy)
{
  if constexpr (has_level<cluster_level, hierarchy_dimensions<Levels...>>)
  {
    return ::cuda::std::make_tuple(
      hierarchy.extents(block, grid), hierarchy.extents(block, cluster), hierarchy.extents(thread, block));
  }
  else
  {
    return ::cuda::std::make_tuple(hierarchy.extents(block, grid), hierarchy.extents(thread, block));
  }
}

/* TODO consider having LUnit optional argument for template argument deduction
 This could have been a single function with make_hierarchy and first template
 argument defauled, but then the above TODO would be impossible and the current
 name makes more sense */
template <typename LUnit = void, typename L1, typename... Levels>
constexpr auto make_hierarchy_fragment(L1 l1, Levels... ls) noexcept
{
  return detail::__make_hierarchy_fragment<LUnit>()(detail::__as_level(l1), detail::__as_level(ls)...);
}

template <typename L1, typename... Levels>
constexpr auto make_hierarchy(L1 l1, Levels... ls) noexcept
{
  return detail::__make_hierarchy_fragment<thread_level>()(detail::__as_level(l1), detail::__as_level(ls)...);
}

// We can consider removing the operator&, but its convenient for in-line construction
// TODO accept forwarding references
template <typename LUnit, typename LNew, typename... Levels>
_CUDAX_API constexpr auto operator&(const hierarchy_dimensions_fragment<LUnit, Levels...>& ls, LNew lnew) noexcept
{
  auto new_level     = detail::__as_level(lnew);
  using NewLevel     = decltype(new_level);
  using top_level    = __level_type_of<::cuda::std::__type_index_c<0, Levels...>>;
  using bottom_level = __level_type_of<::cuda::std::__type_index_c<sizeof...(Levels) - 1, Levels...>>;

  if constexpr (detail::can_stack_on_top<top_level, __level_type_of<NewLevel>>)
  {
    return hierarchy_dimensions_fragment<LUnit, NewLevel, Levels...>(
      ::cuda::std::tuple_cat(::cuda::std::make_tuple(new_level), ls.levels));
  }
  else
  {
    static_assert(detail::can_stack_on_top<__level_type_of<NewLevel>, bottom_level>,
                  "Not supported order of levels in hierarchy");
    using NewUnit = detail::__default_unit_below<__level_type_of<NewLevel>>;
    return hierarchy_dimensions_fragment<NewUnit, Levels..., NewLevel>(
      ::cuda::std::tuple_cat(ls.levels, ::cuda::std::make_tuple(new_level)));
  }
}

template <typename L1, typename LUnit, typename... Levels>
_CUDAX_API constexpr auto operator&(L1 l1, const hierarchy_dimensions_fragment<LUnit, Levels...>& ls) noexcept
{
  return ls & l1;
}

template <typename L1, typename Dims1, typename L2, typename Dims2>
_CUDAX_API constexpr auto
operator&(const level_dimensions<L1, Dims1>& l1, const level_dimensions<L2, Dims2>& l2) noexcept
{
  return hierarchy_dimensions<level_dimensions<L1, Dims1>>(l1) & l2;
}

template <typename NewLevel, typename Unit, typename... Levels>
constexpr auto hierarchy_add_level(const hierarchy_dimensions_fragment<Unit, Levels...>& hierarchy, NewLevel level)
{
  return hierarchy & level;
}

template <int _ThreadsPerBlock>
constexpr auto distribute(int numElements) noexcept
{
  int blocksPerGrid = (numElements + _ThreadsPerBlock - 1) / _ThreadsPerBlock;
  return ::cuda::experimental::make_hierarchy(
    ::cuda::experimental::grid_dims(blocksPerGrid), ::cuda::experimental::block_dims<_ThreadsPerBlock>());
}

} // namespace cuda::experimental
#endif // _CCCL_STD_VER >= 2017
#endif // _CUDAX__HIERARCHY_HIERARCHY_DIMENSIONS