include/cuda/experimental/__stf/utility/handle.cuh

File members: include/cuda/experimental/__stf/utility/handle.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/experimental/__stf/utility/traits.cuh>
#include <cuda/experimental/__stf/utility/unittest.cuh>

#include <memory> // for ::std::shared_ptr

namespace cuda::experimental::stf::reserved
{

enum handle_flags : unsigned
{
  defaults,
  non_null,
};

inline enum class use_static_cast {} use_static_cast = {};

inline enum class use_dynamic_cast {} use_dynamic_cast = {};

constexpr handle_flags operator|(handle_flags a, handle_flags b)
{
  return handle_flags(static_cast<unsigned>(a) | static_cast<unsigned>(b));
}

constexpr handle_flags operator&(handle_flags a, handle_flags b)
{
  return handle_flags(static_cast<unsigned>(a) & static_cast<unsigned>(b));
}

template <typename T, handle_flags f = handle_flags::defaults>
class handle
{
public:
  handle(handle&)                  = default;
  handle(const handle&)            = default;
  handle(handle&&)                 = default;
  handle& operator=(handle&)       = default;
  handle& operator=(const handle&) = default;
  handle& operator=(handle&&)      = default;

  handle()
  {
    static_assert(!::std::is_constructible_v<T>, "T's default constructor must be protected.");
    if constexpr (f & handle_flags::non_null)
    {
      static_assert(!::std::is_abstract_v<T>,
                    "A non-nullable handle of an abstract type cannot have a default constructor.");
      impl = ::std::make_shared<Derived<T>>();
    }
  }

  template <typename T1, handle_flags f1>
  handle(handle<T1, f1> rhs)
      : impl(mv(rhs.impl))
  {
    if constexpr (f & handle_flags::non_null)
    {
      static_assert(f1 & handle_flags::non_null, "Cannot initialize a non-nullable handle from a nullable one.");
    }
  }

  template <typename... Args>
  handle(Args&&... args)
      : impl(make(::std::forward<Args>(args)...))
  {
    static_assert(!::std::is_constructible_v<T, Args...>, "T's constructors must be protected.");
  }

  template <typename T1, handle_flags f1>
  handle(handle<T1, f1>& src, decltype(use_static_cast))
      : handle(const_cast<const handle<T1, f1>&>(src), use_static_cast)
  {}

  template <typename T1, handle_flags f1>
  handle(const handle<T1, f1>& src, decltype(use_static_cast))
      : handle(::std::static_pointer_cast<T>(src.impl))
  {
    if constexpr (f & handle_flags::non_null)
    {
      EXPECT(src.impl, "Pointer of static type ", type_name<T1>, " was null upon construction of non-null handle.");
      assert(impl);
    }
  }

  template <typename T1, handle_flags f1>
  handle(handle<T1, f1>& src, decltype(use_dynamic_cast))
      : handle(const_cast<const handle<T1, f1>&>(src), use_dynamic_cast)
  {}

  template <typename T1, handle_flags f1>
  handle(const handle<T1, f1>& src, decltype(use_dynamic_cast))
      : handle(::std::dynamic_pointer_cast<T>(src.impl))
  {
    if constexpr (f & handle_flags::non_null)
    {
      EXPECT(src.impl, "Pointer of static type ", type_name<T1>, " was null upon construction of non-null handle.");
      EXPECT(impl, "dynamic_cast<", type_name<T>, "> failed for pointer of static type ", type_name<T1>);
    }
  }

  template <typename T1>
  handle(const ::std::shared_ptr<T1>& src, decltype(use_dynamic_cast))
      : handle(::std::dynamic_pointer_cast<T>(src))
  {}

  template <typename T1, handle_flags f1>
  handle& operator=(handle<T1, f1> rhs)
  {
    if constexpr (f & handle_flags::non_null)
    {
      static_assert(f1 & handle_flags::non_null, "Cannot assign a non-nullable handle from a nullable one.");
    }
    impl = mv(rhs.impl);
    return *this;
  }

  ~handle() = default;

  T* operator->() const
  {
    assert(*this);
    return impl.get();
  }

  T& operator*()
  {
    return *operator->();
  }

  const T& operator*() const
  {
    return *operator->();
  }

  explicit operator bool() const
  {
    return impl.get() != nullptr;
  }

  template <typename T1>
  operator ::std::shared_ptr<T1>() const
  {
    return impl;
  }

  template <typename T1, handle_flags f1>
  bool operator==(const handle<T1, f1>& rhs)
  {
    return impl == rhs.impl;
  }

  using weak_t = ::std::weak_ptr<T>;

  weak_t weak() const
  {
    return impl;
  }

  template <typename Fun>
  static bool if_valid(const weak_t& wp, Fun&& fun)
  {
    if (auto p = wp.lock())
    {
      handle h{mv(p)};
      ::std::forward<Fun>(fun)(mv(h));
      return true;
    }
    return false;
  }

private:
  // All instantiations of handle are friends with one another
  template <typename T1, handle_flags f1>
  friend class handle;

  // Define a derived class to access the protected ctor
  template <class U>
  struct Derived : public U
  {
    template <typename... Args>
    Derived(Args&&... args)
        : U(::std::forward<Args>(args)...)
    {}
  };

  template <typename Arg, typename... Args>
  static auto make(Arg&& arg, Args&&... args)
  {
    if constexpr (sizeof...(args) == 0 && ::std::is_convertible_v<Arg, ::std::shared_ptr<T>>)
    {
      return ::std::forward<Arg>(arg);
    }
    else
    {
      return ::std::make_shared<Derived<T>>(::std::forward<Arg>(arg), ::std::forward<Args>(args)...);
    }
  }

  ::std::shared_ptr<T> impl;
};

#ifdef UNITTESTED_FILE
UNITTEST("Weak handle")
{
  class test
  {
  protected:
    test(int x)
    {
      a = x;
    }

  public:
    int a;
  };
  handle<test> h(42);
  EXPECT(h->a == 42);
  auto w = h.weak();
  handle<test>::if_valid(w, [](handle<test> x) {
    x->a++;
  });
  EXPECT(h->a == 43);
};
#endif // UNITTESTED_FILE

} // namespace cuda::experimental::stf::reserved