include/cuda/experimental/__stf/utility/hash.cuh

File members: include/cuda/experimental/__stf/utility/hash.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/experimental/__stf/utility/unittest.cuh>

namespace cuda::experimental::stf
{

template <typename T>
struct hash;

namespace reserved
{

template <typename E, typename = void>
struct has_std_hash : ::std::false_type
{};

template <typename E>
struct has_std_hash<E, ::std::void_t<decltype(::std::declval<::std::hash<E>>()(::std::declval<E>()))>>
    : ::std::true_type
{};

template <typename E>
inline constexpr bool has_std_hash_v = has_std_hash<E>::value;

} // end namespace reserved

template <typename T>
void hash_combine(size_t& seed, const T& val)
{
  if constexpr (reserved::has_std_hash_v<T>)
  {
    // Use std::hash if it is specialized for T
    seed ^= ::std::hash<T>()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
  }
  else
  {
    // Otherwise, use cuda::experimental::stf::hash
    seed ^= ::cuda::experimental::stf::hash<T>()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
  }
}

template <typename... Ts>
size_t hash_all(const Ts&... vals)
{
  if constexpr (sizeof...(Ts) == 1)
  {
    // Special case: single value, use std::hash if possible
    if constexpr (reserved::has_std_hash_v<Ts...>)
    {
      return ::std::hash<Ts...>()(vals...);
    }
    else
    {
      return ::cuda::experimental::stf::hash<Ts...>()(vals...);
    }
  }
  else
  {
    static_assert(sizeof...(Ts) != 0);
    size_t seed = 0;
    each_in_pack(
      [&](auto& val) {
        hash_combine(seed, val);
      },
      vals...);
    return seed;
  }
}

template <class T1, class T2>
struct hash<::std::pair<T1, T2>>
{
  size_t operator()(const ::std::pair<T1, T2>& p) const
  {
    return cuda::experimental::stf::hash_all(p.first, p.second);
  }
};

template <typename... Ts>
struct hash<::std::tuple<Ts...>>
{
  size_t operator()(const ::std::tuple<Ts...>& p) const
  {
    return ::std::apply(cuda::experimental::stf::hash_all<Ts...>, p);
  }
};

namespace reserved
{
template <typename E, typename = void>
struct has_cudastf_hash : ::std::false_type
{};

template <typename E>
struct has_cudastf_hash<
  E,
  ::std::void_t<decltype(::std::declval<::cuda::experimental::stf::hash<E>>()(::std::declval<E>()))>> : ::std::true_type
{};

template <typename E>
inline constexpr bool has_cudastf_hash_v = has_cudastf_hash<E>::value;

} // end namespace reserved

UNITTEST("hash for tuples")
{
  ::std::unordered_map<::std::tuple<int, int>, int, ::cuda::experimental::stf::hash<::std::tuple<int, int>>> m;
  m[::std::tuple(1, 2)] = 42;
};

} // end namespace cuda::experimental::stf