cast¶
- tripy.cast(input: Tensor, dtype: dtype) Tensor [source]¶
Returns a tensor with the contents of the input tensor casted to the specified data type.
For casts into quantized datatypes (
int4
andfloat8
), this performs a per-tensor quantization into that datatype with scale 1.0; for casts from those datatypes, this performs a per-tensor dequantization with scale 1.0. Direct use ofquantize()
anddequantize()
allows for finer control over these parameters.- Parameters:
- Returns:
[dtype=T2] A tensor containing the casted values.
- Return type:
- TYPE CONSTRAINTS:
- UNSUPPORTED TYPE COMBINATIONS:
Example
1input = tp.Tensor([1, 2], dtype=tp.int32) 2output = tp.cast(input, tp.float32)
>>> input tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,)) >>> output tensor([1.0000, 2.0000], dtype=float32, loc=gpu:0, shape=(2,))
See also