thrust/iterator/strided_iterator.h
File members: thrust/iterator/strided_iterator.h
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA Corporation
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <thrust/detail/config.h>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_traits.h>
THRUST_NAMESPACE_BEGIN
template <typename T>
struct runtime_value
{
T value;
};
// we cannot use ::cuda::std::integral_constant, because it has a conversion operator to T that causes an ambiguity
// with operator+(counting_iterator, counting_iterator::difference_type) in any expression `counting_iterator +
// integral`.
template <auto Value>
struct compile_time_value
{
static constexpr decltype(Value) value = Value;
};
namespace detail
{
template <typename T>
inline constexpr bool is_compile_time_value = false;
template <auto Value>
inline constexpr bool is_compile_time_value<compile_time_value<Value>> = true;
} // namespace detail
template <typename RandomAccessIterator, typename StrideHolder>
class _CCCL_DECLSPEC_EMPTY_BASES strided_iterator
: public iterator_adaptor<strided_iterator<RandomAccessIterator, StrideHolder>, RandomAccessIterator>
, StrideHolder
{
using super_t = iterator_adaptor<strided_iterator, RandomAccessIterator>;
friend class iterator_core_access;
public:
using difference_type = typename super_t::difference_type;
static_assert(::cuda::std::random_access_iterator<RandomAccessIterator>,
"The iterator underlying a strided_iterator must be a random access iterator.");
static_assert(::cuda::std::is_same_v<iterator_traversal_t<RandomAccessIterator>, random_access_traversal_tag>);
static_assert(::cuda::std::is_convertible_v<decltype(StrideHolder::value), difference_type>,
"The stride must be convertible to the iterator's difference_type");
strided_iterator() = default;
_CCCL_HOST_DEVICE strided_iterator(RandomAccessIterator it, StrideHolder stride = {})
: super_t(it)
, StrideHolder(stride)
{}
static constexpr bool has_static_stride = detail::is_compile_time_value<StrideHolder>;
_CCCL_HOST_DEVICE const auto& stride_holder() const
{
return static_cast<const StrideHolder&>(*this);
}
_CCCL_HOST_DEVICE auto stride() const -> difference_type
{
return static_cast<detail::it_difference_t<RandomAccessIterator>>(stride_holder().value);
}
private:
_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void advance(difference_type n)
{
this->base_reference() += n * stride();
}
_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void increment()
{
this->base_reference() += stride();
}
_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void decrement()
{
this->base_reference() -= stride();
}
template <typename OtherStrideHolder>
_CCCL_HOST_DEVICE bool equal(strided_iterator<RandomAccessIterator, OtherStrideHolder> const& other) const
{
return this->base() == other.base();
}
_CCCL_HOST_DEVICE difference_type distance_to(strided_iterator const& other) const
{
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(dist % stride() == 0, "Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
};
template <typename Iterator, typename Stride>
_CCCL_HOST_DEVICE auto make_strided_iterator(Iterator it, Stride stride)
{
return strided_iterator<Iterator, runtime_value<Stride>>(it, {stride});
}
template <auto Stride, typename Iterator>
_CCCL_HOST_DEVICE auto make_strided_iterator(Iterator it)
{
return strided_iterator<Iterator, compile_time_value<Stride>>(it, {});
}
THRUST_NAMESPACE_END