include/cuda/experimental/__stf/utility/threads.cuh

File members: include/cuda/experimental/__stf/utility/threads.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/source_location>

#include <atomic>
#include <mutex>

namespace cuda::experimental::stf::reserved
{

template <typename Mutex>
class single_threaded_section
{
public:
  using mutex_type = Mutex;

#ifndef NDEBUG

  explicit single_threaded_section(mutex_type& m,
                                   const _CUDA_VSTD::source_location loc = _CUDA_VSTD::source_location::current())
      : mutex(m)
  {
    if (mutex.try_lock())
    {
      if constexpr (!::std::is_same_v<mutex_type, ::std::recursive_mutex>)
      {
        // Would not be able to reenter this mutex, so release immediately.
        mutex.unlock();
      }
      return;
    }
    fprintf(stderr, "%s(%u) Error: contested single-threaded section.\n", loc.file_name(), loc.line());
    abort();
  }

  single_threaded_section(mutex_type& m, ::std::adopt_lock_t) noexcept
      : mutex(m)
  {} // calling thread owns mutex

  single_threaded_section(const single_threaded_section&)            = delete;
  single_threaded_section& operator=(const single_threaded_section&) = delete;

  ~single_threaded_section()
  {
    if constexpr (::std::is_same_v<mutex_type, ::std::recursive_mutex>)
    {
      // Keep the recursive mutex up until destruction.
      mutex.unlock();
    }
  }

private:
  mutex_type& mutex;

#else

  explicit single_threaded_section(mutex_type&) {}
  single_threaded_section(mutex_type&, ::std::adopt_lock_t) noexcept {}
  single_threaded_section(const single_threaded_section&)            = delete;
  single_threaded_section& operator=(const single_threaded_section&) = delete;

#endif
};

template <typename T>
class counter
{
public:
  counter() = default;

  static auto load()
  {
    return count.load();
  }

  static auto increment()
  {
    count++;
  }

  static auto decrement()
  {
    count--;
  }

private:
  static inline ::std::atomic<unsigned long> count{0};
};

template <typename T>
class high_water_mark
{
public:
  high_water_mark() = default;

  static void record(unsigned long v)
  {
    for (;;)
    {
      auto previous = tracker.load();
      if (previous >= v || tracker.compare_exchange_weak(previous, v))
      {
        break;
      }
    }
  }

  static unsigned long load()
  {
    return tracker.load();
  }

private:
  static inline ::std::atomic<unsigned long> tracker{0};
};

} // namespace cuda::experimental::stf::reserved