Skip to content

Layers

ESM2QueryScaling

Bases: Module

Source code in bionemo/llm/model/layers.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ESM2QueryScaling(torch.nn.Module):  # noqa: D101
    def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
        """A custom layer that scales quary values.

        This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2
        which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()

        Args:
            config (TransformerConfig): The megatron config. This is used for computing projection_size
        """
        super().__init__()
        projection_size = config.kv_channels * config.num_attention_heads
        self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
        self.sqrt_val = math.sqrt(self.hidden_size_per_attention_head)

    @torch.compile
    def forward(self, query, *args, **kwargs):  # noqa: D102
        return query / self.sqrt_val

__init__(config, *args, **kwargs)

A custom layer that scales quary values.

This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2 which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()

Parameters:

Name Type Description Default
config TransformerConfig

The megatron config. This is used for computing projection_size

required
Source code in bionemo/llm/model/layers.py
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
    """A custom layer that scales quary values.

    This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2
    which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()

    Args:
        config (TransformerConfig): The megatron config. This is used for computing projection_size
    """
    super().__init__()
    projection_size = config.kv_channels * config.num_attention_heads
    self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
    self.sqrt_val = math.sqrt(self.hidden_size_per_attention_head)

TELayerNorm

Bases: LayerNorm

Source code in bionemo/llm/model/layers.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class TELayerNorm(te.pytorch.LayerNorm):  # noqa: D101
    def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
        """A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig.
            This allows this method to be used in a megatron layerspec.

        Args:
            config (TransformerConfig): The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma.
                The rest of the config is not used.
        """  # noqa: D205
        # Eps tends to get passed through properly, as does hidden_size, but not other params from the config.
        super().__init__(
            *args,
            zero_centered_gamma=config.layernorm_zero_centered_gamma,
            sequence_parallel=config.sequence_parallel,
            **kwargs,
        )

__init__(config, *args, **kwargs)

A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig. This allows this method to be used in a megatron layerspec.

Parameters:

Name Type Description Default
config TransformerConfig

The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma. The rest of the config is not used.

required
Source code in bionemo/llm/model/layers.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
    """A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig.
        This allows this method to be used in a megatron layerspec.

    Args:
        config (TransformerConfig): The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma.
            The rest of the config is not used.
    """  # noqa: D205
    # Eps tends to get passed through properly, as does hidden_size, but not other params from the config.
    super().__init__(
        *args,
        zero_centered_gamma=config.layernorm_zero_centered_gamma,
        sequence_parallel=config.sequence_parallel,
        **kwargs,
    )