cast

tripy.cast(input: Tensor, dtype: dtype) Tensor[source]

Returns a tensor with the contents of the input tensor casted to the specified data type.

For casts into quantized datatypes (int4 and float8), this performs a per-tensor quantization into that datatype with scale 1.0; for casts from those datatypes, this performs a per-tensor dequantization with scale 1.0. Direct use of quantize() and dequantize() allows for finer control over these parameters.

Parameters:
  • input (Tensor) – [dtype=T1] The input tensor.

  • dtype (dtype) – [dtype=T2] The desired data type.

Returns:

[dtype=T2] A tensor containing the casted values.

Return type:

Tensor

TYPE CONSTRAINTS:
UNSUPPORTED TYPE COMBINATIONS:
Example
Example
1input = tp.Tensor([1, 2], dtype=tp.int32)
2output = tp.cast(input, tp.float32)
>>> input
tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
>>> output
tensor([1.0000, 2.0000], dtype=float32, loc=gpu:0, shape=(2,))