Source code for accvlab.draw_heatmap.funtions.draw_heatmap_batched

# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from accvlab.draw_heatmap.draw_heatmap_ext import draw_heatmap_batched_impl
from accvlab.draw_heatmap.draw_heatmap_ext import draw_heatmap_batched_classwise_impl
from accvlab.batching_helpers import RaggedBatch


[docs] def draw_heatmap_batched( heatmap: torch.Tensor, centers: RaggedBatch, radii: RaggedBatch, diameter_to_sigma_factor: float = 6.0, k_scale: float = 1.0, labels: RaggedBatch = None, ): ''' Draws heatmaps for a batch of samples. :gpu: Args: heatmap: Tensor of shape (batch_size, height, width) when labels is None. Otherwise with shape (batch_size, max_num_classes, height, width). The heatmap will be modified in place. centers: RaggedBatch of shape (batch_size, max_num_targets, 2). The centers of the heatmaps to draw. `max_num_targets` is the maximum number of targets across the batch. radii: RaggedBatch of shape (batch_size, max_num_targets). The radii of the heatmaps to draw. `max_num_targets` is the maximum number of targets across the batch. diameter_to_sigma_factor: Factor for converting diameter to sigma. k_scale: Scale factor for the Gaussian kernel labels: RaggedBatch of shape (batch_size, max_num_targets). The labels are denoted as the class index. `max_num_targets` is the maximum number of targets across the batch. If None, all classes of the sample will be drawn in one heatmap. ''' centers_tensor = centers.tensor radii_tensor = radii.tensor assert ( centers_tensor.shape[0] == radii_tensor.shape[0] ), "centers and radii must have the same size batch size" assert ( centers_tensor.shape[1] == radii_tensor.shape[1] ), "centers and radii must have the same maximum number of objects" # TODO: This conversion can be replaced by type dispatching in the C++ implementation nums_targets = centers.sample_sizes.to(torch.int32) if labels is None: draw_heatmap_batched_impl( heatmap, centers_tensor, radii_tensor, nums_targets, diameter_to_sigma_factor, k_scale ) else: labels_tensor = labels.tensor assert ( centers_tensor.shape[0] == labels_tensor.shape[0] ), "centers and labels must have the same size batch size" assert ( centers_tensor.shape[1] == labels_tensor.shape[1] ), "centers and labels must have the same maximum number of objects" draw_heatmap_batched_classwise_impl( heatmap, centers_tensor, radii_tensor, nums_targets, labels_tensor, diameter_to_sigma_factor, k_scale, )