Source code for accvlab.batching_helpers.batched_index_mapping_op

# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.autograd.function import once_differentiable
from typing import Any, Union
import accvlab.batching_helpers.batched_indexing_access_cuda as batched_indexing_access_cuda
from .data_format import RaggedBatch


class BatchedIndexMapping(torch.autograd.Function):
    """Implementation of mapping values between two tensors using a mapping defined by pairs of indices as well as the corresponding backward operation.

    Please see documentation of :func:`batched_index_mapping` for more details.

    Warning:
        This class is not intended to be used directly. It is an implementation detail of :func:`batched_index_mapping`.
    """

    @staticmethod
    def forward(
        ctx: Any,
        input_data: torch.Tensor,
        input_indices: torch.Tensor,
        output_indices: torch.Tensor,
        nums_indices: torch.Tensor,
        to_insert_into: torch.Tensor,
    ) -> torch.Tensor:
        input_data = input_data.contiguous()
        input_indices = input_indices.contiguous()
        output_indices = output_indices.contiguous()
        nums_indices = nums_indices.contiguous()
        to_insert_into = to_insert_into.contiguous()
        result = batched_indexing_access_cuda.map_values_by_index_pairs(
            input_data, input_indices, output_indices, nums_indices, to_insert_into, backward_accumulate=False
        )
        ctx.save_for_backward(input_indices, output_indices, nums_indices)
        ctx.num_batch_dims = nums_indices.dim()
        ctx.input_max_sample_size = input_data.shape[ctx.num_batch_dims]
        return result

    @staticmethod
    @once_differentiable
    def backward(ctx: Any, grad: Union[torch.Tensor, None]):
        if grad is None:
            return None, None, None, None, None
        else:
            forward_input_indices, forward_output_indices, nums_indices = ctx.saved_tensors
            # The shape for `grad_input` corresponds to the shape of `grad` except for `dim==1`,
            # where the size may be different. The size in this dim is known from the forward step.
            shape_to_set = list(grad.shape)
            shape_to_set[ctx.num_batch_dims] = ctx.input_max_sample_size

            grad = grad.contiguous()

            grad_input = torch.zeros(
                shape_to_set, dtype=grad.dtype, device=grad.device, requires_grad=grad.requires_grad
            )
            grad_input = batched_indexing_access_cuda.map_values_by_index_pairs(
                grad,
                forward_output_indices,
                forward_input_indices,
                nums_indices,
                grad_input,
                backward_accumulate=True,
            )
            grad_to_insert_into = batched_indexing_access_cuda.backward_insert_const(
                0.0, forward_output_indices, nums_indices, grad
            )
            return grad_input, None, None, None, grad_to_insert_into


[docs] def batched_index_mapping( source_data: Union[torch.Tensor, RaggedBatch], source_indices: RaggedBatch, target_indices: RaggedBatch, target_data: Union[torch.Tensor, RaggedBatch], ) -> Union[torch.Tensor, RaggedBatch]: """Map values between ``source_data`` and ``target_data`` using a mapping defined by pairs of indices for elements in ``source_data`` and ``target_data``, and set the corresponding values in ``target_data``. :gpu: For a sample ``i`` and a valid index pair ``j`` (valid means ``target_indices.sample_sizes[i] > j`` and ``source_indices.sample_sizes[i] > j``), the operation can be expressed as (assuming non-uniform dimension ``dim == 1`` for both ``source_data`` and ``target_data``): ``target_data[i, target_indices[i, j]] = source_data[i, source_indices[i, j]]`` This function sets the values in `target_data` in a way which corresponds to the line above for all valid matches in all samples. Warning: It is expected that for each sample, the number of valid indices for the source and target matches, i.e. ``target_indices.sample_sizes == source_indices.sample_sizes``. If this is not the case, the behavior is undefined. Warning: This function assumes that for each sample, there are no duplicate indices in ``target_indices``, i.e. there are no duplicates in the valid entries in: ``target_indices[i, 0:target_indices.sample_sizes[i]]``. If this is not the case, the behavior is undefined. There are no such restrictions on `source_indices`. Args: source_data: Input data. Shape in case of a tensor: (batch_size, num_entries_input, ...) In case of a RaggedBatch, the following holds: - ``target_data.shape[0] == batch_size`` - ``target_data.shape[target_data.non_uniform_dim] == num_entries_input`` The number of dimensions needs to correspond to ``target_data``, and the shape needs to be the same except in the non-uniform dimension (``dim == 1`` for tensors), as in this dimension, the matching is done using the index pairs. source_indices: Indices at which to get the input data. Shape: (batch_size, max_num_matches). Note that the batch size and sample sizes need to match ``target_indices``, i.e. ``target_indices.sample_sizes == source_indices.sample_sizes`` and ``max_num_matches`` corresponds to the maximum matches among all samples. target_indices: Indices at which to fill the data. Shape: (batch_size, max_num_matches) Note that ``max_num_matches`` corresponds to the maximum matches among all samples. target_data: Data to fill the values into. Shape in case of a tensor: (batch_size, num_entries_output, ...) In case of a RaggedBatch, the following holds: - ``target_data.shape[0] == batch_size`` - ``target_data.shape[target_data.non_uniform_dim] == num_entries_output`` Returns: As ``target_data``, with the values from ``source_data`` inserted according to the pairs of corresponding indices in ``source_indices`` and ``target_indices``. Shape in case of a tensor: ``(batch_size, num_entries_output, ...)`` In case of a RaggedBatch, the following holds: - ``target_data_filled.shape[0] == batch_size`` - ``target_data_filled.shape[target_data_filled.non_uniform_dim] == num_entries_output`` Example: - '-' indicates values in the input which are not used in the mapping - '*' indicates filler values in `source_indices` and `target_indices`, which are ignored. - '..' indicates data which remains unchanged, i.e. is the same as in the `target_data` parameter and the output. Each depicted element in `source_data` and `target_data` may represent a single value (in case of 2D tensors), or itself be a non-scalar entry (in case that the data has more than 2 dimensions). .. image:: images/Mapping_ragged.png :alt: Illustration of the mapping operation :align: center """ num_batch_dims = target_indices.non_uniform_dim assert ( target_indices.dim() == num_batch_dims + 1 and source_indices.dim() == num_batch_dims + 1 ), "Indices must have exactly one dimension in addition to the batch dimensions" assert ( target_indices.shape[:num_batch_dims] == source_indices.shape[:num_batch_dims] ), "Batch shape mismatch" assert ( target_indices.shape[num_batch_dims] == source_indices.shape[num_batch_dims] ), "Maximum number of indices mismatch" is_target_data_ragged_batch = isinstance(target_data, RaggedBatch) is_source_data_ragged_batch = isinstance(source_data, RaggedBatch) if is_target_data_ragged_batch: target_data_non_uniform_dim = target_data.non_uniform_dim target_data = target_data.get_non_uniform_dimension_transposed_to(num_batch_dims) target_data_tensor = target_data.tensor else: target_data_non_uniform_dim = 1 target_data_tensor = target_data if is_source_data_ragged_batch: source_data = source_data.get_non_uniform_dimension_transposed_to(num_batch_dims) source_data = source_data.tensor # TODO: add check for consistent sample sizes (debug mode?) res = BatchedIndexMapping.apply( source_data, source_indices.tensor, target_indices.tensor, source_indices.sample_sizes, target_data_tensor, ) if is_target_data_ragged_batch: res = target_data.create_with_sample_sizes_like_self(res, num_batch_dims) res = res.get_non_uniform_dimension_transposed_to(target_data_non_uniform_dim) else: if target_data_non_uniform_dim != 1: res = res.transpose(1, target_data_non_uniform_dim) return res