# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.autograd.function import once_differentiable
from typing import Any, Union, Sequence, Optional
import accvlab.batching_helpers.batched_indexing_access_cuda as batched_indexing_access_cuda
from .data_format import RaggedBatch
class BatchedIndexingAccess(torch.autograd.Function):
"""Batched indexing with non-uniform indices.
A wrapper function (:func:`batched_indexing_access`) is available and this class should not be used directly.
For details about the indexing operation, see documentation of :func:`batched_indexing_access` below, which wraps the functionality of this class
and presents it as a function.
"""
@staticmethod
def forward(
ctx: Any,
input_data: torch.Tensor,
input_indices: torch.Tensor,
input_nums_indices: torch.Tensor,
fill_value: float = 0.0,
) -> torch.Tensor:
"""Batched indexing with non-uniform indices.
Detailed documentation see :func:`batched_indexing_access` below, which wraps the functionality of this class
and presents it as a function (with some additional functionality).
"""
input_data = input_data.contiguous()
input_indices = input_indices.contiguous()
input_nums_indices = input_nums_indices.contiguous()
result = batched_indexing_access_cuda.forward(
input_data, input_indices, input_nums_indices, fill_value
)
ctx.save_for_backward(input_indices, input_nums_indices)
num_batch_dim = input_nums_indices.dim()
ctx.input_num_targets = input_data.shape[num_batch_dim]
return result
@staticmethod
@once_differentiable
def backward(ctx: Any, grad: Union[torch.Tensor, None]):
"""Perform back-propagation of gradients for the performed mapping operation (see documentation of method :meth:`forward`)."""
if grad is None:
return None, None, None, None
else:
input_indices, input_nums_indices = ctx.saved_tensors
grad = grad.contiguous()
grad_input = batched_indexing_access_cuda.backward_new_tensor(
grad, input_indices, input_nums_indices, ctx.input_num_targets, 0.0, backward_accumulate=True
)
return grad_input, None, None, None
class BatchedInverseIndexingAccessNewTensor(torch.autograd.Function):
"""Batched inverse indexing access, i.e. writing data into the indexed location, with non-uniform indices
For details about the inverse indexing operation, see documentation of :func:`batched_inverse_indexing_access` below,
which wraps the functionality of this class and presents it as a function (with some additional functionality).
"""
@staticmethod
def forward(
ctx: Any,
input: torch.Tensor,
output_indices: torch.Tensor,
output_nums_indices: torch.Tensor,
output_num_targets: Sequence,
fill_value: float = 0.0,
) -> torch.Tensor:
"""Batched indexing, i.e. writing data into the indexed location, with non-uniform indices
Detailed documentation see :func:`batched_inverse_indexing_access` below, which wraps the functionality of this class
and presents it as a function.
"""
input = input.contiguous()
output_indices = output_indices.contiguous()
output_nums_indices = output_nums_indices.contiguous()
result = batched_indexing_access_cuda.backward_new_tensor(
input,
output_indices,
output_nums_indices,
output_num_targets,
fill_value,
backward_accumulate=False,
)
ctx.save_for_backward(output_indices, output_nums_indices)
return result
@staticmethod
@once_differentiable
def backward(ctx: Any, grad: Union[torch.Tensor, None]):
"""Perform back-propagation of gradients for the performed mapping operation (see documentation of method :meth:`forward`)."""
if grad is None:
return None, None, None, None, None
else:
(output_indices, output_nums_indices) = ctx.saved_tensors
grad = grad.contiguous()
grad_input = batched_indexing_access_cuda.forward(grad, output_indices, output_nums_indices, 0.0)
return grad_input, None, None, None, None
class BatchedInverseIndexingAccessInsert(torch.autograd.Function):
"""Batched inverse indexing access, i.e. writing data into the indexed location, with non-uniform indices
For details about the inverse indexing operation, see documentation of :func:`batched_inverse_indexing_access` below,
which wraps the functionality of this class and presents it as a function (with some additional functionality).
"""
@staticmethod
def forward(
ctx: Any,
to_fill: torch.Tensor,
output_indices: torch.Tensor,
output_nums_indices: torch.Tensor,
to_fill_into: torch.Tensor,
) -> torch.Tensor:
"""Batched indexing, i.e. writing data into the indexed location, with non-uniform indices
Detailed documentation see :func:`batched_inverse_indexing_access` below, which wraps the functionality of this class
and presents it as a function.
"""
to_fill = to_fill.contiguous()
to_fill_into = to_fill_into.contiguous()
output_indices = output_indices.contiguous()
output_nums_indices = output_nums_indices.contiguous()
result = batched_indexing_access_cuda.backward_insert(
to_fill, output_indices, output_nums_indices, to_fill_into
)
ctx.save_for_backward(output_indices, output_nums_indices)
return result
@staticmethod
@once_differentiable
def backward(ctx: Any, grad: Union[torch.Tensor, None]):
"""Perform back-propagation of gradients for the performed mapping operation (see documentation of method :meth:`forward`)."""
if grad is None:
return None, None, None, None
else:
(output_indices, output_nums_indices) = ctx.saved_tensors
grad = grad.contiguous()
grad_for_to_insert = batched_indexing_access_cuda.forward(
grad, output_indices, output_nums_indices, 0.0
)
grad_for_to_insert_into = batched_indexing_access_cuda.backward_insert_const(
0.0, output_indices, output_nums_indices, grad
)
return grad_for_to_insert, None, None, grad_for_to_insert_into
[docs]
def batched_indexing_access(
input_data: Union[RaggedBatch, torch.Tensor],
input_indices: RaggedBatch,
filler_value: float = 0.0,
dim_to_index_in: Optional[int] = None,
) -> RaggedBatch:
"""Batched indexing access with non-uniform indices.
:gpu:
Note that for each sample, the number of resulting entries corresponds to the number of indices. This means that in general,
the output size will be non-uniform. Therefore, a :class:`RaggedBatch` is returned regardless of the ``input_data`` type.
Note:
Note that whether ``input_data`` is a :class:`RaggedBatch` or a :class:`torch.Tensor`, the indexing of ``input_data``
is performed along ``dim_to_index_in``, which is not necessarily the non-uniform dimension of ``input_data``.
Warning:
While the ``filler_value`` parameter can be used to set the value for filler values,
the filler value may change when processing the resulting :class:`RaggedBatch` further. Therefore,
care needs to be taken when assuming a certain filler value.
Args:
input_data: Data to which the indexing is applied.
input_indices: For each sample (element along the batch dimension), the indices of entries to obtain from the input.
Shape: ``(*batch_shape, max_num_indices)``
Here, ``max_num_indices`` corresponds to the maximum number of indices over all samples.
filler_value: Filler values for the remaining elements in the output (corresponding to the fillers in ``input_indices``).
Default: 0.0
dim_to_index_in: Dimension on which to apply the indexing. Cannot be a batch dimension of the input indices.
If not set, will corresponds to `input_indices.non_uniform_dim`.
Returns:
Result containing the indexed entries from the input tensor. For a sample ``i`` and a valid
index ``j < input_indices.sample_sizes[i]``, the following holds (assuming ``dim_to_index_in == 1``):
``indexed_vals[i, j] == input_data[i, input_indices[i, j]]``
The shape of the resulting data is:
- ``indexed_vals.shape[0] == batch_size``
- ``indexed_vals.shape[dim_to_index_in] == max_num_indices``
- Remaining dimensions correspond to the input data
Example:
In the illustration below:
- Letters indicate data entries that are indexed in the input (and therefore appear in the output)
- '-' indicates entries where the actual values are not relevant (in the input).
- '*' indicates filler values in :class:`RaggedBatch` instances.
.. image:: images/BatchedIndexing_ragged.png
:alt: Illustration of the batched indexing operation
:align: center
Each depicted entry in the data may represent a single value (in case of
2D tensors), or itself be a non-scalar entry (in case that ``input_data`` has more than 2 dimensions).
Note that for ``input_indices``, the entries are always scalar.
Also, we do not show the ``filler_value`` in the example. It is filled into the '*'-entries in the output.
In this case, the ``dim_to_index_in`` is 1.
"""
is_input_ragged_batch = isinstance(input_data, RaggedBatch)
if is_input_ragged_batch:
input_data = input_data.tensor
if dim_to_index_in is None:
dim_to_index_in = input_indices.non_uniform_dim
transpose_needed = False
else:
assert (
dim_to_index_in >= input_indices.num_batch_dims
), "Cannot index in a batch dimension of the input indices"
transpose_needed = input_indices.num_batch_dims != dim_to_index_in
assert (
dim_to_index_in >= input_indices.num_batch_dims
), "Cannot index in a batch dimension of the input indices"
transpose_needed = input_indices.num_batch_dims != dim_to_index_in
if transpose_needed:
input_data = input_data.transpose(input_indices.num_batch_dims, dim_to_index_in)
res = BatchedIndexingAccess.apply(
input_data, input_indices.tensor, input_indices.sample_sizes, filler_value
)
if transpose_needed:
res = res.transpose(input_indices.num_batch_dims, dim_to_index_in)
res = input_indices.create_with_sample_sizes_like_self(res, dim_to_index_in)
return res
[docs]
def batched_inverse_indexing_access(
input_data: Union[RaggedBatch, torch.Tensor],
output_indices: RaggedBatch,
output_num_targets: int,
filler_value: float = 0.0,
dim_to_index_in: Optional[int] = None,
) -> torch.Tensor:
"""Batched setting of values at given indices, with non-uniform indices.
:gpu:
Non-uniform indices means that for each sample, the indices, as well as the number of indices, vary.
Note:
This function is similar to :func:`batched_indexing_write`, but instead of using a ``to_write_into``
tensor, a tensor with a uniform filler value is created first, and the values to set are
written into that tensor.
Note:
Note that whether ``input_data`` is a :class:`RaggedBatch` instance or a tensor, the indexing
is performed along ``dim_to_index_in``, which is not necessarily the non-uniform dimension
of ``input_data``.
Warning:
This function assumes that for each sample, there are no duplicate indices in
``output_indices``, i.e. there are no duplicates in the valid entries in:
``output_indices[i, 0:output_indices.sample_sizes[i]]``.
If this is not the case, the behavior is undefined.
Args:
input_data: Data which to write into the given indices.
output_indices: For each sample (element along the batch dimension), the indices of entries to write to in the output.
Shape: ``(batch_size, max_num_indices)``
Here, ``max_num_indices`` corresponds to the maximum number of indices over all samples.
output_num_targets: Size of the dimension corresponding to the indexed dimension in the output
filler_value: Filler values for the non-indexed elements in the output. Default: 0.0
dim_to_index_in: Dimension on which to apply the indexing. Optional, default is the non-uniform
dimension of the output indices.
Returns:
Resulting tensor, containing the filled in values from the input, inserted at the corresponding indices,
and the filler values everywhere else.
For each sample ``i`` and each valid index ``j < output_indices.sample_sizes[i]``, the following holds:
``output[i, output_indices[i, j]] == input_data[i, j]``
The shape of the resulting data is:
- ``output.shape[0] == batch_size``
- ``output.shape[dim_to_index_in] == output_nums_targets``
- Remaining dimensions correspond to the input data
Example:
In the illustration below:
- Letters indicate data entries that are indexed in the input (and therefore appear in the output)
- '-' indicates entries where the actual values are not relevant (in the input).
- '*' indicates filler values in :class:`RaggedBatch` instances.
.. image:: images/BatchedInverseIndexing_ragged.png
:alt: Illustration of the batched inverse indexing operation
:align: center
Each depicted entry in the data may represent a single value (in case of
2D tensors), or itself be a non-scalar entry (in case that the data has more than 2 dimensions).
Note that for ``output_indices``, the entries are always scalar.
In this case, the ``dim_to_index_in`` is 1.
"""
is_input_ragged_batch = isinstance(input_data, RaggedBatch)
if is_input_ragged_batch:
input_data = input_data.tensor
if dim_to_index_in is None:
dim_to_index_in = output_indices.non_uniform_dim
transpose_needed = False
else:
assert (
dim_to_index_in >= output_indices.num_batch_dims
), "Cannot index in a batch dimension of the output indices"
transpose_needed = output_indices.num_batch_dims != dim_to_index_in
if transpose_needed:
input_data = input_data.transpose(output_indices.num_batch_dims, dim_to_index_in)
res = BatchedInverseIndexingAccessNewTensor.apply(
input_data, output_indices.tensor, output_indices.sample_sizes, output_num_targets, filler_value
)
if transpose_needed:
res = res.transpose(output_indices.num_batch_dims, dim_to_index_in)
return res
[docs]
def batched_indexing_write(
to_write: Union[RaggedBatch, torch.Tensor],
output_indices: RaggedBatch,
to_write_into: Union[RaggedBatch, torch.Tensor],
dim_to_index_in: Optional[int] = None,
) -> Union[RaggedBatch, torch.Tensor]:
"""Batched indexing write, i.e. writing data into the indexed location, with non-uniform indices.
Non-uniform indices means that for each sample, the indices, as well as the number of indices, vary.
:gpu:
Note:
This function is similar to :func:`batched_inverse_indexing_access`, but instead of creating a
constant tensor and filling the values in there, a ``to_write_into`` tensor is used, which
may already contain values, and only the values corresponding to the indices are updated.
Note:
Note that whether ``to_write`` and ``to_write_into`` are :class:`RaggedBatch` or :class:`torch.Tensor` instances, the indexing
is performed along ``dim_to_index_in``, which is not necessarily the non-uniform dimension
of ``to_write`` or ``to_write_into``.
Warning:
This function assumes that for each sample, there are no duplicate indices in
``output_indices``, i.e. there are no duplicates in the valid entries in:
``output_indices[i, 0:output_indices.sample_sizes[i]]``.
If this is not the case, the behavior is undefined.
Args:
to_write: Data which to write into the given indices.
output_indices: For each sample (element along the batch dimension), the indices of entries to write to in the output.
Shape: ``(batch_size, max_num_indices)``
Here, ``max_num_indices`` corresponds to the maximum number of indices over all samples.
to_write_into: Tensor or RaggedBatch to write into.
dim_to_index_in: Dimension on which to apply the indexing. Optional, default is the non-uniform
dimension of the output indices.
Returns:
Resulting tensor or :class:`RaggedBatch` instance. Corresponds to ``to_write_into``, with the values from
``to_write`` inserted at the corresponding indices, and the original values from ``to_write_into``
everywhere else.
Example:
In the illustration below:
- Letters indicate data entries that are indexed in the input (and therefore appear in the output)
- '-' indicates entries where the actual values are not relevant (in the input).
- '*' indicates filler values in :class:`RaggedBatch` instances.
- '..' indicates data which remains unchanged, i.e. is the same as in the ``to_write_into`` parameter
and the output.
.. image:: images/BatchedIndexWrite_ragged.png
:alt: Illustration of the batched indexing write operation
:align: center
Each depicted entry in the data may represent a single value (in case of
2D tensors), or itself be a non-scalar entry (in case that the data has more than 2 dimensions).
Note that for ``output_indices``, the entries are always scalar.
In this case, the ``dim_to_index_in`` is 1.
"""
is_input_ragged_batch = isinstance(to_write, RaggedBatch)
is_output_ragged_batch = isinstance(to_write_into, RaggedBatch)
if dim_to_index_in is None:
dim_to_index_in = output_indices.non_uniform_dim
transpose_needed = False
else:
assert (
dim_to_index_in >= output_indices.num_batch_dims
), "Cannot index in a batch dimension of the output indices"
transpose_needed = output_indices.num_batch_dims != dim_to_index_in
if is_input_ragged_batch:
to_write = to_write.tensor
if is_output_ragged_batch:
to_write_into_data = to_write_into.tensor
else:
to_write_into_data = to_write_into
if transpose_needed:
assert dim_to_index_in >= output_indices.num_batch_dims, "Cannot index in any batch dimension"
to_write = to_write.transpose(output_indices.num_batch_dims, dim_to_index_in)
to_write_into_data = to_write_into_data.transpose(output_indices.num_batch_dims, dim_to_index_in)
res = BatchedInverseIndexingAccessInsert.apply(
to_write, output_indices.tensor, output_indices.sample_sizes, to_write_into_data
)
if transpose_needed:
res = res.transpose(output_indices.num_batch_dims, dim_to_index_in)
if is_output_ragged_batch:
res = to_write_into.create_with_sample_sizes_like_self(res)
return res