CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
host/tensor_fill.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 /* \file
26  \brief Provides several functions for filling tensors with data.
27 */
28 
29 #pragma once
30 
31 // Standard Library includes
32 #include <utility>
33 #include <cstdlib>
34 #include <cmath>
35 
36 // Cutlass includes
37 #include "cutlass/cutlass.h"
38 #include "cutlass/complex.h"
39 #include "cutlass/array.h"
40 #include "cutlass/numeric_types.h"
41 
43 #include "tensor_foreach.h"
44 
46 
47 namespace cutlass {
48 namespace reference {
49 namespace host {
50 
53 
54 namespace detail {
55 
56 template <
57  typename Element,
58  typename Layout>
60 
62 
63  //
64  // Data members
65  //
66 
68  Element value;
69 
70  //
71  // Methods
72  //
73 
75  TensorView const &view_ = TensorView(),
76  Element value_ = Element(0)
77  ): view(view_), value(value_) { }
78 
79  void operator()(Coord<Layout::kRank> const & coord) const {
80  view.at(coord) = value;
81  }
82 };
83 
84 } // namespace detail
85 
87 
89 template <
90  typename Element,
91  typename Layout>
94  Element val = Element(0)) {
95 
97 
99  dst.extent(),
100  func
101  );
102 }
103 
106 
107 namespace detail {
108 
109 template <typename Element>
111 
112  uint64_t seed;
113  double mean;
114  double stddev;
116  double pi;
117 
118  //
119  // Methods
120  //
122  uint64_t seed_ = 0,
123  double mean_ = 0,
124  double stddev_ = 1,
125  int int_scale_ = -1
126  ):
127  seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) {
128  std::srand((unsigned)seed);
129  }
130 
132  Element operator()() const {
133 
134  // Box-Muller transform to generate random numbers with Normal distribution
135  double u1 = double(std::rand()) / double(RAND_MAX);
136  double u2 = double(std::rand()) / double(RAND_MAX);
137 
138  // Compute Gaussian random value
139  double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
140  rnd = mean + stddev * rnd;
141 
142  // Scale and convert final result
143  Element result;
144 
145  if (int_scale >= 0) {
146  rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
147  result = static_cast<Element>(rnd);
148  }
149  else {
150  result = static_cast<Element>(rnd);
151  }
152 
153  return result;
154  }
155 };
156 
158 template <typename Element>
159 struct RandomGaussianFunc<complex<Element> > {
160 
161  uint64_t seed;
162  double mean;
163  double stddev;
165  double pi;
166 
167  //
168  // Methods
169  //
171  uint64_t seed_ = 0,
172  double mean_ = 0,
173  double stddev_ = 1,
174  int int_scale_ = -1
175  ):
176  seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) {
177  std::srand((unsigned)seed);
178  }
179 
182 
183  Element reals[2];
184 
185  for (int i = 0; i < 2; ++i) {
186  // Box-Muller transform to generate random numbers with Normal distribution
187  double u1 = double(std::rand()) / double(RAND_MAX);
188  double u2 = double(std::rand()) / double(RAND_MAX);
189 
190  // Compute Gaussian random value
191  double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
192  rnd = mean + stddev * rnd;
193 
194  if (int_scale >= 0) {
195  rnd = double(int(rnd * double(1 << int_scale)));
196  reals[i] = from_real<Element>(rnd / double(1 << int_scale));
197  }
198  else {
199  reals[i] = from_real<Element>(rnd);
200  }
201  }
202 
203  return complex<Element>(reals[0], reals[1]);
204  }
205 };
206 
208 template <
209  typename Element,
210  typename Layout>
212 
214 
215  //
216  // Data members
217  //
218 
221 
222  //
223  // Methods
224  //
225 
228  TensorView view_ = TensorView(),
230  ):
231  view(view_), func(func_) {
232 
233  }
234 
236  void operator()(Coord<Layout::kRank> const &coord) const {
237  view.at(coord) = func();
238  }
239 };
240 
241 } // namespace detail
242 
244 
246 template <
247  typename Element,
248  typename Layout>
251  uint64_t seed,
252  double mean = 0,
253  double stddev = 1,
254  int bits = -1) {
255 
258  detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits);
259 
261  dst,
262  random_func
263  );
264 
266  dst.extent(),
267  func
268  );
269 }
270 
272 
274 template <
275  typename Element
276 >
278  Element *ptr,
279  size_t capacity,
280  uint64_t seed,
281  double mean = 0,
282  double stddev = 1,
283  int bits = -1) {
284 
287 
288  detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits);
289 
290  for (size_t i = 0; i < capacity; ++i) {
291  ptr[i] = random_func();
292  }
293 }
294 
297 
298 namespace detail {
299 
300 template <typename Element>
302 
303  using Real = typename RealType<Element>::Type;
304 
305  uint64_t seed;
306  double range;
307  double min;
309 
310  //
311  // Methods
312  //
313 
315  uint64_t seed_ = 0,
316  double max = 1,
317  double min_ = 0,
318  int int_scale_ = -1
319  ):
320  seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
321  std::srand((unsigned)seed);
322  }
323 
324 
326  Element operator()() const {
327 
328  double rnd = double(std::rand()) / double(RAND_MAX);
329 
330  rnd = min + range * rnd;
331 
332  // Random values are cast to integer after scaling by a power of two to facilitate error
333  // testing
334  Element result;
335 
336  if (int_scale >= 0) {
337  rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
338  result = static_cast<Element>(Real(rnd));
339  }
340  else {
341  result = static_cast<Element>(Real(rnd));
342  }
343 
344  return result;
345  }
346 };
347 
349 template <typename Element>
350 struct RandomUniformFunc<complex<Element> > {
351 
352  using Real = typename RealType<Element>::Type;
353 
354  uint64_t seed;
355  double range;
356  double min;
358 
359  //
360  // Methods
361  //
362 
364  uint64_t seed_ = 0,
365  double max = 1,
366  double min_ = 0,
367  int int_scale_ = -1
368  ):
369  seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
370  std::srand((unsigned)seed);
371  }
372 
373 
375  complex<Element> operator()() const {
376 
377  Element reals[2];
378 
379  for (int i = 0; i < 2; ++i) {
380  double rnd = double(std::rand()) / double(RAND_MAX);
381 
382  rnd = min + range * rnd;
383 
384  // Random values are cast to integer after scaling by a power of two to facilitate error
385  // testing
386 
387  if (int_scale >= 0) {
388  rnd = double(int(rnd * double(1 << int_scale)));
389  reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
390  }
391  else {
392  reals[i] = from_real<Element>(Real(rnd));
393  }
394  }
395 
396  return complex<Element>(reals[0], reals[1]);
397  }
398 };
399 
401 template <
402  typename Element,
403  typename Layout>
405 
407 
408  //
409  // Data members
410  //
411 
414 
415  //
416  // Methods
417  //
418 
421  TensorView view_ = TensorView(),
423  ):
424  view(view_), func(func_) {
425 
426  }
427 
429  void operator()(Coord<Layout::kRank> const &coord) const {
430 
431  view.at(coord) = func();
432  }
433 };
434 
435 } // namespace detail
436 
438 
440 template <
441  typename Element,
442  typename Layout>
445  uint64_t seed,
446  double max = 1,
447  double min = 0,
448  int bits = -1) {
449  detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
452 
454  dst,
455  random_func
456  );
457 
459  dst.extent(),
460  func
461  );
462 }
463 
465 
467 template <
468  typename Element
469 >
471  Element *ptr,
472  size_t capacity,
473  uint64_t seed,
474  double max = 1,
475  double min = 0,
476  int bits = -1) {
477  detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
480 
481  for (size_t i = 0; i < capacity; ++i) {
482  ptr[i] = random_func();
483  }
484 }
485 
488 
489 namespace detail {
490 
491 template <
492  typename Element,
493  typename Layout>
495 
497 
498  //
499  // Data members
500  //
501 
503  Element diag;
504  Element other;
505 
506  //
507  // Methods
508  //
509 
511  TensorView const &view_ = TensorView(),
512  Element diag_ = Element(1),
513  Element other_ = Element(0)
514  ):
515  view(view_), diag(diag_), other(other_) { }
516 
517  void operator()(Coord<Layout::kRank> const & coord) const {
518  bool is_diag = true;
519 
521  for (int i = 1; i < Layout::kRank; ++i) {
522  if (coord[i] != coord[i - 1]) {
523  is_diag = false;
524  break;
525  }
526  }
527 
528  view.at(coord) = (is_diag ? diag : other);
529  }
530 };
531 
532 } // namespace detail
533 
535 
537 template <
538  typename Element,
539  typename Layout>
542  Element diag = Element(1),
543  Element other = Element(0)) {
544 
546  dst,
547  diag,
548  other
549  );
550 
552  dst.extent(),
553  func
554  );
555 }
556 
559 
561 template <
562  typename Element,
563  typename Layout>
566 
567  TensorFillDiagonal(dst, Element(1), Element(0));
568 }
569 
572 
574 template <
575  typename Element,
576  typename Layout>
579  Element val = Element(1)) {
580 
581  typename Layout::Index extent = dst.extent().min();
582 
583  for (typename Layout::Index i = 0; i < extent; ++i) {
584  Coord<Layout::kRank> coord(i);
585  dst.at(coord) = val;
586  }
587 }
588 
591 
592 namespace detail {
593 
594 template <
595  typename Element,
596  typename Layout>
598 
600 
601  //
602  // Data members
603  //
604 
606  Element other;
607 
608  //
609  // Methods
610  //
611 
613  TensorView const &view_ = TensorView(),
614  Element other_ = Element(0)
615  ):
616  view(view_), other(other_) { }
617 
618  void operator()(Coord<Layout::kRank> const & coord) const {
619  bool is_diag = true;
620 
622  for (int i = 1; i < Layout::kRank; ++i) {
623  if (coord[i] != coord[i - 1]) {
624  is_diag = false;
625  break;
626  }
627  }
628 
629  if (!is_diag) {
630  view.at(coord) = other;
631  }
632  }
633 };
634 
635 } // namespace detail
636 
638 
640 template <
641  typename Element,
642  typename Layout>
645  Element other = Element(1)) {
646 
648  dst,
649  other
650  );
651 
653  dst.extent(),
654  func
655  );
656 }
657 
658 
661 
662 namespace detail {
663 
664 template <
665  typename Element,
666  typename Layout>
668 
670 
671  //
672  // Data members
673  //
674 
676  Array<Element, Layout::kRank> v;
677  Element s;
678 
679  //
680  // Methods
681  //
682 
684 
687  TensorView const &view_,
688  Array<Element, Layout::kRank> const & v_,
689  Element s_ = Element(0)
690  ):
691  view(view_), v(v_), s(s_) { }
692 
694  void operator()(Coord<Layout::kRank> const & coord) const {
695 
696  Element sum(s);
697 
699  for (int i = 0; i < Layout::kRank; ++i) {
700  sum += Element(coord[i]) * v[i];
701  }
702 
703  view.at(coord) = sum;
704  }
705 };
706 
707 } // namespace detail
708 
710 
712 template <
713  typename Element,
714  typename Layout>
717  Array<Element, Layout::kRank> const & v,
718  Element s = Element(0)) {
719 
721  dst,
722  v,
723  s
724  );
725 
727  dst.extent(),
728  func
729  );
730 }
731 
733 
735 template <
736  typename Element,
737  typename Layout>
740  Element s = Element(0)) {
741 
742  Array<Element, Layout::kRank> stride;
743 
744  stride[0] = Element(1);
745 
747  for (int i = 1; i < Layout::kRank; ++i) {
748  stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]);
749  }
750 
751  TensorFillLinear(dst, stride, s);
752 }
753 
756 
758 template <
759  typename Element
760 >
762  Element *ptr,
763  int64_t capacity,
764  Element v = Element(1),
765  Element s = Element(0)) {
766  int i = 0;
767 
768  while (i < capacity) {
770  8)>::get(ptr, i) = s;
771 
772  s = Element(s + v);
773  ++i;
774  }
775 }
776 
779 
781 template <
782  typename Element
783 >
785  Element *ptr,
786  size_t capacity,
787  uint64_t seed,
788  Distribution dist) {
789 
790  if (dist.kind == Distribution::Gaussian) {
791  BlockFillRandomGaussian<Element>(
792  ptr,
793  capacity,
794  seed,
795  dist.gaussian.mean,
796  dist.gaussian.stddev,
797  dist.int_scale);
798  }
799  else if (dist.kind == Distribution::Uniform) {
800  BlockFillRandomUniform<Element>(
801  ptr,
802  capacity,
803  seed,
804  dist.uniform.max,
805  dist.uniform.min,
806  dist.int_scale);
807  }
808 }
809 
812 
814 template <
815  typename Element,
816  typename Layout>
819  Element const *ptr) {
820 
821  typename Layout::Index extent = dst.extent().min();
822 
823  for (typename Layout::Index i = 0; i < extent; ++i) {
824  Coord<Layout::kRank> coord(i);
825  dst.at(coord) = ptr[i];
826  }
827 }
828 
831 
833 template <
834  typename Element,
835  typename Layout>
837  Element *ptr,
839 
840  typename Layout::Index extent = src.extent().min();
841 
842  for (typename Layout::Index i = 0; i < extent; ++i) {
843  Coord<Layout::kRank> coord(i);
844  ptr[i] = src.at(coord);
845  }
846 }
847 
850 
851 } // namespace host
852 } // namespace reference
853 } // namespace cutlass
uint64_t seed
Definition: host/tensor_fill.h:112
void operator()(Coord< Layout::kRank > const &coord) const
Compute random value and update RNG state.
Definition: host/tensor_fill.h:429
CUTLASS_HOST_DEVICE complex< T > cos(complex< T > const &z)
Computes the cosine of complex z.
Definition: complex.h:401
CUTLASS_HOST_DEVICE constexpr const T & max(const T &a, const T &b)
std::max
Definition: platform.h:189
typename RealType< Element >::Type Real
Definition: host/tensor_fill.h:303
void TensorCopyDiagonalOut(Element *ptr, TensorView< Element, Layout > src)
Copies the diagonal of a tensor into a dense buffer in host memory.
Definition: host/tensor_fill.h:836
Definition: aligned_buffer.h:35
Definition: distribution.h:40
< Layout function
Definition: host/tensor_fill.h:494
RandomUniformFunc(uint64_t seed_=0, double max=1, double min_=0, int int_scale_=-1)
Definition: host/tensor_fill.h:363
Definition: distribution.h:40
struct cutlass::Distribution::@18::@20 uniform
Uniform distribution.
TensorView< Element, Layout > TensorView
Definition: host/tensor_fill.h:61
void operator()(Coord< Layout::kRank > const &coord) const
Definition: host/tensor_fill.h:517
Element diag
Definition: host/tensor_fill.h:503
Element operator()() const
Compute random value and update RNG state.
Definition: host/tensor_fill.h:132
T Type
Definition: real.h:32
Kind kind
Active variant kind.
Definition: distribution.h:64
void TensorFill(TensorView< Element, Layout > dst, Element val=Element(0))
Fills a tensor with a uniform value.
Definition: host/tensor_fill.h:92
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
RandomUniformFunc(uint64_t seed_=0, double max=1, double min_=0, int int_scale_=-1)
Definition: host/tensor_fill.h:314
RandomGaussianFunc(uint64_t seed_=0, double mean_=0, double stddev_=1, int int_scale_=-1)
Definition: host/tensor_fill.h:170
struct cutlass::Distribution::@18::@21 gaussian
Gaussian distribution.
void operator()(Coord< Layout::kRank > const &coord) const
Definition: host/tensor_fill.h:79
STL namespace.
TensorView view
Definition: host/tensor_fill.h:605
void TensorFillDiagonal(TensorView< Element, Layout > dst, Element diag=Element(1), Element other=Element(0))
Fills a tensor everywhere with a unique value for its diagonal.
Definition: host/tensor_fill.h:540
< Layout function
Definition: host/tensor_fill.h:667
int int_scale
Definition: host/tensor_fill.h:115
< Layout function
Definition: host/tensor_fill.h:597
void TensorFillIdentity(TensorView< Element, Layout > dst)
Helper to fill a tensor&#39;s diagonal with 1 and 0 everywhere else.
Definition: host/tensor_fill.h:564
complex< Element > operator()() const
Compute random value and update RNG state.
Definition: host/tensor_fill.h:375
CUTLASS_HOST_DEVICE complex< T > log(complex< T > const &z)
Computes the complex exponential of z.
Definition: complex.h:381
void operator()(Coord< Layout::kRank > const &coord) const
Compute random value and update RNG state.
Definition: host/tensor_fill.h:236
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
double mean
Definition: host/tensor_fill.h:113
void TensorUpdateOffDiagonal(TensorView< Element, Layout > dst, Element other=Element(1))
Writes a uniform value to all elements in the tensor without modifying diagonal elements.
Definition: host/tensor_fill.h:643
Element s
Definition: host/tensor_fill.h:677
Array< Element, Layout::kRank > v
Definition: host/tensor_fill.h:676
typename RealType< Element >::Type Real
Definition: host/tensor_fill.h:352
TensorView view
Definition: host/tensor_fill.h:219
void TensorFillRandomGaussian(TensorView< Element, Layout > dst, uint64_t seed, double mean=0, double stddev=1, int bits=-1)
Fills a tensor with random values with a Gaussian distribution.
Definition: host/tensor_fill.h:249
complex< Element > operator()() const
Compute random value and update RNG state.
Definition: host/tensor_fill.h:181
void TensorFillLinear(TensorView< Element, Layout > dst, Array< Element, Layout::kRank > const &v, Element s=Element(0))
Fills tensor with a linear combination of its coordinate and another vector.
Definition: host/tensor_fill.h:715
Computes a random Gaussian distribution.
Definition: host/tensor_fill.h:211
void TensorUpdateDiagonal(TensorView< Element, Layout > dst, Element val=Element(1))
Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements...
Definition: host/tensor_fill.h:577
TensorView view
Definition: host/tensor_fill.h:412
Definition: subbyte_reference.h:557
void operator()(Coord< Layout::kRank > const &coord) const
Updates the tensor.
Definition: host/tensor_fill.h:694
void BlockFillRandomGaussian(Element *ptr, size_t capacity, uint64_t seed, double mean=0, double stddev=1, int bits=-1)
Fills a tensor with random values with a Gaussian distribution.
Definition: host/tensor_fill.h:277
This header contains a class to parametrize a statistical distribution function.
TensorView view
Definition: host/tensor_fill.h:502
Element value
Definition: host/tensor_fill.h:68
Element operator()() const
Compute random value and update RNG state.
Definition: host/tensor_fill.h:326
Definition: host/tensor_fill.h:301
Top-level include for all CUTLASS numeric types.
TensorFillLinearFunc(TensorView const &view_, Array< Element, Layout::kRank > const &v_, Element s_=Element(0))
Constructs functor.
Definition: host/tensor_fill.h:686
CUTLASS_HOST_DEVICE constexpr const T & min(const T &a, const T &b)
std::min
Definition: platform.h:183
double pi
Definition: host/tensor_fill.h:116
void operator()(Coord< Layout::kRank > const &coord) const
Definition: host/tensor_fill.h:618
void BlockFillSequential(Element *ptr, int64_t capacity, Element v=Element(1), Element s=Element(0))
Fills a block of data with sequential elements.
Definition: host/tensor_fill.h:761
double stddev
Definition: host/tensor_fill.h:114
TensorUpdateOffDiagonalFunc(TensorView const &view_=TensorView(), Element other_=Element(0))
Definition: host/tensor_fill.h:612
TensorFillGaussianFunc(TensorView view_=TensorView(), RandomGaussianFunc< Element > func_=RandomGaussianFunc< Element >())
Construction of Gaussian RNG functor.
Definition: host/tensor_fill.h:227
Element other
Definition: host/tensor_fill.h:606
TensorFillDiagonalFunc(TensorView const &view_=TensorView(), Element diag_=Element(1), Element other_=Element(0))
Definition: host/tensor_fill.h:510
RandomGaussianFunc(uint64_t seed_=0, double mean_=0, double stddev_=1, int int_scale_=-1)
Definition: host/tensor_fill.h:121
TensorView view
Definition: host/tensor_fill.h:67
Definition: complex.h:92
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
double range
Definition: host/tensor_fill.h:306
TensorFillLinearFunc()
Definition: host/tensor_fill.h:683
void BlockFillRandomUniform(Element *ptr, size_t capacity, uint64_t seed, double max=1, double min=0, int bits=-1)
Fills a tensor with random values with a uniform random distribution.
Definition: host/tensor_fill.h:470
RandomUniformFunc< Element > func
Definition: host/tensor_fill.h:413
void TensorForEach(Coord< Rank > extent, Func &func)
Iterates over the index space of a tensor.
Definition: host/tensor_foreach.h:87
void TensorFillRandomUniform(TensorView< Element, Layout > dst, uint64_t seed, double max=1, double min=0, int bits=-1)
Fills a tensor with random values with a uniform random distribution.
Definition: host/tensor_fill.h:443
void TensorCopyDiagonalIn(TensorView< Element, Layout > dst, Element const *ptr)
Copies a diagonal in from host memory without modifying off-diagonal elements.
Definition: host/tensor_fill.h:817
double min
Definition: host/tensor_fill.h:307
Distribution type.
Definition: distribution.h:38
RandomGaussianFunc< Element > func
Definition: host/tensor_fill.h:220
void TensorFillSequential(TensorView< Element, Layout > dst, Element s=Element(0))
Fills tensor with a linear combination of its coordinate and another vector.
Definition: host/tensor_fill.h:738
< Layout function
Definition: host/tensor_fill.h:59
int int_scale
Random values are cast to integer after scaling by this power of two.
Definition: distribution.h:67
TensorFillFunc(TensorView const &view_=TensorView(), Element value_=Element(0))
Definition: host/tensor_fill.h:74
Computes a random Gaussian distribution.
Definition: host/tensor_fill.h:404
Basic include for CUTLASS.
TensorView view
Definition: host/tensor_fill.h:675
uint64_t seed
Definition: host/tensor_fill.h:305
TensorFillRandomUniformFunc(TensorView view_=TensorView(), RandomUniformFunc< Element > func_=RandomUniformFunc< Element >())
Construction of Gaussian RNG functor.
Definition: host/tensor_fill.h:420
CUTLASS_HOST_DEVICE complex< T > sqrt(complex< T > const &z)
Computes the square root of complex number z.
Definition: complex.h:393
int int_scale
Definition: host/tensor_fill.h:308
Element other
Definition: host/tensor_fill.h:504
void BlockFillRandom(Element *ptr, size_t capacity, uint64_t seed, Distribution dist)
Fills a block of data with sequential elements.
Definition: host/tensor_fill.h:784