Jax¶
- class transformer_engine.jax.MajorShardingType¶
- class transformer_engine.jax.ShardingType¶
- class transformer_engine.jax.flax.TransformerLayerType¶
TransformerLayerType is an Enum class to specify a type of TransformerLayer
- Values
ENCODER – Encoder type of TransformerLayer.
DECODER – Decoder type of TransformerLayer.
- class transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)¶
- transformer_engine.jax.fp8_autocast(enabled: bool = False, fp8_recipe: Optional[transformer_engine.common.recipe.DelayedScaling] = None, mesh_resource: Optional[transformer_engine.jax.sharding.MeshResource] = None)¶
Context manager for FP8 usage.
mesh_shape = (4, 2) dp_mesh_axis_name = 'data_parallel' tp_mesh_axis_name = 'tensor_parallel' devices = np.asarray(jax.devices()).reshape(*mesh_shape) with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) with fp8_autocast(enabled=True, mesh_resource=mesh_resource): rules = extend_logical_axis_rules(tuple()) transformer = TransformerLayer() with partitioning.axis_rules(rules): pjit(transformer.init, ...)(...)
Note
We only support
margin
,fp8_format
,interval
,amax_history_len
and :attr:`amax_compute_algo`(with value ‘max’ and ‘most_recent’) in recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling will trigger an assertion.- Parameters
enabled (bool, default = False) – Whether or not to enable fp8
fp8_recipe (recipe.DelayedScaling, default = None) – Recipe used for FP8 training.
mesh_resource (MeshResource, default = None) – Specify the mesh axes for data and tensor parallelism to shard along. If set to None, then no data or tensor parallelism will be used.
- transformer_engine.jax.update_collections(new: Collection, original: Collection)¶
A helper to update Flax’s Collection.
Collection = [dict, flax.core.frozen_dict.FrozenDict]
- Parameters
new (Collection) – A collection that includes new data.
original (Collection) – The base collection.
- Returns
outputs – The updated collection.
- Return type
Collection
- transformer_engine.jax.update_fp8_metas(state: Collection)¶
Calculate new fp8 scales and its inverse via the followed formula
sf = (fp8_max / amax) / (2 ^ margin) sf = sf if amax > 0.0, else original_scale updated_scale = sf if isfinite(amax), else original_scale) updated_scale_inv = 1/updated_scale
Collection = [dict, flax.core.frozen_dict.FrozenDict]
- Parameters
state (Collection) – A collection that includes FP8 metas.
- Returns
outputs – The collection with updated FP8 metas.
- Return type
Collection
- class transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)¶
Applies layer normalization over a mini-batch of inputs. There are two types of normalization supported by this module, regular and root mean square layer Normalization.
The regular layer normalization is as described in the paper Layer Normalization
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]\(\gamma\) and \(\beta\) are learnable affine transform parameters of size of each input sample.
The root mean square layer normalization (RMSNorm) is as described in the paper Root Mean Square Layer Normalization
\[y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma\]\[RMS = \sqrt{\mathrm{E}[x^2]}\]\(\gamma\) is learnable affine transform parameters of size of each input sample.
- Parameters
epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.
zero_centered_gamma (bool, default = False) –
If set to True, the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init.
scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh.
bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when
layernorm_type='layernorm'
. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. only used when
layernorm_type='layernorm'
.
- Optimization parameters
dtype (jax.numpy.dtype, default = jax.numpy.float32) – the data type used to allocate the initial parameters.
transpose_batch_sequence (bool, default = False) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
- __call__(x: jax.numpy.ndarray)¶
Applies layer normalization to the input
inputs
.- Parameters
inputs (jax.numpy.ndarray) – Input tensors.
- Returns
outputs – Output tensors.
- Return type
jax.numpy.ndarray
- class transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)¶
Applies a linear transformation to the incoming data \(y = xA^T + b\)
- Parameters
features (Union[Iterable[int], int]) – The hidden size of each output sample.
kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes (Tuple[str, ...], default = ()) – The name of axes used to shard the weights with a corresponding mesh.
use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.
bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when
use_bias=True
. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).bias_axes (Tuple[str, ...], default = ()) – The name of axes used to shard bias with a corresponding mesh, only used when
use_bias=True
.axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.
- Optimization parameters
dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.
transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
- __call__(inputs: Array)¶
Apply the linear transformation to the input.
- Parameters
inputs (jax.numpy.ndarray) – Input tensors.
- Returns
outputs – Output tensors.
- Return type
jax.numpy.ndarray
- class transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)¶
Applies layer normalization followed by linear transformation to the incoming data.
- Parameters
features (Union[Iterable[int], int]) – The hidden size of each output sample.
enable_layernorm (bool, default = True) – Indicate whether to enable layer normalization before linear transformation.
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.
epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma (bool, default = False) –
If set to True, the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init
scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh, only used when
enable_layernorm=True
.ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when
enable_layernorm=True
andlayernorm_type='layernorm'
. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).ln_bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. It is only used when
enable_layernorm=True
andlayernorm_type='layernorm'
.kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes (Tuple[str, ...], default = ()) – The name of axes used to shard the weights with a corresponding mesh.
use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.
bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when
use_bias=True
. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).bias_axes (Tuple[str, ...], default = ()) – The name of axes used to shard bias with a corresponding mesh, only used when
use_bias=True
.return_layernorm_output (bool, default = True) – Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs.
axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.
- Optimization parameters
dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.
transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
depth_scaling (float, default = None) – The factor to scale the output from DenseGeneral. It should be a float value or None. When None is set, then no scaling is applied.
- __call__(inputs: Array)¶
Apply layer normalization to the input followed by a linear transformation.
- Parameters
inputs (jax.numpy.ndarray) – Input tensor.
- Returns
outputs (jax.numpy.ndarray) – Output tensors.
ln_outputs (jax.numpy.ndarray) – The output tensors of layer normalization. If
return_layernorm_output=False
, then this would be None.
- class transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)¶
Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by given activations.
- Parameters
intermediate_dim (int, default = 2048) – Intermediate size to which input samples are projected.
enable_layernorm (bool, default = True) – Indicate whether to enable layer normalization before linear transformation.
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.
epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma (bool, default = False) –
If set to True, the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]This parameter is only applicable for ‘layernorm’. The default of scale_init will also be changed. See scale_init.
scale_init (Initializer, default = None) – Used for initializing scale factors \(\gamma\). If None is provided, scale_init is set according to the value of zero_centered_gamma. If zero_centered_gamma is set to True, then scale_init is flax.linen.initializers.zeros. Otherwise, scale_init is flax.linen.initializers.ones. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the scale factors \(\gamma\) with a corresponding mesh, only used when
enable_layernorm=True
.ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing shift factors \(\beta\), only used when
enable_layernorm=True
andlayernorm_type='layernorm'
. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).ln_bias_axes (Tuple[str, ...], default = ('embed', )) – The name of axes used to shard the shift factors \(\beta\) with a corresponding mesh. Only used when
enable_layernorm=True
andlayernorm_type='layernorm'
.kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing the weights of both linear transformations. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes_1 (Tuple[str, ...], default = ('embed', 'act', 'mlp')) – The name of axes used to shard the weights with a corresponding mesh for the weight of the first linear transformations.
kernel_axes_2 (Tuple[str, ...], default = ('mlp', 'embed')) – The name of axes used to shard the weights with a corresponding mesh for the weight of the second linear transformations.
use_bias (bool, default = False) – Indicate whether to enable bias shifting. If set to False, the layer will not learn an additive bias.
bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias, only used when
use_bias=True
. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).bias_axes_1 (Tuple[str, ...], default = ('mlp',)) – The name of axes used to shard bias with a corresponding mesh for the weight of the first linear transformations. Only used when
use_bias=True
.bias_axes_2 (Tuple[str, ...], default = ('embed',)) – The name of axes used to shard bias with a corresponding mesh for the weight of the second linear transformations. Only used when
use_bias=True
.return_layernorm_output (bool, default = True) – Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs.
activations (Sequence[Union[str, Callable]], default = ('relu',)) – The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer.
intermediate_dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate (float, default = 0.1) – Dropout probability for the dropout op after the
activations
.intermediate_hidden_dropout_dims (Sequence[int], default = ()) – Dimensions that will share the same dropout mask for hidden
axis (Union[Iterable[int], int], default = -1) – An integer tuple with axes to apply the transformation on.
- Optimization parameters
dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.
transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
- __call__(inputs: Array, deterministic: bool = False)¶
Apply layer normalization to the input followed by a feedforward network (MLP Block).
- Parameters
inputs (jax.numpy.ndarray) – Input tensor.
deterministic (bool, default = False) – Disable dropout ops if set to True.
- Returns
outputs (jax.numpy.ndarray) – Output tensors.
ln_outputs (jax.numpy.ndarray) – The output tensors of layer normalization. If
return_layernorm_output=False
, then this would be None.
- class transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)¶
T5-style relative positional embeddings to the attention logits.
- Parameters
num_buckets (int) – The number of buckets to bucket distances between key and query positions into.
max_distance (int) – The maximum distance before everything is lumped into the last distance bucket.
num_attention_heads (int) – Number of attention heads in the transformer layer.
embedding_init (Initializer, default = flax.linen.linear.default_embed_init) – Used for initializing relative embedding tables.
embedding_axes (Tuple[str, ...], default = ('heads', 'relpos_buckets')) – The name of axes used to shard embedding attention bias with a corresponding mesh.
- Optimization parameters
dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.
- __call__(q_seqlen, k_seqlen, bidirectional=True)¶
Generate relative position embedding attention biases.
- Parameters
q_seqlen (int) – The sequence length of query.
k_seqlen (int) – The sequence length of key.
bidirectional (bool, default = True) – Indicate whether to allow positive memory-query relative position embeddings.
- Returns
output – An attention bias with shape (1, num_attention_heads, q_seqlen, k_seqlen).
- Return type
jax.numpy.ndarray
- class transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)¶
Multi-head Attention (MHA), including Query, Key, Value and Output projection.
Note
Argument
mask
will be ignored whenattn_mask_type
is set to “causal”.- Parameters
head_dim (int) – The hidden dimension of each attention head.
num_heads (int) – The number of attention heads
dropout_rate (float, default = 0.0) – Dropout probability for the dropout op during multi-head attention.
dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that is used to generate Dropout masks in the core attention.
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.
layernorm_epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma (bool, default = False) –
If set to True, the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]This parameter is only applicable for ‘layernorm’.
kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) Used for initializing the QKV and Output projection weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
use_bias (bool, default = False) – Indicate whether or not to enable bias shifting for QKVO projections. If set to False, the layer will not learn additive biases.
bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias of QKVO projections, only used when
use_bias=True
. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).apply_residual_connection_post_layernorm (bool, default = False) – Indicate if apply residual connection with the output of layer normalization.
output_layernorm (bool, default = False) – Indicate if apply a layer normalization at the end of MHA.
attn_mask_type ({'causal', 'padding'}, default = 'causal') – Type of attention mask passed into softmax operation. Introduced in v0.10.0.
- Optimization parameters
dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.
fuse_qkv (bool, default = True) – If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention.
transpose_batch_sequence (bool, default = True) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits (bool, default = False) – Indicate whether to scale attention logits. If set to True, \(\frac{Q}{\sqrt{head_dim}*K}\), else \(Q*K\)
scaled_query_init (bool, default = True) – Whether to scale WQ on initialization by \(\sqrt{head_dim}\)
float32_logits (bool, default = False) – Whether to compute attention logits in float32.
- __call__(inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, *, decode: bool = False, deterministic: bool = False)¶
MultiHeadAttention Layer: [Query, Key, Value projection] -> Dot Product Attention -> Output projection.
- Parameters
inputs_q (jax.numpy.ndarray) – Input tensor for query projection.
inputs_kv (jax.numpy.ndarray) – Input tensor for key/value projection.
mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out self-attention softmax input.
bias (jax.numpy.ndarray, default = None) – A tensor used to shift self-attention softmax input.
* –
decode (bool,default = False) – Indicate whether to prepare and use an autoregressive cache.
deterministic (bool,default = False) – Disable dropout layers if set to True.
- Returns
outputs – Output tensors.
- Return type
jax.numpy.ndarray
- class transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)¶
TransformerLayer is made up of a relative embedding, an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.
Note
Argument
attention_mask
will be ignored whenself_attn_mask_type
is set to “causal”.- Parameters
hidden_size (int, default = 512) – The hidden size of each input sample.
mlp_hidden_size (int, default = 2048) – Intermediate size to which input samples are projected.
num_attention_heads (int, default = 8) – Number of attention heads in the transformer layer.
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – Indicate the type of layer normalization.
layernorm_epsilon (float, default = 1e-6) – A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma (bool, default = False) –
If set to True, the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]This parameter is only applicable for ‘layernorm’.
hidden_dropout (float, default = 0.1) – Dropout probability for the dropout op after FC2 layer.
hidden_dropout_dims (Sequence[int], default = ()) – Dimensions that will share the same dropout mask for hidden
attention_dropout (float, default = 0.1) – Dropout probability for the dropout op during multi-head attention.
intermediate_dropout (float, default = 0.1) – Dropout probability for the dropout op after FC1 layer.
intermediate_dropout_dims (Sequence[int], default = ()) – Dimensions that will share the same dropout mask for hidden after FC1 layer.
dropout_rng_name (str, default = 'dropout') – The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks in the Multi-Head Attention.
mha_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) Used for initializing weights of QKV and Output projection weights. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) Used for initializing weights of FC1 and FC2 layers. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_activations (Sequence[str], default = ('relu', )) – The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer.
use_bias (bool, default = False) – Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases.
bias_init (Initializer, default = flax.linen.initializers.zeros) – Used for initializing bias of QKVO projections, FC1 and FC2. It is only used when
use_bias=True
. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).apply_residual_connection_post_layernorm (bool, default = False) – If set to True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)
output_layernorm (bool, default = False) – If set to True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.
float32_attention_logits (bool, default = False) – If set to True, attention logits are executed in jax.numpy.float32.
layer_type (TransformerLayerType, default = TransformerLayerType.ENCODER) – If set to TransformerLayerType.DECODER, an additional cross-attention block is added after self-attention.this can be used for structures like T5 Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type ({'causal', 'padding'}, default = 'causal') – Type of attention mask passed into softmax operation. Introduced in v0.10.0.
enable_relative_embedding (bool, default = True) – Whether to enable relative embedding as shifting of attention logits.
relative_embedding (flax.linen.Module, default = None) – The module for relative embedding execution, only used when
enable_relative_embedding=True
. Default is None, which will create an instance of RelativePositionBiases ifenable_relative_embedding=True
. Default: RelativePositionBiases( num_buckets=32, max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, embedding_init=flax.linen.initializers.variance_scaling(1.0, ‘fan_avg’, ‘uniform’), name=’relpos_bias’)
- Optimization parameters
dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.
drop_path (float, default = 0.0) – When > 0.0, applies stochastic depth per sample in the main path of the residual block.
fuse_qkv_params (bool, default = True) – If set to True, TransformerLayer module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention.
transpose_batch_sequence (bool, default = False) – Indicate whether the input tensors were switched axis of batch and sequence length dimension. if set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits (bool, default = False) – Indicate whether to scale attention logits. if set to True, \(\frac{Q}{\sqrt{head_dim}*K}\), else \(Q*K\)
scaled_query_init (bool, default = True) – Whether to scale WQ on initialization by \(\sqrt{head_dim}\)
- __call__(inputs: Array, encoded: Array = None, attention_mask: Array = None, encoder_decoder_mask: Array = None, deterministic: bool = False, decode: bool = False, max_decode_length: bool = None)¶
Transformer Layer: attention block and a feedforward network (MLP)
- Parameters
inputs (jax.numpy.ndarray) – Input tensor.
encoded (jax.numpy.ndarray, default = None) – Output tensors of the encoder block to be fed into the decoder block if using
layer_type=TransformerLayerType.DECODER
.attention_mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out self-attention softmax input.
encoder_decoder_mask (jax.numpy.ndarray, default = None) – Boolean tensor used to mask out cross-attention softmax input when
layer_type=TransformerLayerType.DECODER
.deterministic (bool, default = False) – Disable dropout layers if set to True.
decode (bool,default = False) – Indicate whether to prepare and use an autoregressive cache in Multi-head attention (MHA).
max_decode_length (bool, default = None) – The maximum length to generate relative embedding biases when
layer_type=TransformerLayerType.DECODER
andenable_relative_embedding=True
.
- Returns
outputs – Output tensors.
- Return type
jax.numpy.ndarray
- transformer_engine.jax.flax.extend_logical_axis_rules(rules: LogicalRules)¶
Extend the given Flax logical axis rules with the predefined TransformerLayer’s logical axis rules.
Note
We currently only support logical axis rules for single GPU training, data parallel training and 1D-sharding tensor parallel training. Refer to Figure 3 in Megatron-LM tensor parallel for 1D-sharding tensor parallelism.
Warning
Please make sure ShardingResource is set via fp8_autocast before calling this function.
Note
This function is only needed when using TransformerLayer. For other modules, such as DenseGeneral, please properly set axes of kernels and bias.
- Parameters
rules (Sequence[Tuple[str, Union[str, None]]]) – the base Flax logical axis rules to extend.
- Returns
extended_rules – the extended Flax logical axis rules.
- Return type
Sequence[Tuple[str, Union[str, None]]]