thrust/iterator/shuffle_iterator.h
File members: thrust/iterator/shuffle_iterator.h
/*
* Copyright 2025 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <thrust/detail/config.h>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <thrust/detail/random_bijection.h>
#include <thrust/detail/type_traits.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_traits.h>
#include <cuda/std/type_traits>
THRUST_NAMESPACE_BEGIN
template <class IndexType, class BijectionFunc>
class shuffle_iterator;
namespace detail
{
template <class IndexType, class BijectionFunc>
struct make_shuffle_iterator_base
{
static_assert(::cuda::std::is_integral_v<IndexType>, "IndexType must be an integral type");
using system = any_system_tag;
using traversal = random_access_traversal_tag;
using difference = ::cuda::std::_If<sizeof(IndexType) < sizeof(int), int, ::cuda::std::ptrdiff_t>;
using type =
iterator_adaptor<shuffle_iterator<IndexType, BijectionFunc>,
IndexType,
IndexType,
system,
traversal,
IndexType,
difference>;
};
} // namespace detail
template <class IndexType, class BijectionFunc = thrust::detail::random_bijection<IndexType>>
class shuffle_iterator : public detail::make_shuffle_iterator_base<IndexType, BijectionFunc>::type
{
using super_t = typename detail::make_shuffle_iterator_base<IndexType, BijectionFunc>::type;
friend class iterator_core_access;
public:
template <class URBG,
class Enable = ::cuda::std::enable_if_t<::cuda::std::is_constructible_v<BijectionFunc, IndexType, URBG&&>>>
_CCCL_HOST_DEVICE shuffle_iterator(IndexType n, URBG&& g)
: super_t(IndexType{0})
, bijection(n, ::cuda::std::forward<URBG>(g))
{}
_CCCL_HOST_DEVICE shuffle_iterator(BijectionFunc bijection)
: super_t(IndexType{0})
, bijection(std::move(bijection))
{}
private:
_CCCL_HOST_DEVICE IndexType dereference() const
{
assert(this->base() < bijection.size());
return bijection(this->base());
}
BijectionFunc bijection;
};
template <class IndexType, class URBG>
_CCCL_HOST_DEVICE shuffle_iterator<IndexType> make_shuffle_iterator(IndexType n, URBG&& g)
{
return shuffle_iterator<IndexType>(n, ::cuda::std::forward<URBG>(g));
} // end make_shuffle_iterator
THRUST_NAMESPACE_END