
File members: include/cuda/experimental/__stf/internal/task.cuh

// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.

#pragma once

#include <cuda/__cccl_config>

#  pragma GCC system_header
#  pragma clang system_header
#  pragma system_header
#endif // no system header

 * This is a generic class of "tasks" that are synchronized according to
 * accesses on "data" depending on in/out dependencies

#include <cuda/experimental/__stf/internal/msir.cuh>
#include <cuda/experimental/__stf/internal/task_dep.cuh> // task has-a task_dep_vector_untyped

#include <optional>

namespace cuda::experimental::stf

namespace reserved

class mapping_id_tag

using mapping_id_t = reserved::unique_id<reserved::mapping_id_tag>;

} // end namespace reserved

class backend_ctx_untyped;
class logical_data_untyped;
class exec_place;

void reclaim_memory(
  backend_ctx_untyped& ctx, const data_place& place, size_t requested_s, size_t& reclaimed_s, event_list& prereqs);

class task
  enum class phase
    setup, // the task has not started yet
    running, // between acquire and release
    finished, // we have released

  // pimpl
  class impl
    impl(const impl&)            = delete;
    impl& operator=(const impl&) = delete;

    impl(exec_place where = exec_place::current_device())
        : e_place(mv(where))
        , affine_data_place(e_place.affine_data_place())

    // Vector of user provided deps
    task_dep_vector_untyped deps;

    // This list is only useful when calling the get() method of a task, to
    // reduce overheads, we initialize this vector lazily
    void initialize_reordered_indexes()
      // This list gives converts the original index to the sorted index
      // For example the first entry of the list before being ordered has order t->reordered_index[0]

      int sorted_index = 0;
      for (auto& it : deps)
        reordered_indexes[it.dependency_index] = sorted_index;

    // Get the index of the dependency after reordering, for example
    // deps[reordered_index[0]] is the first piece of data
    ::std::vector<size_t> reordered_indexes;

    // Indices of logical data which were locked (non skipped). Indexes are
    // those obtained after sorting.
    ::std::vector<::std::pair<size_t, access_mode>> unskipped_indexes;

    // Extra events that need to be done before the task starts. These are
    // "extra" as these are in addition to the events that will be required to
    // acquire the logical_data_untypeds accessed by the task
    event_list input_events;

    // A string useful for debugging purpose
    mutable ::std::string symbol;

    // This points to the prerequisites for this task's termination
    event_list done_prereqs;

    // Used to uniquely identify the task
    reserved::unique_id_t unique_id;

    // Used to uniquely identify the task for mapping purposes
    reserved::mapping_id_t mapping_id;

    // This is a pointer to a generic data structure used by "unset_place" to
    // restore previous context
    exec_place saved_place_ctx;

    // Indicate the status of the task
    task::phase phase = task::phase::setup;

    // This is where the task is executed
    exec_place e_place;

    // This is the default data place for the task. In general this is the
    // affine data place of the execution place, but this can be a
    // composite data place when using a grid of places for example.
    data_place affine_data_place;

    ::std::vector<::std::function<void()>> post_submission_hooks;

  // This is the only state
  ::std::shared_ptr<impl> pimpl;

      : pimpl(::std::make_shared<impl>())

  task(exec_place ep)
      : pimpl(::std::make_shared<impl>(mv(ep)))

  task(const task& rhs)
      : pimpl(rhs.pimpl)
  task(task&&) = default;

  task& operator=(const task& rhs) = default;
  task& operator=(task&& rhs)      = default;

  explicit operator bool() const
    return pimpl != nullptr;

  bool operator==(const task& rhs) const
    return pimpl == rhs.pimpl;

  const ::std::string& get_symbol() const
    if (pimpl->symbol.empty())
      pimpl->symbol = "task " + ::std::to_string(pimpl->unique_id);
    return pimpl->symbol;

  void set_symbol(::std::string new_symbol)
    EXPECT(get_task_phase() == phase::setup);
    pimpl->symbol = mv(new_symbol);

  void add_dep(task_dep_untyped d)
    EXPECT(get_task_phase() == phase::setup);

  void add_deps(task_dep_vector_untyped input_deps)
    EXPECT(get_task_phase() == phase::setup);
    if (pimpl->deps.empty())
      // Frequent case
      pimpl->deps = mv(input_deps);
        pimpl->deps.end(), ::std::make_move_iterator(input_deps.begin()), ::std::make_move_iterator(input_deps.end()));

  template <typename... Pack>
  void add_deps(task_dep_untyped first, Pack&&... pack)
    EXPECT(get_task_phase() == phase::setup);
    if constexpr (sizeof...(Pack) > 0)

  template <typename... Args>
  void add_deps(::std::tuple<Args...>& deps_tuple)
      [this](const auto&... deps) {
        // Call add_deps on each dep using a fold expression
        // Note that we use this-> while it seems unnecessary to work-around
        // some compiler issue which otherwise believe the "this" captured
        // value is unused.
        (this->add_deps(deps), ...);

  const task_dep_vector_untyped& get_task_deps() const
    return pimpl->deps;

  task& on(exec_place p)
    EXPECT(get_task_phase() == phase::setup);
    // This defines an affine data place too
    pimpl->e_place = mv(p);
    return *this;

  const exec_place& get_exec_place() const
    return pimpl->e_place;
  exec_place& get_exec_place()
    return pimpl->e_place;
  void set_exec_place(const exec_place& place)
    pimpl->e_place = place;

  const data_place& get_affine_data_place() const
    return pimpl->affine_data_place;

  void set_affine_data_place(data_place affine_data_place)
    pimpl->affine_data_place = mv(affine_data_place);

  dim4 grid_dims() const
    return get_exec_place().grid_dims();

  const event_list& get_done_prereqs() const
    return pimpl->done_prereqs;

  template <typename T>
  void merge_event_list(T&& tail)

  instance_id_t find_data_instance_id(const logical_data_untyped& d) const;

  template <typename T, typename logical_data_untyped = logical_data_untyped>
  decltype(auto) get(size_t submitted_index) const;

  // If there are extra input dependencies in addition to STF-induced events
  void set_input_events(event_list _input_events)
    EXPECT(get_task_phase() == phase::setup);
    pimpl->input_events = mv(_input_events);

  const event_list& get_input_events() const
    return pimpl->input_events;

  // Get the unique task identifier
  int get_unique_id() const
    return pimpl->unique_id;

  // Get the unique task mapping identifier
  int get_mapping_id() const
    return pimpl->mapping_id;

  size_t hash() const
    return ::std::hash<impl*>()(pimpl.get());

  void add_post_submission_hook(::std::vector<::std::function<void()>>& hooks)
    for (auto& h : hooks)

  // Resolve all dependencies at the specified execution place
  // Returns execution prereqs
  event_list acquire(backend_ctx_untyped& ctx);

  void release(backend_ctx_untyped& ctx, event_list& done_prereqs);

  // Returns the current state of the task
  phase get_task_phase() const
    return pimpl->phase;

  /* When the task has ended, we cannot do anything with it. It is possible
   * that the user-facing task object is not destroyed when the context is
   * synchronized, so we clear it.
   * This for example happens when doing :
   *   auto t = ctx.task(;
   *   t->*[](auto A){...};
   *   ctx.finalize();
  void clear()
    pimpl.reset((cuda::experimental::stf::task::impl*) nullptr);

namespace reserved

/* This method lazily allocates data (possibly reclaiming memory) and copies data if needed */
template <typename Data>
void dep_allocate(
  backend_ctx_untyped& ctx,
  Data& d,
  access_mode mode,
  const data_place& dplace,
  const ::std::optional<exec_place> eplace,
  instance_id_t instance_id,
  event_list& prereqs)
  auto& inst = d.get_data_instance(instance_id);

  bool already_allocated = inst.is_allocated();
  if (!already_allocated)
    // nvtx_range r("acquire::allocate");
    /* Try to allocate memory : if we fail to do so, we must try to
     * free other instances first, and retry */
    int alloc_attempts = 0;
    while (true)
      ::std::ptrdiff_t s = 0;

      prereqs.merge(inst.get_read_prereq(), inst.get_write_prereq());

      // The allocation routine may decide to store some extra information
      void* extra_args = nullptr;

      d.allocate(dplace, instance_id, s, &extra_args, prereqs);

      // Save extra_args

      if (s >= 0)
        // This allocation was successful
        inst.allocated_size = s;
        inst.reclaimable = true;

      assert(s < 0);

      // Limit the number of attempts if it's simply not possible
      EXPECT(alloc_attempts++ < 5);

      // We failed to allocate so we try to reclaim
      size_t reclaimed_s = 0;
      size_t needed      = -s;
      reclaim_memory(ctx, dplace, needed, reclaimed_s, prereqs);

    // After allocating a reduction instance, we need to initialize it
    if (mode == access_mode::relaxed)
      // We have just allocated a new piece of data to perform
      // reductions, so we need to initialize this with an
      // appropriate user-provided operator
      // First get the data instance and then its reduction operator
      ::std::shared_ptr<reduction_operator_base> ops = inst.get_redux_op();
      ops->init_op_untyped(d, dplace, instance_id, eplace.value(), prereqs);

} // end namespace reserved

// inline size_t task_state::hash() const {
//     size_t h = 0;
//     for (auto& e: logical_data_ids) {
//         int id = e.first;
//         auto handle = e.second.lock();
//         // ignore expired handles
//         if (handle) {
//             hash_combine(h, ::std::hash<int> {}(id));
//             hash_combine(h, handle->hash());
//         }
//     }

//     return h;
// }

class data_instance
  data_instance() {}
  data_instance(bool used, data_place dplace)
      : used(used)
      , dplace(mv(dplace))
#if 0
        // Since this will default construct a task, we need to decrement the id

  void set_used(bool flag)
    assert(flag != used);
    used = flag;

  bool get_used() const
    return used;

  void set_dplace(data_place _dplace)
    dplace = mv(_dplace);
  const data_place& get_dplace() const
    return dplace;

  // Returns what is the reduction operator associated to this data instance
  ::std::shared_ptr<reduction_operator_base> get_redux_op() const
    return redux_op;
  // Sets the reduction operator associated to that data instance
  void set_redux_op(::std::shared_ptr<reduction_operator_base> op)
    redux_op = op;

  // Indicates if the data instance is allocated (ie. if it needs to be
  // allocated prior to use). Note that we may have allocated instances that
  // are out of sync too.
  bool is_allocated() const
    return state.is_allocated();

  void set_allocated(bool b)

  reserved::msir_state_id get_msir() const
    return state.get_msir();

  void set_msir(reserved::msir_state_id st)

  const event_list& get_read_prereq() const
    return state.get_read_prereq();
  const event_list& get_write_prereq() const
    return state.get_write_prereq();

  void set_read_prereq(event_list prereq)
  void set_write_prereq(event_list prereq)

  void add_read_prereq(const event_list& _prereq)
  void add_write_prereq(const event_list& _prereq)

  void clear_read_prereq()
  void clear_write_prereq()

  bool has_last_task_relaxed() const
    return last_task_relaxed.has_value();
  void set_last_task_relaxed(task t)
    last_task_relaxed = mv(t);
  const task& get_last_task_relaxed() const
    return last_task_relaxed.value();

  int max_prereq_id() const
    return state.max_prereq_id();

  // Compute a hash of the MSI/Alloc state
  size_t state_hash() const
    return hash<reserved::per_data_instance_msi_state>{}(state);

  void set_extra_args(void* args)
    extra_args = args;

  void* get_extra_args() const
    return extra_args;

  void clear()

  // Is this instance available or not ? If not we can reuse this data
  // instance when looking for an available slot in the vector of data
  // instances attached to the logical data
  bool used = false;

  // If the used flag is set, this tells where this instance is located
  data_place dplace;

  // Reduction operator attached to the data instance
  ::std::shared_ptr<reduction_operator_base> redux_op;

  // @@@@TODO@@@@ There are a lot of unchecked forwarding with this variable,
  // which is public in practice ...
  // This structure contains everything to implement the MSI protocol,
  // including asynchronous prereqs so that we only use a data instance once
  // it's ready to do so
  reserved::per_data_instance_msi_state state;

  // This stores the last task which used this instance with a relaxed coherence mode (redux)
  ::std::optional<task> last_task_relaxed;

  // This generic pointer can be used to store some information in the
  // allocator which is passed to the deallocation routine.
  void* extra_args = nullptr;

  // Size of the memory allocation (bytes). Only valid for allocated instances.
  size_t allocated_size = 0;

  // A false value indicates that this instance cannot be a candidate for
  // memory reclaiming (e.g. because this corresponds to memory allocated by
  // the user)
  bool reclaimable = false;

  bool automatically_pinned = false;

template <>
struct hash<task>
  ::std::size_t operator()(const task& t) const
    return t.hash();

} // namespace cuda::experimental::stf