thrust/iterator/tabulate_output_iterator.h

File members: thrust/iterator/tabulate_output_iterator.h

// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <thrust/detail/config.h>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/tabulate_output_iterator.h>

THRUST_NAMESPACE_BEGIN

template <typename BinaryFunction, typename System, typename DifferenceT>
class tabulate_output_iterator;

namespace detail
{
// Proxy reference that invokes a BinaryFunction with the index of the dereferenced iterator and the assigned value
template <typename BinaryFunction, typename DifferenceT>
class tabulate_output_iterator_proxy
{
public:
  _CCCL_HOST_DEVICE tabulate_output_iterator_proxy(BinaryFunction fun, DifferenceT index)
      : fun(fun)
      , index(index)
  {}

  _CCCL_EXEC_CHECK_DISABLE
  template <typename T>
  _CCCL_HOST_DEVICE tabulate_output_iterator_proxy operator=(const T& x)
  {
    fun(index, x);
    return *this;
  }

private:
  BinaryFunction fun;
  DifferenceT index;
};

// Alias template for the iterator_adaptor instantiation to be used for tabulate_output_iterator
template <typename BinaryFunction, typename System, typename DifferenceT>
using make_tabulate_output_iterator_base =
  iterator_adaptor<tabulate_output_iterator<BinaryFunction, System, DifferenceT>,
                   counting_iterator<DifferenceT>,
                   void,
                   System,
                   use_default,
                   tabulate_output_iterator_proxy<BinaryFunction, DifferenceT>>;

// Register tabulate_output_iterator_proxy with 'is_proxy_reference' from
// type_traits to enable its use with algorithms.
template <class BinaryFunction, class OutputIterator>
inline constexpr bool is_proxy_reference_v<tabulate_output_iterator_proxy<BinaryFunction, OutputIterator>> = true;
} // namespace detail

template <typename BinaryFunction, typename System = use_default, typename DifferenceT = ptrdiff_t>
class tabulate_output_iterator : public detail::make_tabulate_output_iterator_base<BinaryFunction, System, DifferenceT>
{
public:
  using super_t = detail::make_tabulate_output_iterator_base<BinaryFunction, System, DifferenceT>;

  friend class iterator_core_access;

  tabulate_output_iterator() = default;

  _CCCL_HOST_DEVICE tabulate_output_iterator(BinaryFunction fun)
      : fun(fun)
  {}

private:
  _CCCL_HOST_DEVICE typename super_t::reference dereference() const
  {
    return detail::tabulate_output_iterator_proxy<BinaryFunction, DifferenceT>{fun, *this->base()};
  }

  BinaryFunction fun;
};

template <typename BinaryFunction>
tabulate_output_iterator<BinaryFunction> _CCCL_HOST_DEVICE make_tabulate_output_iterator(BinaryFunction fun)
{
  return tabulate_output_iterator<BinaryFunction>(fun);
}

THRUST_NAMESPACE_END