include/cuda/experimental/__stf/utility/traits.cuh

File members: include/cuda/experimental/__stf/utility/traits.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/mdspan>

#include <cuda/experimental/__stf/utility/core.cuh>

#include <array>
#include <cassert>
#include <string_view>
#include <tuple>

namespace cuda::experimental::stf
{

namespace reserved
{

// We use this function as a detector for what __PRETTY_FUNCTION__ looks like
template <typename T>
constexpr ::std::string_view type_name_IMPL()
{
#if _CCCL_COMPILER(MSVC)
  return __FUNCSIG__;
#else // ^^^ _CCCL_COMPILER(MSVC) ^^^ / vvv !_CCCL_COMPILER(MSVC) vvv
  return __PRETTY_FUNCTION__;
#endif // !_CCCL_COMPILER(MSVC)
}

// Length of prefix and suffix in __PRETTY_FUNCTION__ when used with `type_name`.
inline constexpr ::std::pair<size_t, size_t> type_name_affixes = [] {
  const auto p      = type_name_IMPL<double>();
  const auto target = ::std::string_view("double");
  const auto len    = target.size();
  // Simulate p.find() by hand because clang can't do it.
  size_t i = target.npos;
  for (std::size_t start = 0; start <= p.size() - len; ++start)
  {
    if (p.substr(start, len) == target)
    {
      i = start; // Found the substring, set i to the starting position
      break; // Exit loop after finding the first match
    }
  }
  auto j = p.size() - i - len;
  return ::std::pair{i, j};
}();

template <class T>
constexpr ::std::string_view type_name_impl()
{
#if _CCCL_COMPILER(MSVC)
  constexpr ::std::string_view p = __FUNCSIG__;
  // MSVC does not provide constexpr methods so we make this utility much simpler and return __FUNCSIG__ directly
  return p;
#else // ^^^ _CCCL_COMPILER(MSVC) ^^^ / vvv !_CCCL_COMPILER(MSVC) vvv
  ::std::string_view p = __PRETTY_FUNCTION__;
  return p.substr(type_name_affixes.first, p.size() - type_name_affixes.first - type_name_affixes.second);
#endif // !_CCCL_COMPILER(MSVC)
}

} // namespace reserved

template <class T>
inline constexpr ::std::string_view type_name = reserved::type_name_impl<T>();

template <typename Tuple, typename Fun>
constexpr auto tuple2tuple(const Tuple& t, Fun&& f)
{
  return ::std::apply(
    [&](auto&&... x) {
      return ::std::tuple(f(::std::forward<decltype(x)>(x))...);
    },
    t);
}

/*
 * @brief A function that will fail to compile, and result in an error message
 * with type T. Used internally for debugging. Since this uses a static_assert,
 * it will break compilation even if the function is called in a path that is
 * supposed to be unreachable !
 *
 * @tparam T A type which we want to display.
 */
template <typename T>
class print_type_name_and_fail
{
  static_assert(::std::integral_constant<T*, nullptr>::value, "Type name is: ");
};

namespace reserved
{

template <class T>
class meyers_singleton
{
protected:
  template <class U>
  struct wrapper
  {
    using type = U;
  };
  friend typename wrapper<T>::type;

  meyers_singleton()                        = default;
  ~meyers_singleton()                       = default;
  meyers_singleton(const meyers_singleton&) = delete;
  meyers_singleton(meyers_singleton&&)      = delete;

public:
  static T& instance()
  {
    static_assert(!::std::is_default_constructible_v<T>,
                  "Make the default constructor of your Meyers singleton protected.");
    static_assert(!::std::is_destructible_v<T>, "Make the destructor of your Meyers singleton protected.");
    static_assert(!::std::is_copy_constructible_v<T>, "Disable the copy constructor of your Meyers singleton.");
    static_assert(!::std::is_move_constructible_v<T>, "Disable the move constructor of your Meyers singleton.");
    struct U : T
    {};
    static U instance;
    return instance;
  }
};

} // end namespace reserved

template <typename Array>
auto to_tuple(Array&& array)
{
  return tuple2tuple(::std::forward<Array>(array), [](auto&& e) {
    return ::std::forward<decltype(e)>(e);
  });
}

template <typename T, size_t n>
using array_tuple = decltype(to_tuple(::std::array<T, n>{}));

// Mini-unittest
static_assert(::std::is_same_v<array_tuple<size_t, 3>, ::std::tuple<size_t, size_t, size_t>>);

namespace reserved
{

template <typename T0, typename... Ts>
::cuda::std::array<T0, 1 + sizeof...(Ts)> to_cuda_array(const ::std::tuple<T0, Ts...>& obj)
{
  ::cuda::std::array<T0, 1 + sizeof...(Ts)> result;
  each_in_tuple(obj, [&](auto index, const auto& value) {
    result[index] = value;
  });
  return result;
}

template <typename T, size_t N>
::cuda::std::array<T, N> convert_to_cuda_array(const ::std::array<T, N>& std_array)
{
  ::cuda::std::array<T, N> result;
  for (size_t i = 0; i < N; i++)
  {
    result[i] = std_array[i];
  }
  return result;
}

} // end namespace reserved

template <typename T, typename P0, typename... P>
T only_convertible(P0&& p0, P&&... p)
{
  if constexpr (::std::is_convertible_v<P0, T>)
  {
    ((void) p, ...);
    static_assert(!(::std::is_convertible_v<P, T> || ...), "Duplicate argument type found");
    return ::std::forward<P0>(p0);
  }
  else
  {
    // Ignore current head and recurse to tail
    return only_convertible<T>(::std::forward<P>(p)...);
  }
}

template <typename T, typename... P>
auto all_convertible(P&&... p)
{
  // We use a union here to prevent the compiler from calling the destructor of the array.
  // All construction/destruction will be done manually for efficiency purposes.
  static constexpr size_t size = (::std::is_convertible_v<P, T> + ...);
  unsigned char buffer[size * sizeof(T)];
  auto& result = *reinterpret_cast<::std::array<T, size>*>(&buffer[0]);
  size_t i     = 0; // marks the already-constructed portion of the array
  try
  {
    each_in_pack(
      [&](auto&& e) {
        if constexpr (::std::is_convertible_v<decltype(e), T>)
        {
          new (result.data() + i) T(::std::forward<decltype(e)>(e));
          ++i;
        }
      },
      ::std::forward<P>(p)...);
    return mv(result);
  }
  catch (...)
  {
    for (size_t j = 0; j < i; ++j)
    {
      result[j].~T();
    }
    throw;
  }
}

/*
 * @brief Chooses a parameter from `P...` of a type convertible to `T`. If found, it is returned. If no such parameter
 * is found, returns `default_v`.
 *
 * For now only value semantics are supported.
 *
 * @tparam T Result type
 * @tparam P Variadic parameter types
 * @param default_v Default value
 * @param p Variadic parameter values
 * @return T Either the first convertible parameter, or `default_v` if no such parameter is found
 */
template <typename T, typename... P>
T only_convertible_or([[maybe_unused]] T default_v, P&&... p)
{
  if constexpr (!(::std::is_convertible_v<P, T> || ...))
  {
    ((void) p, ...);
    return default_v;
  }
  else
  {
    return only_convertible<T>(::std::forward<P>(p)...);
  }
}

namespace reserved
{
/* Checks whether a collection of `DataTypes` objects can be unambiguously initialized (in some order)
 from a collection of `ArgTypes` objects. Not all objects must be initialized,
 e.g. `check_initialization<int, int*>(1)` passes. */
template <typename... DataTypes>
struct check_initialization
{
  /* Yields the number of types in `Ts` to which `T` can be converted. */
  template <typename T>
  static constexpr int count_convertibilty = (::std::is_convertible_v<T, DataTypes> + ... + 0);

  template <typename... ArgTypes>
  static constexpr void from()
  {
    (
      [] {
        using T = ArgTypes;
        static_assert(count_convertibilty<T> > 0,
                      "Incompatible argument: argument type doesn't match any member type.");
        static_assert(count_convertibilty<T> == 1,
                      "Ambiguous argument: argument type converts to more than one member type.");
      }(),
      ...); // This expands ArgTypes
  }
};
} // namespace reserved

template <typename... ArgTypes, typename... DataTypes>
void shuffled_args_check(const DataTypes&...)
{
  reserved::check_initialization<DataTypes...>::template from<ArgTypes...>();
}

template <typename... DataTypes, typename... ArgTypes>
::std::tuple<DataTypes...> shuffled_tuple(ArgTypes... args)
{
  reserved::check_initialization<DataTypes...>::template from<ArgTypes...>();
  return ::std::tuple<DataTypes...>{only_convertible_or(DataTypes(), mv(args)...)...};
}

template <typename... DataTypes, typename... ArgTypes>
auto shuffled_array_tuple(ArgTypes... args)
{
  reserved::check_initialization<DataTypes...>::template from<ArgTypes...>();
  return ::std::tuple{all_convertible<DataTypes>(mv(args)...)...};
}

namespace reserved
{

template <typename F, typename Tuple>
struct is_tuple_invocable : ::std::false_type
{};

// Partial specialization that unpacks the tuple
template <typename F, typename... Args>
struct is_tuple_invocable<F, ::std::tuple<Args...>> : ::std::is_invocable<F, Args...>
{};

// Convenient alias template
template <typename F, typename Tuple>
inline constexpr bool is_tuple_invocable_v = is_tuple_invocable<F, Tuple>::value;

template <typename T, typename = void>
struct has_ostream_operator : ::std::false_type
{};

template <typename T>
struct has_ostream_operator<T, decltype(void(::std::declval<::std::ostream&>() << ::std::declval<const T&>()), void())>
    : ::std::true_type
{};

} // end namespace reserved

} // namespace cuda::experimental::stf