include/cuda/experimental/__stf/utility/dimensions.cuh
File members: include/cuda/experimental/__stf/utility/dimensions.cuh
//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <cuda/__cccl_config>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <cuda/experimental/__stf/utility/cuda_attributes.cuh>
#include <cuda/experimental/__stf/utility/hash.cuh>
#include <cuda/experimental/__stf/utility/unittest.cuh>
namespace cuda::experimental::stf
{
class pos4
{
public:
constexpr pos4() = default;
template <typename Integral>
_CCCL_HOST_DEVICE constexpr explicit pos4(Integral x, Integral y = 0, Integral z = 0, Integral t = 0)
: x(static_cast<int>(x))
, y(static_cast<int>(y))
, z(static_cast<int>(z))
, t(static_cast<int>(t))
{}
_CCCL_HOST_DEVICE constexpr int get(size_t axis_id) const
{
switch (axis_id)
{
case 0:
return x;
case 1:
return y;
case 2:
return z;
default:
assert(axis_id == 3);
return t;
}
}
_CCCL_HOST_DEVICE constexpr int operator()(int axis_id) const
{
return get(axis_id);
}
_CCCL_HOST_DEVICE constexpr bool operator<(const pos4& rhs) const
{
if (x != rhs.x)
{
return x < rhs.x;
}
if (y != rhs.y)
{
return y < rhs.y;
}
if (z != rhs.z)
{
return z < rhs.z;
}
return t < rhs.t;
}
_CCCL_HOST_DEVICE constexpr bool operator==(const pos4& rhs) const
{
return x == rhs.x && y == rhs.y && z == rhs.z && t == rhs.t;
}
::std::string to_string() const
{
return ::std::string("pos4(" + ::std::to_string(x) + "," + ::std::to_string(y) + "," + ::std::to_string(z) + ","
+ ::std::to_string(t) + ")");
}
int x = 0;
int y = 0;
int z = 0;
int t = 0;
};
class dim4 : public pos4
{
public:
dim4() = default;
_CCCL_HOST_DEVICE constexpr explicit dim4(int x, int y = 1, int z = 1, int t = 1)
: pos4(x, y, z, t)
{}
// TODO: could coords ever be negative? (if not, maybe they should be unsigned).
_CCCL_HOST_DEVICE constexpr size_t size() const
{
const ::std::ptrdiff_t result = ::std::ptrdiff_t(x) * y * z * t;
assert(result >= 0);
return result;
}
_CCCL_HOST_DEVICE static constexpr dim4 min(const dim4& a, const dim4& b)
{
return dim4(::std::min(a.x, b.x), ::std::min(a.y, b.y), ::std::min(a.z, b.z), ::std::min(a.t, b.t));
}
_CCCL_HOST_DEVICE constexpr size_t get_index(const pos4& p) const
{
assert(p.get(0) <= x);
assert(p.get(1) <= y);
assert(p.get(2) <= z);
assert(p.get(3) <= t);
size_t index = p.get(0) + x * (p.get(1) + y * (p.get(2) + p.get(3) * z));
return index;
}
_CCCL_HOST_DEVICE constexpr size_t get_rank() const
{
if (t > 1)
{
return 3;
}
if (z > 1)
{
return 2;
}
if (y > 1)
{
return 1;
}
return 0;
}
};
template <size_t dimensions>
class box
{
public:
template <typename Int1, typename Int2>
_CCCL_HOST_DEVICE box(const ::std::array<::std::pair<Int1, Int2>, dimensions>& s)
: s(s)
{}
template <typename Int>
_CCCL_HOST_DEVICE box(const ::std::array<Int, dimensions>& sizes)
{
for (size_t ind : each(0, dimensions))
{
s[ind].first = 0;
s[ind].second = sizes[ind];
if constexpr (::std::is_signed_v<Int>)
{
_CCCL_ASSERT(sizes[ind] >= 0, "Invalid shape.");
}
}
}
template <typename... Int>
_CCCL_HOST_DEVICE box(Int... args)
{
static_assert(sizeof...(Int) == dimensions, "Number of dimensions must match");
each_in_pack(
[&](auto i, const auto& e) {
if constexpr (::std::is_arithmetic_v<::std::remove_reference_t<decltype(e)>>)
{
s[i].first = 0;
s[i].second = e;
}
else
{
// Assume a pair
s[i].first = e.first;
s[i].second = e.second;
}
},
args...);
}
template <typename... E>
_CCCL_HOST_DEVICE box(::std::initializer_list<E>... args)
{
static_assert(sizeof...(E) == dimensions, "Number of dimensions must match");
each_in_pack(
[&](auto i, auto&& e) {
_CCCL_ASSERT((e.size() == 1 || e.size() == 2), "Invalid arguments for box.");
if (e.size() > 1)
{
s[i].first = *e.begin();
s[i].second = e.begin()[1];
}
else
{
s[i].first = 0;
s[i].second = *e.begin();
}
},
args...);
}
// _CCCL_HOST_DEVICE box(const typename ::std::experimental::dextents<size_t, dimensions>& extents) {
// for (size_t i: each(0, dimensions)) {
// s[i].first = 0;
// s[i].second = extents[ind];
// }
// }
_CCCL_HOST_DEVICE void print()
{
printf("EXPLICIT SHAPE\n");
for (size_t ind = 0; ind < dimensions; ind++)
{
assert(s[ind].first <= s[ind].second);
printf(" %ld -> %ld\n", s[ind].first, s[ind].second);
}
}
_CCCL_HOST_DEVICE ::std::ptrdiff_t get_extent(size_t dim) const
{
return s[dim].second - s[dim].first;
}
_CCCL_HOST_DEVICE ::std::ptrdiff_t get_begin(size_t dim) const
{
return s[dim].first;
}
_CCCL_HOST_DEVICE ::std::ptrdiff_t get_end(size_t dim) const
{
return s[dim].second;
}
_CCCL_HOST_DEVICE ::std::ptrdiff_t size() const
{
if constexpr (dimensions == 1)
{
return s[0].second - s[0].first;
}
else
{
size_t res = 1;
for (size_t d = 0; d < dimensions; d++)
{
res *= get_extent(d);
}
return res;
}
}
_CCCL_HOST_DEVICE constexpr size_t get_rank() const
{
return dimensions;
}
// Iterator class for box
class iterator
{
private:
box iterated; // A copy of the box being iterated
::std::array<::std::ptrdiff_t, dimensions> current; // Array to store the current position in each dimension
public:
_CCCL_HOST_DEVICE iterator(const box& b, bool at_end = false)
: iterated(b)
{
if (at_end)
{
for (size_t i = 0; i < dimensions; ++i)
{
current[i] = iterated.get_end(i);
}
}
else
{
for (size_t i = 0; i < dimensions; ++i)
{
current[i] = iterated.get_begin(i);
}
}
}
// Overload the dereference operator to get the current position
_CCCL_HOST_DEVICE auto& operator*()
{
if constexpr (dimensions == 1UL)
{
return current[0];
}
else
{
return current;
}
}
// Overload the pre-increment operator to move to the next position
_CCCL_HOST_DEVICE iterator& operator++()
{
if constexpr (dimensions == 1UL)
{
current[0]++;
}
else
{
// Increment current with carry to next dimension
for (size_t i : each(0, dimensions))
{
_CCCL_ASSERT(current[i] < iterated.get_end(i), "Attempt to increment past the end.");
if (++current[i] < iterated.get_end(i))
{
// Found the new posish, now reset all lower dimensions to "zero"
for (size_t j : each(0, i))
{
current[j] = iterated.get_begin(j);
}
break;
}
}
}
return *this;
}
// Overload the equality operator to check if two iterators are equal
_CCCL_HOST_DEVICE bool operator==(const iterator& rhs) const
{ /*printf("EQUALITY TEST index %d %d shape equal ? %s\n", index,
other.index, (&shape == &other.shape)?"yes":"no"); */
_CCCL_ASSERT(iterated == rhs.iterated, "Cannot compare iterators in different boxes.");
for (auto i : each(0, dimensions))
{
if (current[i] != rhs.current[i])
{
return false;
}
}
return true;
}
// Overload the inequality operator to check if two iterators are not equal
_CCCL_HOST_DEVICE bool operator!=(const iterator& other) const
{
return !(*this == other);
}
};
// Functions to create the begin and end iterators
_CCCL_HOST_DEVICE iterator begin()
{
return iterator(*this);
}
_CCCL_HOST_DEVICE iterator end()
{
return iterator(*this, true);
}
// Overload the equality operator to check if two shapes are equal
_CCCL_HOST_DEVICE bool operator==(const box& rhs) const
{
for (size_t i : each(0, dimensions))
{
if (get_begin(i) != rhs.get_begin(i) || get_end(i) != rhs.get_end(i))
{
return false;
}
}
return true;
}
_CCCL_HOST_DEVICE bool operator!=(const box& rhs) const
{
return !(*this == rhs);
}
using coords_t = array_tuple<size_t, dimensions>;
// This transforms a tuple of (shape, 1D index) into a coordinate
_CCCL_HOST_DEVICE coords_t index_to_coords(size_t index) const
{
// Help the compiler which may not detect that a device lambda is calling a device lambda
CUDASTF_NO_DEVICE_STACK
return make_tuple_indexwise<dimensions>([&](auto i) {
// included
const ::std::ptrdiff_t begin_i = get_begin(i);
const ::std::ptrdiff_t extent_i = get_extent(i);
auto result = begin_i + (index % extent_i);
index /= extent_i;
return result;
});
CUDASTF_NO_DEVICE_STACK
}
private:
::std::array<::std::pair<::std::ptrdiff_t, ::std::ptrdiff_t>, dimensions> s;
};
// Deduction guides
template <typename... Int>
box(Int...) -> box<sizeof...(Int)>;
template <typename... E>
box(::std::initializer_list<E>...) -> box<sizeof...(E)>;
template <typename E, size_t dimensions>
box(::std::array<E, dimensions>) -> box<dimensions>;
#ifdef UNITTESTED_FILE
UNITTEST("box<3>")
{
// Expect to iterate over Card({0, 1, 2}x{1, 2}x{10, 11, 12, 13}) = 3*2*4 = 24 items
const size_t expected_cnt = 24;
size_t cnt = 0;
auto shape = box({0, 3}, {1, 3}, {10, 14});
static_assert(::std::is_same_v<decltype(shape), box<3>>);
for ([[maybe_unused]] const auto& pos : shape)
{
EXPECT(cnt < expected_cnt);
cnt++;
}
EXPECT(cnt == expected_cnt);
};
UNITTEST("box<3> upper")
{
// Expect to iterate over Card({0, 1, 2}x{0, 1}x{0, 1, 2, 3}) = 3*2*4 = 24 items
const size_t expected_cnt = 24;
size_t cnt = 0;
auto shape = box(3, 2, 4);
static_assert(::std::is_same_v<decltype(shape), box<3>>);
for ([[maybe_unused]] const auto& pos : shape)
{
EXPECT(cnt < expected_cnt);
cnt++;
}
EXPECT(cnt == expected_cnt);
};
UNITTEST("empty box<1>")
{
auto shape = box({7, 7});
static_assert(::std::is_same_v<decltype(shape), box<1>>);
auto it_end = shape.end();
auto it_begin = shape.begin();
if (it_end != it_begin)
{
fprintf(stderr, "Error: begin() != end()\n");
abort();
}
// There should be no entry in this range
for ([[maybe_unused]] const auto& pos : shape)
{
abort();
}
};
UNITTEST("mix of integrals and pairs")
{
const size_t expected_cnt = 12;
size_t cnt = 0;
auto shape = box(3, ::std::pair(1, 2), 4);
static_assert(::std::is_same_v<decltype(shape), box<3>>);
for ([[maybe_unused]] const auto& pos : shape)
{
EXPECT(cnt < expected_cnt);
cnt++;
}
EXPECT(cnt == expected_cnt);
};
#endif // UNITTESTED_FILE
// So that we can create unordered_map of pos4 entries
template <>
struct hash<pos4>
{
::std::size_t operator()(pos4 const& s) const noexcept
{
return hash_all(s.x, s.y, s.z, s.t);
}
};
// So that we can create maps of dim4 entries
template <>
struct hash<dim4> : hash<pos4>
{};
} // end namespace cuda::experimental::stf