thrust/iterator/shuffle_iterator.h

File members: thrust/iterator/shuffle_iterator.h

/*
 *  Copyright 2025 NVIDIA Corporation
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <thrust/detail/random_bijection.h>
#include <thrust/detail/type_traits.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_traits.h>

#include <cuda/std/type_traits>

THRUST_NAMESPACE_BEGIN

template <class IndexType, class BijectionFunc>
class shuffle_iterator;

namespace detail
{
template <class IndexType, class BijectionFunc>
struct make_shuffle_iterator_base
{
  static_assert(::cuda::std::is_integral_v<IndexType>, "IndexType must be an integral type");

  using system     = any_system_tag;
  using traversal  = random_access_traversal_tag;
  using difference = ::cuda::std::_If<sizeof(IndexType) < sizeof(int), int, ::cuda::std::ptrdiff_t>;

  using type =
    iterator_adaptor<shuffle_iterator<IndexType, BijectionFunc>,
                     IndexType,
                     IndexType,
                     system,
                     traversal,
                     IndexType,
                     difference>;
};
} // namespace detail

template <class IndexType, class BijectionFunc = thrust::detail::random_bijection<IndexType>>
class shuffle_iterator : public detail::make_shuffle_iterator_base<IndexType, BijectionFunc>::type
{
  using super_t = typename detail::make_shuffle_iterator_base<IndexType, BijectionFunc>::type;
  friend class iterator_core_access;

public:
  template <class URBG,
            class Enable = ::cuda::std::enable_if_t<::cuda::std::is_constructible_v<BijectionFunc, IndexType, URBG&&>>>
  _CCCL_HOST_DEVICE shuffle_iterator(IndexType n, URBG&& g)
      : super_t(IndexType{0})
      , bijection(n, ::cuda::std::forward<URBG>(g))
  {}

  _CCCL_HOST_DEVICE shuffle_iterator(BijectionFunc bijection)
      : super_t(IndexType{0})
      , bijection(std::move(bijection))
  {}

private:
  _CCCL_HOST_DEVICE IndexType dereference() const
  {
    assert(this->base() < bijection.size());
    return bijection(this->base());
  }

  BijectionFunc bijection;

};

template <class IndexType, class URBG>
_CCCL_HOST_DEVICE shuffle_iterator<IndexType> make_shuffle_iterator(IndexType n, URBG&& g)
{
  return shuffle_iterator<IndexType>(n, ::cuda::std::forward<URBG>(g));
} // end make_shuffle_iterator

THRUST_NAMESPACE_END