include/cuda/experimental/__stf/internal/task.cuh

File members: include/cuda/experimental/__stf/internal/task.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

/*
 * This is a generic class of "tasks" that are synchronized according to
 * accesses on "data" depending on in/out dependencies
 */

#include <cuda/experimental/__stf/internal/msir.cuh>
#include <cuda/experimental/__stf/internal/task_dep.cuh> // task has-a task_dep_vector_untyped

#include <optional>

namespace cuda::experimental::stf
{

namespace reserved
{

class mapping_id_tag
{};

using mapping_id_t = reserved::unique_id<reserved::mapping_id_tag>;

} // end namespace reserved

class backend_ctx_untyped;
class logical_data_untyped;
class exec_place;

void reclaim_memory(
  backend_ctx_untyped& ctx, const data_place& place, size_t requested_s, size_t& reclaimed_s, event_list& prereqs);

class task
{
public:
  enum class phase
  {
    setup, // the task has not started yet
    running, // between acquire and release
    finished, // we have released
  };

private:
  // pimpl
  class impl
  {
  public:
    impl(const impl&)            = delete;
    impl& operator=(const impl&) = delete;

    impl(exec_place where = exec_place::current_device())
        : e_place(mv(where))
        , affine_data_place(e_place.affine_data_place())
    {}

    // Vector of user provided deps
    task_dep_vector_untyped deps;

    // This list is only useful when calling the get() method of a task, to
    // reduce overheads, we initialize this vector lazily
    void initialize_reordered_indexes()
    {
      // This list gives converts the original index to the sorted index
      // For example the first entry of the list before being ordered has order t->reordered_index[0]
      reordered_indexes.resize(deps.size());

      int sorted_index = 0;
      for (auto& it : deps)
      {
        reordered_indexes[it.dependency_index] = sorted_index;
        sorted_index++;
      }
    }

    // Get the index of the dependency after reordering, for example
    // deps[reordered_index[0]] is the first piece of data
    ::std::vector<size_t> reordered_indexes;

    // Indices of logical data which were locked (non skipped). Indexes are
    // those obtained after sorting.
    ::std::vector<::std::pair<size_t, access_mode>> unskipped_indexes;

    // Extra events that need to be done before the task starts. These are
    // "extra" as these are in addition to the events that will be required to
    // acquire the logical_data_untypeds accessed by the task
    event_list input_events;

    // A string useful for debugging purpose
    mutable ::std::string symbol;

    // This points to the prerequisites for this task's termination
    event_list done_prereqs;

    // Used to uniquely identify the task
    reserved::unique_id_t unique_id;

    // Used to uniquely identify the task for mapping purposes
    reserved::mapping_id_t mapping_id;

    // This is a pointer to a generic data structure used by "unset_place" to
    // restore previous context
    exec_place saved_place_ctx;

    // Indicate the status of the task
    task::phase phase = task::phase::setup;

    // This is where the task is executed
    exec_place e_place;

    // This is the default data place for the task. In general this is the
    // affine data place of the execution place, but this can be a
    // composite data place when using a grid of places for example.
    data_place affine_data_place;

    ::std::vector<::std::function<void()>> post_submission_hooks;
  };

protected:
  // This is the only state
  ::std::shared_ptr<impl> pimpl;

public:
  task()
      : pimpl(::std::make_shared<impl>())
  {}

  task(exec_place ep)
      : pimpl(::std::make_shared<impl>(mv(ep)))
  {}

  task(const task& rhs)
      : pimpl(rhs.pimpl)
  {}
  task(task&&) = default;

  task& operator=(const task& rhs) = default;
  task& operator=(task&& rhs)      = default;

  explicit operator bool() const
  {
    return pimpl != nullptr;
  }

  bool operator==(const task& rhs) const
  {
    return pimpl == rhs.pimpl;
  }

  const ::std::string& get_symbol() const
  {
    if (pimpl->symbol.empty())
    {
      pimpl->symbol = "task " + ::std::to_string(pimpl->unique_id);
    }
    return pimpl->symbol;
  }

  void set_symbol(::std::string new_symbol)
  {
    EXPECT(get_task_phase() == phase::setup);
    pimpl->symbol = mv(new_symbol);
  }

  void add_dep(task_dep_untyped d)
  {
    EXPECT(get_task_phase() == phase::setup);
    pimpl->deps.push_back(mv(d));
  }

  void add_deps(task_dep_vector_untyped input_deps)
  {
    EXPECT(get_task_phase() == phase::setup);
    if (pimpl->deps.empty())
    {
      // Frequent case
      pimpl->deps = mv(input_deps);
    }
    else
    {
      pimpl->deps.insert(
        pimpl->deps.end(), ::std::make_move_iterator(input_deps.begin()), ::std::make_move_iterator(input_deps.end()));
    }
  }

  template <typename... Pack>
  void add_deps(task_dep_untyped first, Pack&&... pack)
  {
    EXPECT(get_task_phase() == phase::setup);
    pimpl->deps.push_back(mv(first));
    if constexpr (sizeof...(Pack) > 0)
    {
      add_deps(::std::forward<Pack>(pack)...);
    }
  }

  template <typename... Args>
  void add_deps(::std::tuple<Args...>& deps_tuple)
  {
    ::std::apply(
      [this](const auto&... deps) {
        // Call add_deps on each dep using a fold expression
        //
        // Note that we use this-> while it seems unnecessary to work-around
        // some compiler issue which otherwise believe the "this" captured
        // value is unused.
        (this->add_deps(deps), ...);
      },
      deps_tuple);
  }

  const task_dep_vector_untyped& get_task_deps() const
  {
    return pimpl->deps;
  }

  task& on(exec_place p)
  {
    EXPECT(get_task_phase() == phase::setup);
    // This defines an affine data place too
    set_affine_data_place(p.affine_data_place());
    pimpl->e_place = mv(p);
    return *this;
  }

  const exec_place& get_exec_place() const
  {
    return pimpl->e_place;
  }
  exec_place& get_exec_place()
  {
    return pimpl->e_place;
  }
  void set_exec_place(const exec_place& place)
  {
    pimpl->e_place = place;
  }

  const data_place& get_affine_data_place() const
  {
    return pimpl->affine_data_place;
  }

  void set_affine_data_place(data_place affine_data_place)
  {
    pimpl->affine_data_place = mv(affine_data_place);
  }

  dim4 grid_dims() const
  {
    return get_exec_place().grid_dims();
  }

  const event_list& get_done_prereqs() const
  {
    return pimpl->done_prereqs;
  }

  template <typename T>
  void merge_event_list(T&& tail)
  {
    pimpl->done_prereqs.merge(::std::forward<T>(tail));
  }

  instance_id_t find_data_instance_id(const logical_data_untyped& d) const;

  template <typename T, typename logical_data_untyped = logical_data_untyped>
  decltype(auto) get(size_t submitted_index) const;

  // If there are extra input dependencies in addition to STF-induced events
  void set_input_events(event_list _input_events)
  {
    EXPECT(get_task_phase() == phase::setup);
    pimpl->input_events = mv(_input_events);
  }

  const event_list& get_input_events() const
  {
    return pimpl->input_events;
  }

  // Get the unique task identifier
  int get_unique_id() const
  {
    return pimpl->unique_id;
  }

  // Get the unique task mapping identifier
  int get_mapping_id() const
  {
    return pimpl->mapping_id;
  }

  size_t hash() const
  {
    return ::std::hash<impl*>()(pimpl.get());
  }

  void add_post_submission_hook(::std::vector<::std::function<void()>>& hooks)
  {
    for (auto& h : hooks)
    {
      pimpl->post_submission_hooks.push_back(h);
    }
  }

  // Resolve all dependencies at the specified execution place
  // Returns execution prereqs
  event_list acquire(backend_ctx_untyped& ctx);

  void release(backend_ctx_untyped& ctx, event_list& done_prereqs);

  // Returns the current state of the task
  phase get_task_phase() const
  {
    EXPECT(pimpl);
    return pimpl->phase;
  }

  /* When the task has ended, we cannot do anything with it. It is possible
   * that the user-facing task object is not destroyed when the context is
   * synchronized, so we clear it.
   *
   * This for example happens when doing :
   *   auto t = ctx.task(A.rw());
   *   t->*[](auto A){...};
   *   ctx.finalize();
   */
  void clear()
  {
    pimpl.reset((cuda::experimental::stf::task::impl*) nullptr);
  }
};

namespace reserved
{

/* This method lazily allocates data (possibly reclaiming memory) and copies data if needed */
template <typename Data>
void dep_allocate(
  backend_ctx_untyped& ctx,
  Data& d,
  access_mode mode,
  const data_place& dplace,
  const ::std::optional<exec_place> eplace,
  instance_id_t instance_id,
  event_list& prereqs)
{
  auto& inst = d.get_data_instance(instance_id);

  /*
   * DATA LAZY ALLOCATION
   */
  bool already_allocated = inst.is_allocated();
  if (!already_allocated)
  {
    // nvtx_range r("acquire::allocate");
    /* Try to allocate memory : if we fail to do so, we must try to
     * free other instances first, and retry */
    int alloc_attempts = 0;
    while (true)
    {
      ::std::ptrdiff_t s = 0;

      prereqs.merge(inst.get_read_prereq(), inst.get_write_prereq());

      // The allocation routine may decide to store some extra information
      void* extra_args = nullptr;

      d.allocate(dplace, instance_id, s, &extra_args, prereqs);

      // Save extra_args
      inst.set_extra_args(extra_args);

      if (s >= 0)
      {
        // This allocation was successful
        inst.allocated_size = s;
        inst.set_allocated(true);
        inst.reclaimable = true;
        break;
      }

      assert(s < 0);

      // Limit the number of attempts if it's simply not possible
      EXPECT(alloc_attempts++ < 5);

      // We failed to allocate so we try to reclaim
      size_t reclaimed_s = 0;
      size_t needed      = -s;
      reclaim_memory(ctx, dplace, needed, reclaimed_s, prereqs);
    }

    // After allocating a reduction instance, we need to initialize it
    if (mode == access_mode::relaxed)
    {
      assert(eplace.has_value());
      // We have just allocated a new piece of data to perform
      // reductions, so we need to initialize this with an
      // appropriate user-provided operator
      // First get the data instance and then its reduction operator
      ::std::shared_ptr<reduction_operator_base> ops = inst.get_redux_op();
      ops->init_op_untyped(d, dplace, instance_id, eplace.value(), prereqs);
    }
  }
}

} // end namespace reserved

// inline size_t task_state::hash() const {
//     size_t h = 0;
//     for (auto& e: logical_data_ids) {
//         int id = e.first;
//         auto handle = e.second.lock();
//         // ignore expired handles
//         if (handle) {
//             hash_combine(h, ::std::hash<int> {}(id));
//             hash_combine(h, handle->hash());
//         }
//     }

//     return h;
// }

class data_instance
{
public:
  data_instance() {}
  data_instance(bool used, data_place dplace)
      : used(used)
      , dplace(mv(dplace))
  {
#if 0
        // Since this will default construct a task, we need to decrement the id
        reserved::mapping_id_t::decrement_id();
#endif
  }

  void set_used(bool flag)
  {
    assert(flag != used);
    used = flag;
  }

  bool get_used() const
  {
    return used;
  }

  void set_dplace(data_place _dplace)
  {
    dplace = mv(_dplace);
  }
  const data_place& get_dplace() const
  {
    return dplace;
  }

  // Returns what is the reduction operator associated to this data instance
  ::std::shared_ptr<reduction_operator_base> get_redux_op() const
  {
    return redux_op;
  }
  // Sets the reduction operator associated to that data instance
  void set_redux_op(::std::shared_ptr<reduction_operator_base> op)
  {
    redux_op = op;
  }

  // Indicates if the data instance is allocated (ie. if it needs to be
  // allocated prior to use). Note that we may have allocated instances that
  // are out of sync too.
  bool is_allocated() const
  {
    return state.is_allocated();
  }

  void set_allocated(bool b)
  {
    state.set_allocated(b);
  }

  reserved::msir_state_id get_msir() const
  {
    return state.get_msir();
  }

  void set_msir(reserved::msir_state_id st)
  {
    state.set_msir(st);
  }

  const event_list& get_read_prereq() const
  {
    return state.get_read_prereq();
  }
  const event_list& get_write_prereq() const
  {
    return state.get_write_prereq();
  }

  void set_read_prereq(event_list prereq)
  {
    state.set_read_prereq(mv(prereq));
  }
  void set_write_prereq(event_list prereq)
  {
    state.set_write_prereq(mv(prereq));
  }

  void add_read_prereq(const event_list& _prereq)
  {
    state.add_read_prereq(_prereq);
  }
  void add_write_prereq(const event_list& _prereq)
  {
    state.add_write_prereq(_prereq);
  }

  void clear_read_prereq()
  {
    state.clear_read_prereq();
  }
  void clear_write_prereq()
  {
    state.clear_write_prereq();
  }

  bool has_last_task_relaxed() const
  {
    return last_task_relaxed.has_value();
  }
  void set_last_task_relaxed(task t)
  {
    last_task_relaxed = mv(t);
  }
  const task& get_last_task_relaxed() const
  {
    assert(last_task_relaxed.has_value());
    return last_task_relaxed.value();
  }

  int max_prereq_id() const
  {
    return state.max_prereq_id();
  }

  // Compute a hash of the MSI/Alloc state
  size_t state_hash() const
  {
    return hash<reserved::per_data_instance_msi_state>{}(state);
  }

  void set_extra_args(void* args)
  {
    extra_args = args;
  }

  void* get_extra_args() const
  {
    return extra_args;
  }

  void clear()
  {
    clear_read_prereq();
    clear_write_prereq();
    last_task_relaxed.reset();
  }

private:
  // Is this instance available or not ? If not we can reuse this data
  // instance when looking for an available slot in the vector of data
  // instances attached to the logical data
  bool used = false;

  // If the used flag is set, this tells where this instance is located
  data_place dplace;

  // Reduction operator attached to the data instance
  ::std::shared_ptr<reduction_operator_base> redux_op;

  // @@@@TODO@@@@ There are a lot of unchecked forwarding with this variable,
  // which is public in practice ...
  //
  // This structure contains everything to implement the MSI protocol,
  // including asynchronous prereqs so that we only use a data instance once
  // it's ready to do so
  reserved::per_data_instance_msi_state state;

  // This stores the last task which used this instance with a relaxed coherence mode (redux)
  ::std::optional<task> last_task_relaxed;

  // This generic pointer can be used to store some information in the
  // allocator which is passed to the deallocation routine.
  void* extra_args = nullptr;

public:
  // Size of the memory allocation (bytes). Only valid for allocated instances.
  size_t allocated_size = 0;

  // A false value indicates that this instance cannot be a candidate for
  // memory reclaiming (e.g. because this corresponds to memory allocated by
  // the user)
  bool reclaimable = false;

  bool automatically_pinned = false;
};

template <>
struct hash<task>
{
  ::std::size_t operator()(const task& t) const
  {
    return t.hash();
  }
};

} // namespace cuda::experimental::stf