include/cuda/experimental/__device/logical_device.cuh

File members: include/cuda/experimental/__device/logical_device.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__DEVICE_LOGICAL_DEVICE
#define _CUDAX__DEVICE_LOGICAL_DEVICE

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/experimental/__device/all_devices.cuh>
#include <cuda/experimental/__green_context/green_ctx.cuh>

namespace cuda::experimental
{
struct __logical_device_access;

class logical_device
{
public:
  enum class kinds
  {
    // Indicates logical device is a full device
    device,
    // Indicated logical device is a green context
    green_context
  };

  // We might want to make this private depending on how this type ends up looking like long term,
  // not documenting it for now
  _CCCL_NODISCARD constexpr CUcontext context() const noexcept
  {
    return __ctx;
  }

  _CCCL_NODISCARD constexpr device_ref get_underlying_device() const noexcept
  {
    return __dev_id;
  }

  _CCCL_NODISCARD constexpr kinds get_kind() const noexcept
  {
    return __kind;
  }

  explicit logical_device(int __id)
      : __dev_id(__id)
      , __kind(kinds::device)
      , __ctx(devices[__id].primary_context())
  {}

  explicit logical_device(device_ref __dev)
      : logical_device(__dev.get())
  {}

  // More of a micro-optimization, we can also remove this (depending if we keep device_ref)
  logical_device(const ::cuda::experimental::device& __dev)
      : __dev_id(__dev.get())
      , __kind(kinds::device)
      , __ctx(__dev.primary_context())
  {}

#if CUDART_VERSION >= 12050
  logical_device(const green_context& __gctx)
      : __dev_id(__gctx.__dev_id)
      , __kind(kinds::green_context)
      , __ctx(__gctx.__transformed)
  {}
#endif // CUDART_VERSION >= 12050

  _CCCL_NODISCARD_FRIEND bool operator==(logical_device __lhs, logical_device __rhs) noexcept
  {
    return __lhs.__ctx == __rhs.__ctx;
  }

#if _CCCL_STD_VER <= 2017
  _CCCL_NODISCARD_FRIEND bool operator!=(logical_device __lhs, logical_device __rhs) noexcept
  {
    return __lhs.__ctx != __rhs.__ctx;
  }
#endif // _CCCL_STD_VER <= 2017

private:
  friend __logical_device_access;
  // This might be a CUdevice as well
  int __dev_id = 0;
  kinds __kind;
  CUcontext __ctx = nullptr;

  logical_device(int __id, CUcontext __context, kinds __k)
      : __dev_id(__id)
      , __kind(__k)
      , __ctx(__context)
  {}
};

struct __logical_device_access
{
  static logical_device make_logical_device(int __id, CUcontext __context, logical_device::kinds __k)
  {
    return logical_device(__id, __context, __k);
  }
};

} // namespace cuda::experimental

#endif // _CUDAX__DEVICE_DEVICE_REF