include/cuda/experimental/__stf/internal/msir.cuh
File members: include/cuda/experimental/__stf/internal/msir.cuh
//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <cuda/__cccl_config>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
/*
* We here define the protocol to keep data copies up to date
* Task dependencies are supposed to be enforced by the STF model, so this is
* intended to implement the required data movements/allocations.
*
* M : msir_state_id::modified
* S : msir_state_id::shared
* I : msir_state_id::invalid
* R : msir_state_id::reduction
*/
#include <cuda/experimental/__stf/internal/async_prereq.cuh>
namespace cuda::experimental::stf
{
namespace reserved
{
enum class msir_state_id
{
invalid,
modified,
shared,
reduction,
};
inline ::std::string status_to_string(msir_state_id status)
{
switch (status)
{
case msir_state_id::modified:
return "msir_state_id::modified";
case msir_state_id::shared:
return "msir_state_id::shared";
case msir_state_id::invalid:
return "msir_state_id::invalid";
case msir_state_id::reduction:
return "REDUCTION";
}
return "UNKNOWN";
}
inline char status_to_char(msir_state_id status)
{
switch (status)
{
case msir_state_id::modified:
return 'M';
case msir_state_id::shared:
return 'S';
case msir_state_id::invalid:
return 'I';
case msir_state_id::reduction:
return 'R';
}
return 'U';
}
/*
* Generic interface that does not suppose how memory places are organized.
* Could be a concept in the future. Passive documentation for now.
*/
/*
template <typename memory_place_interface> class msi_state {
public:
// Returns the state of a piece of data at a specific place
int shape(memory_place_interface place, class event_list **msi_prereq);
// Sets the status for a specific place
void set_state(memory_place_interface place, int state, class event_list *new_msi_prereq);
// Update data status when accessing data at the specified place with the specified access type
// Returns prereq_out
class event_list *update_state(memory_place_interface place, int access_mode, class event_list *prereq_in);
// Find a valid copy to move a piece of data to dst_place. Returned value is a possible source place
// TODO perhaps we should return a list
// TODO perhaps this should return a pair of (place+prereq)
memory_place_interface find_source_place(memory_place_interface dst_place);
private:
};
*/
class per_data_instance_msi_state
{
public:
per_data_instance_msi_state() {}
~per_data_instance_msi_state() {}
msir_state_id get_msir() const
{
return msir;
}
void set_msir(msir_state_id _msir)
{
msir = _msir;
}
bool is_allocated() const
{
return allocated;
}
void set_allocated(bool _allocated)
{
allocated = _allocated;
}
const event_list& get_read_prereq() const
{
return read_prereq;
}
const event_list& get_write_prereq() const
{
return write_prereq;
}
void set_read_prereq(event_list prereq)
{
read_prereq = mv(prereq);
}
void set_write_prereq(event_list prereq)
{
write_prereq = mv(prereq);
}
int max_prereq_id() const
{
int res = read_prereq.max_prereq_id();
res = ::std::max(res, write_prereq.max_prereq_id());
return res;
}
template <typename T>
void add_read_prereq(T&& prereq)
{
read_prereq.merge(::std::forward<T>(prereq));
if (read_prereq.size() > 16)
{
read_prereq.optimize();
}
}
template <typename T>
void add_write_prereq(T&& prereq)
{
write_prereq.merge(::std::forward<T>(prereq));
if (write_prereq.size() > 16)
{
write_prereq.optimize();
}
}
void clear_read_prereq()
{
read_prereq.clear();
}
void clear_write_prereq()
{
write_prereq.clear();
}
size_t hash() const
{
return hash_all(allocated, (int) msir);
}
private:
// We need to fulfill these events __and those in read_prereq__ to modify the instance
event_list write_prereq;
// We need to fulfill these events to read the instance without modifying it
event_list read_prereq;
msir_state_id msir = msir_state_id::invalid; // MSIR = msir_state_id::modified, ...
bool allocated = false;
};
} // end namespace reserved
// Overload hash to compute the hash of a per_data_instance_msi_state
// class from the MSI and allocated states.
template <>
struct hash<reserved::per_data_instance_msi_state>
{
::std::size_t operator()(reserved::per_data_instance_msi_state const& s) const noexcept
{
return s.hash();
}
};
} // namespace cuda::experimental::stf