include/cuda/experimental/__stf/internal/repeat.cuh

File members: include/cuda/experimental/__stf/internal/repeat.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/experimental/__stf/utility/core.cuh>

#include <cassert>
#include <variant>

namespace cuda::experimental::stf::reserved
{

template <typename context_t>
class repeat_scope
{
public:
  static constexpr size_t tasks_per_epoch = 200;

  repeat_scope(context_t& ctx, size_t count)
      : condition(count)
      , ctx(ctx)
  {}
  repeat_scope(context_t& ctx, ::std::function<bool()> condition)
      : condition(mv(condition))
      , ctx(ctx)
  {}

  template <typename Fun>
  void operator->*(Fun&& f)
  {
    size_t task_cnt = 0;
    for (size_t iter = 0; next(); ++iter)
    {
      size_t before_cnt = ctx.task_count();

      static_assert(::std::is_invocable_v<Fun, context_t, size_t>, "Incorrect lambda function signature.");

      f(ctx, iter);

      size_t after_cnt = ctx.task_count();
      assert(after_cnt >= before_cnt);

      // If there is more than a specific number of tasks, fire a new epoch !
      task_cnt += after_cnt - before_cnt;
      if (task_cnt > tasks_per_epoch)
      {
        ctx.change_epoch();
        task_cnt = 0;
      }
    }
  }

private:
  bool next()
  {
    return condition.index() == 0
           ? ::std::get<size_t>(condition)-- > 0
           : ::std::get<::std::function<bool()>>(condition)();
  }

  // Number of iterations, or a function which evaluates if we continue
  ::std::variant<size_t, ::std::function<bool()>> condition;
  // The supporting context for this construct
  context_t& ctx;
};

} // end namespace cuda::experimental::stf::reserved