28 #include <curand_kernel.h> 42 Distribution dist, int64_t seed,
int dim_contiguous,
int dim_strided, T *tensor,
int ldm) {
43 __shared__ curandState_t rng_state[1024];
45 uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
47 curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
49 int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
50 int s_idx = blockIdx.y * blockDim.x;
52 tensor += s_idx * ldm + c_idx;
54 for (
int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
55 if (s_idx < dim_strided && c_idx < dim_contiguous) {
58 double rnd = curand_uniform(&rng_state[threadIdx.x]);
60 rnd = dist.
uniform.min + range * rnd;
65 rnd = double(
int(rnd *
double(1 << dist.
int_scale)));
66 *tensor = T(rnd /
double(1 << dist.
int_scale));
81 Distribution dist, int64_t seed,
int dim_contiguous,
int dim_strided, T *tensor,
int ldm) {
82 __shared__ curandState_t rng_state[1024];
84 uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
86 curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
88 int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
89 int s_idx = blockIdx.y * blockDim.x;
91 tensor += s_idx * ldm + c_idx;
93 for (
int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
94 if (s_idx < dim_strided && c_idx < dim_contiguous) {
98 double rnd = curand_normal(&rng_state[threadIdx.x]);
103 rnd = double(
int(rnd *
double(1 << dist.
int_scale)));
104 *tensor = T(rnd /
double(1 << dist.
int_scale));
113 template <
typename T>
115 Distribution dist, int64_t seed,
int dim_contiguous,
int dim_strided, T *tensor,
int ldm) {
116 __shared__ curandState_t rng_state[1024];
118 uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
120 curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
122 int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
123 int s_idx = blockIdx.y * blockDim.x;
125 tensor += s_idx * ldm + c_idx;
127 for (
int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
128 if (s_idx < dim_strided && c_idx < dim_contiguous) {
130 dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx;
136 template <
typename T>
138 Distribution dist, int64_t seed,
int dim_contiguous,
int dim_strided, T *tensor,
int ldm) {
139 __shared__ curandState_t rng_state[1024];
141 uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
143 curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
145 int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
146 int s_idx = blockIdx.y * blockDim.x;
148 tensor += s_idx * ldm + c_idx;
150 for (
int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
151 if (s_idx < dim_strided && c_idx < dim_contiguous) {
152 *tensor = (c_idx == s_idx ? T(1) : T(0));
__global__ void TensorInitializeUniform(Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm)
Kernel to initialize tensor to uniform random distribution.
Definition: device/kernel/tensor_elementwise.h:41
Definition: aligned_buffer.h:35
__global__ void TensorInitializeGaussian(Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm)
Kernel to initialize tensor to uniform distribution.
Definition: device/kernel/tensor_elementwise.h:80
struct cutlass::Distribution::@18::@20 uniform
Uniform distribution.
__global__ void TensorInitializeLinear(Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm)
Kernel to initialize tensor to an identity matrix.
Definition: device/kernel/tensor_elementwise.h:114
struct cutlass::Distribution::@18::@21 gaussian
Gaussian distribution.
__global__ void TensorInitializeIdentity(Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm)
Kernel to initialize tensor to an identity matrix.
Definition: device/kernel/tensor_elementwise.h:137
Distribution type.
Definition: distribution.h:38
int int_scale
Random values are cast to integer after scaling by this power of two.
Definition: distribution.h:67
Basic include for CUTLASS.