# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import io
import warnings
from collections.abc import Iterator
from fractions import Fraction
from typing import BinaryIO, Literal, Optional, Sequence, Union, overload
import numpy as np
import torch
# Try importing optional dependencies
try:
import av
import av.container
import filetype
from megatron.energon.av.fastseek import Fastseek
AV_DECODE_AVAILABLE = True
except ImportError as e:
AV_DECODE_AVAILABLE = False
MISSING_DEPENDENCY = str(e)
from megatron.energon.av.av_data import AVData
[docs]
class AVDecoder:
"""A class that provides a flexible interface for decoding audio and video data.
This class allows users to control decoding parameters at runtime rather than having them fixed
during initialization. It's particularly useful for cases where different samples may need different
decoding parameters.
"""
seeker: "Fastseek"
stream: BinaryIO
suppress_warnings: bool
def __init__(self, stream: BinaryIO, suppress_warnings: bool = False) -> None:
if not AV_DECODE_AVAILABLE:
raise ImportError(
f"AV decoding is not available. Please install the required dependencies with:\n"
f"pip install megatron-energon[av_decode]\n"
f"Missing dependency: {MISSING_DEPENDENCY}. Install megatron-energon[av_decode] to use AVDecoder."
)
self.stream = stream
self.suppress_warnings = suppress_warnings
try:
self.seeker = Fastseek(self.stream)
except ValueError:
self.stream.seek(0)
self.seeker = Fastseek(self.stream, probe=True)
self.stream.seek(0)
[docs]
def get_video(self) -> AVData:
"""Get the entire video data from the stream (without audio)."""
video_clips, video_timestamps = self.get_video_clips(video_clip_ranges=[(0, float("inf"))])
return AVData(
video_clips=video_clips,
video_timestamps=video_timestamps,
audio_clips=[],
audio_timestamps=[],
)
[docs]
def get_video_clips(
self,
video_clip_ranges: Sequence[tuple[float, float]],
video_unit: Literal["frames", "seconds"] = "seconds",
video_out_frame_size: Optional[tuple[int, int]] = None,
) -> tuple[list[torch.Tensor], list[tuple[float, float]]]:
"""Get video clips from the video stream.
Args:
video_clip_ranges: List of video clip start and end positions in the given unit (see video_unit)
video_unit: Unit of the video clip positions ("frames" for frame number, "seconds" for timestamp)
video_out_frame_size: Output size for video frames (width, height), or None to use the original frame size
Returns:
A tuple containing:
- video_clips: List of video clips
- video_clips_timestamps: List of timestamps for each video clip start and end in seconds
"""
assert video_unit in ("frames", "seconds")
self.stream.seek(0) # Reset the video stream so that pyav can read the entire container
with av.open(self.stream) as input_container:
initialize_av_container(input_container)
assert len(input_container.streams.video) > 0, (
"No video stream found, but video_clips are requested"
)
video_stream = input_container.streams.video[0]
# Pre-calculate timing info for video
average_rate: Fraction = video_stream.average_rate # Frames per second
assert average_rate, "Video stream has no FPS."
time_base: Fraction = video_stream.time_base # Seconds per PTS unit
average_frame_duration: int = int(1 / average_rate / time_base) # PTS units per frame
if video_clip_ranges is not None:
# Convert video_clip_ranges to seeker unit
if video_unit == "frames" and self.seeker.unit == "pts":
# Convert from frames to pts units
video_clip_ranges = [
(
clip[0] / average_rate / time_base,
clip[1] / average_rate / time_base,
)
for clip in video_clip_ranges
]
if not self.suppress_warnings:
warnings.warn(
"Video container unit is frames, but seeking in time units. The resulting frames may be slightly off.",
RuntimeWarning,
)
elif video_unit == "seconds" and self.seeker.unit == "frames":
# Convert from seconds to frames
video_clip_ranges = [
(
clip[0] * average_rate,
clip[1] * average_rate,
)
for clip in video_clip_ranges
]
if not self.suppress_warnings:
warnings.warn(
"Video container unit is time units, but seeking using frame number. The resulting frames may be slightly off.",
RuntimeWarning,
)
elif video_unit == "seconds" and self.seeker.unit == "pts":
# Convert from seconds to pts units
video_clip_ranges = [
(clip[0] / time_base, clip[1] / time_base) for clip in video_clip_ranges
]
frame_iterator: Iterator[av.VideoFrame] = input_container.decode(video=0)
previous_frame_index: int = 0
video_clips_frames: list[list[torch.Tensor]] = []
video_clips_timestamps: list[tuple[float, float]] = []
for video_clip_range in video_clip_ranges:
start_frame_index, end_frame_index = video_clip_range
# Convert to int if possible, set end to None if infinite
start_frame_index = int(start_frame_index)
end_frame_index = int(end_frame_index) if end_frame_index != float("inf") else None
clip_frames: list[torch.Tensor] = []
clip_timestamp_start = None
clip_timestamp_end = None
# Find start frame
if (
iframe_info := self.seeker.should_seek(previous_frame_index, start_frame_index)
) is not None:
input_container.seek(iframe_info.pts, stream=input_container.streams.video[0])
previous_frame_index = iframe_info.index
for frame in frame_iterator:
take_frame = False
last_frame = False
# Container uses frame counts, we can find the exact target frame by counting from the iframe which is at a known offset
if self.seeker.unit == "frames":
if previous_frame_index >= start_frame_index:
take_frame = True
if end_frame_index is not None and previous_frame_index >= end_frame_index:
last_frame = True
# Container uses time, the target frame might not correspond exactly to any metadata but the desired timestamp should
# fall within a frames display period
if self.seeker.unit == "pts":
if start_frame_index <= frame.pts + average_frame_duration:
take_frame = True
if (
end_frame_index is not None
and end_frame_index <= frame.pts + average_frame_duration
):
last_frame = True
if take_frame:
if video_out_frame_size is not None:
frame = frame.reformat(
width=video_out_frame_size[0],
height=video_out_frame_size[1],
format="rgb24",
interpolation="BILINEAR",
)
else:
frame = frame.reformat(format="rgb24")
clip_frames.append(torch.from_numpy(frame.to_ndarray()))
if clip_timestamp_start is None:
clip_timestamp_start = float(frame.pts * frame.time_base)
clip_timestamp_end = float(
(frame.pts + average_frame_duration) * frame.time_base
)
previous_frame_index += 1
if last_frame:
break
if clip_timestamp_start is not None and clip_timestamp_end is not None:
video_clips_frames.append(clip_frames)
video_clips_timestamps.append((clip_timestamp_start, clip_timestamp_end))
# Stack frames within each clip
out_video_clips = [
torch.stack(clip_frames).permute((0, 3, 1, 2)) for clip_frames in video_clips_frames
]
return out_video_clips, video_clips_timestamps
[docs]
def get_audio(self) -> AVData:
"""Get the entire audio data from the stream."""
audio_clips, audio_timestamps = self.get_audio_clips(audio_clip_ranges=[(0, float("inf"))])
return AVData(
video_clips=[],
video_timestamps=[],
audio_clips=audio_clips,
audio_timestamps=audio_timestamps,
)
[docs]
def get_audio_clips(
self,
audio_clip_ranges: Sequence[tuple[float, float]],
audio_unit: Literal["samples", "seconds"] = "seconds",
) -> tuple[list[torch.Tensor], list[tuple[float, float]]]:
"""Get audio clips from the audio stream.
Args:
audio_clip_ranges: List of audio clip start and end positions in the given unit (see audio_unit)
audio_unit: Unit of the audio clip positions ("samples" for sample number, "seconds" for timestamp)
Returns:
A tuple containing:
- audio_clips: List of audio clips
- audio_clips_timestamps: List of timestamps for each audio clip start and end in seconds
"""
assert audio_unit in ("samples", "seconds")
self.stream.seek(0) # Reset the video stream so that pyav can read the entire container
with av.open(self.stream) as input_container:
initialize_av_container(input_container)
assert len(input_container.streams.audio) > 0, (
"No audio stream found, but audio_clips are requested"
)
audio_stream = input_container.streams.audio[0]
audio_sample_rate = audio_stream.sample_rate
assert audio_sample_rate, "Audio streams without sample rate are not supported"
if audio_unit == "samples":
# Convert from samples to seconds
audio_clip_ranges = [
(
float(clip[0] / audio_sample_rate),
float(clip[1] / audio_sample_rate),
)
for clip in audio_clip_ranges
]
out_audio_clips: list[torch.Tensor] = []
out_audio_clips_timestamps: list[tuple[float, float]] = []
def audio_frame_array(frame: av.AudioFrame) -> np.ndarray:
if frame.format.is_planar:
arr_processed = frame.to_ndarray() # Already (channels, samples)
else:
# Calculate the number of channels and samples
channels = int(frame.layout.nb_channels)
samples = int(frame.samples)
# Reshape the interleaved data to (samples, channels), then transpose to (channels, samples)
arr_processed = np.reshape(frame.to_ndarray(), (samples, channels)).transpose(
1, 0
)
return arr_processed
for start_time, end_time in audio_clip_ranges:
# Seek near start time, but rounded down to the nearest frame
input_container.seek(int(start_time * av.time_base))
if end_time != float("inf"):
desired_duration = end_time - start_time
desired_sample_count = int(desired_duration * audio_sample_rate + 0.5)
else:
desired_sample_count = None
clip_start_time = None
clip_end_time = None
decoded_samples = []
decoded_sample_count = 0
previous_frame = None
for frame in input_container.decode(audio=0):
assert frame.pts is not None, "Audio frame has no PTS timestamp"
cur_frame_time = float(frame.pts * frame.time_base)
cur_frame_duration = float(frame.samples / audio_sample_rate)
if cur_frame_time < start_time:
# Skip frames before the start time
previous_frame = frame
continue
if clip_start_time is None:
# This is our first matching frame
if previous_frame is not None:
# We have a previous frame that we need to crop to the start time
prev_start_time = float(previous_frame.pts * previous_frame.time_base)
prev_frame_array = audio_frame_array(previous_frame)
prev_frame_array = prev_frame_array[
:, int((start_time - prev_start_time) * audio_sample_rate + 0.5) :
]
decoded_samples.append(prev_frame_array)
decoded_sample_count += prev_frame_array.shape[1]
clip_start_time = start_time
clip_end_time = prev_start_time + cur_frame_duration
else:
clip_start_time = cur_frame_time
# Stop decoding if the end of the frame is past the end time
if cur_frame_time + cur_frame_duration >= end_time:
# Crop the last frame to the end time
last_frame_array = audio_frame_array(frame)
additional_samples = int(
(end_time - cur_frame_time) * audio_sample_rate + 0.5
)
projected_total_samples = decoded_sample_count + additional_samples
projected_total_samples = decoded_sample_count + additional_samples
if (
desired_sample_count is not None
and 0 < abs(projected_total_samples - desired_sample_count) < 2
):
# We are within 2 samples of the desired duration, let's adjust
# the last frame so that we get the desired duration
additional_samples = desired_sample_count - decoded_sample_count
last_frame_array = last_frame_array[:, :additional_samples]
decoded_samples.append(last_frame_array)
decoded_sample_count += last_frame_array.shape[1]
clip_end_time = end_time
break
frame_nd = audio_frame_array(frame) # (channels, samples)
decoded_samples.append(frame_nd)
decoded_sample_count += frame_nd.shape[1]
clip_end_time = cur_frame_time + cur_frame_duration
if decoded_samples:
# Combine all channels/samples along samples axis
clip_all = np.concatenate(decoded_samples, axis=-1) # (channels, total_samples)
if clip_start_time is not None and clip_end_time is not None:
out_audio_clips.append(torch.from_numpy(clip_all))
out_audio_clips_timestamps.append((clip_start_time, clip_end_time))
return out_audio_clips, out_audio_clips_timestamps
[docs]
def get_video_with_audio(self) -> AVData:
"""Get the entire video and audio data from the stream."""
return self.get_clips(
video_clip_ranges=[(0, float("inf"))],
audio_clip_ranges=[(0, float("inf"))],
video_unit="seconds",
audio_unit="seconds",
)
[docs]
def get_clips(
self,
video_clip_ranges: Optional[Sequence[tuple[float, float]]] = None,
audio_clip_ranges: Optional[Sequence[tuple[float, float]]] = None,
video_unit: Literal["frames", "seconds"] = "seconds",
audio_unit: Literal["samples", "seconds"] = "seconds",
video_out_frame_size: Optional[tuple[int, int]] = None,
) -> AVData:
"""Get clips from the video and/or audio streams.
Given a list of (start, end) tuples, this method will decode the video and/or audio clips
at the specified start and end times. The units of the start and end times are specified by
the `video_unit` and `audio_unit` arguments.
Args:
video_clip_ranges: List of video clip start and end positions in the given unit (see video_unit)
audio_clip_ranges: List of audio clip start and end positions in the given unit (see audio_unit)
video_unit: Unit of the video clip positions ("frames" for frame number, "seconds" for timestamp)
audio_unit: Unit of the audio clip positions ("samples" for sample number, "seconds" for timestamp)
video_out_frame_size: Output size for video frames (width, height), or None to use the original frame size
Returns:
AVData containing the decoded video and audio clips
"""
if video_clip_ranges is not None:
ret_video_clips, ret_video_clips_timestamps = self.get_video_clips(
video_clip_ranges, video_unit, video_out_frame_size
)
else:
ret_video_clips = []
ret_video_clips_timestamps = []
if audio_clip_ranges is not None:
ret_audio_clips, ret_audio_clips_timestamps = self.get_audio_clips(
audio_clip_ranges, audio_unit
)
else:
ret_audio_clips = []
ret_audio_clips_timestamps = []
return AVData(
video_clips=ret_video_clips,
video_timestamps=ret_video_clips_timestamps,
audio_clips=ret_audio_clips,
audio_timestamps=ret_audio_clips_timestamps,
)
[docs]
def get_frames(
self,
video_decode_audio: bool = False,
) -> Optional[AVData]:
"""Decode the audio/video data with the specified parameters.
Args:
audio_clip_duration: Duration of each audio clip in seconds
audio_num_clips: Number of audio clips to extract (-1 for all)
video_decode_audio: Whether to decode audio from video
video_num_frames: Number of video frames to extract
video_out_frame_size: Output size for video frames (width, height)
Returns:
VideoData containing the decoded frames and metadata, or None if decoding failed
The video tensor is in the shape (frames, channels, height, width)
The audio tensor is in the shape (channels, samples)
"""
extension = self._get_extension()
if extension in ("mov", "mp4", "webm", "mkv"):
if video_decode_audio:
return self.get_video_with_audio()
else:
return self.get_video()
elif extension in ("flac", "mp3", "wav"):
return self.get_audio()
else:
return None
def _get_extension(self) -> Optional[str]:
"""Get the file extension from the raw data."""
# Try to guess the file type using the first few bytes
self.stream.seek(0) # Reset stream position before guessing
ftype = filetype.guess(self.stream)
if ftype is None:
return None
return ftype.extension
[docs]
def get_video_fps(self) -> float:
"""Get the FPS of the video stream."""
self.stream.seek(0)
with av.open(self.stream) as input_container:
initialize_av_container(input_container)
video_stream = input_container.streams.video[0]
assert video_stream.average_rate is not None
return float(video_stream.average_rate)
[docs]
def get_audio_samples_per_second(self) -> int:
"""Get the number of samples per second of the audio stream."""
self.stream.seek(0)
with av.open(self.stream) as input_container:
initialize_av_container(input_container)
audio_stream = input_container.streams.audio[0]
assert audio_stream.sample_rate is not None
return int(audio_stream.sample_rate)
[docs]
def has_audio_stream(self) -> bool:
"""Check if the stream has an audio stream."""
self.stream.seek(0)
with av.open(self.stream) as input_container:
initialize_av_container(input_container)
return len(input_container.streams.audio) > 0
[docs]
def has_video_stream(self) -> bool:
"""Check if the stream has a video stream."""
self.stream.seek(0)
with av.open(self.stream) as input_container:
initialize_av_container(input_container)
return len(input_container.streams.video) > 0
[docs]
def get_audio_duration(self) -> Optional[float]:
"""Get the duration of the audio stream.
Returns:
The duration of the audio stream in seconds
"""
self.stream.seek(0)
with av.open(self.stream) as input_container:
initialize_av_container(input_container)
if input_container.streams.audio:
audio_time_base = input_container.streams.audio[0].time_base
audio_start_pts = input_container.streams.audio[0].start_time
if audio_start_pts is None:
audio_start_pts = 0.0
audio_duration = input_container.streams.audio[0].duration
if audio_time_base is not None and audio_duration is not None:
duration = int(audio_duration - audio_start_pts) * audio_time_base
return float(duration)
return None
@overload
def get_video_duration(self, get_frame_count: Literal[True]) -> tuple[Optional[float], int]: ...
@overload
def get_video_duration(
self, get_frame_count: bool = False
) -> tuple[Optional[float], Optional[int]]: ...
[docs]
def get_video_duration(
self, get_frame_count: bool = False
) -> tuple[Optional[float], Optional[int]]:
"""Get the duration of the video stream.
Args:
get_frame_count: Whether to return the number of frames in the video. This is a more costly operation.
Returns:
A tuple containing the duration in seconds, and the number of frames in the video
"""
video_duration = None
num_frames = None
duration = None
self.stream.seek(0) # Reset the video stream so that pyav can read the entire container
with av.open(self.stream) as input_container:
initialize_av_container(input_container)
if input_container.streams.video:
video_stream = input_container.streams.video[0]
assert video_stream.time_base is not None
video_start_pts = video_stream.start_time
if video_start_pts is None:
video_start_pts = 0.0
video_duration = video_stream.duration
if video_duration is None:
# If duration isn't found in header the whole video is decoded to
# determine the duration.
num_frames = 0
last_packet = None
for packet in input_container.demux(video=0):
if packet.pts is not None:
num_frames += 1
last_packet = packet
if last_packet is not None and last_packet.duration is not None:
assert last_packet.pts is not None
video_duration = last_packet.pts + last_packet.duration
if video_duration is not None and video_stream.time_base is not None:
duration = int(video_duration - video_start_pts) * video_stream.time_base
if get_frame_count and num_frames is None:
num_frames = sum(1 for p in input_container.demux(video=0) if p.pts is not None)
return float(duration) if duration is not None else None, num_frames
def __repr__(self):
return f"AVDecoder(stream={self.stream!r})"
[docs]
class AVWebdatasetDecoder:
"""A decoder class for audio and video data that provides a consistent interface for decoding media files.
This class encapsulates the decoding parameters and provides a callable interface that can be used
with webdataset or other data loading pipelines. It supports both video and audio decoding with
configurable parameters for frame extraction, resizing, and audio clip extraction.
Args:
video_decode_audio: Whether to decode audio from video files. If True, audio will be
extracted alongside video frames.
av_decode: If "AVDecoder", returns an AVDecoder instance for flexible decoding. If "torch",
returns decoded VideoData.
Example:
>>> decoder = AVWebdatasetDecoder(
... video_decode_audio=True,
... av_decode="AVDecoder"
... )
>>> result = decoder("video.mp4", video_bytes)
"""
def __init__(
self,
video_decode_audio: bool,
av_decode: Literal["torch", "AVDecoder", "pyav"] = "AVDecoder",
) -> None:
self.video_decode_audio = video_decode_audio
self.av_decode = av_decode
[docs]
def read_av_data(self, key: str, data: bytes) -> AVDecoder:
"""Decoder function that returns an AVData object for flexible decoding.
Args:
key: The file extension or key
data: The raw bytes of the media file
Returns:
AVData object that can be used to decode the media with custom parameters
"""
return AVDecoder(io.BytesIO(data))
def __call__(
self, key: str, data: bytes
) -> Optional[
Union[AVData, AVDecoder, "av.container.InputContainer", "av.container.OutputContainer"]
]:
"""
Extract the video or audio data from default media extensions.
Args:
key: media file extension
data: raw media bytes
Returns:
If av_decode is "torch", returns VideoData containing the decoded frames and metadata.
If av_decode is "AVDecoder", returns an AVDecoder instance for flexible decoding.
If av_decode is "pyav", returns an av.container.InputContainer or av.container.OutputContainer instance.
Returns None if decoding failed or file type is not supported.
"""
if not any(
key == ext or key.endswith("." + ext)
for ext in ("mp4", "avi", "mov", "webm", "mkv", "flac", "mp3", "wav")
):
return None
av_decoder = self.read_av_data(key, data)
if self.av_decode == "AVDecoder":
return av_decoder
elif self.av_decode == "pyav":
input_container = av.open(av_decoder.stream)
initialize_av_container(input_container)
return input_container
elif self.av_decode == "torch":
return av_decoder.get_frames(
video_decode_audio=self.video_decode_audio,
)
else:
raise ValueError(f"Invalid av_decode value: {self.av_decode}")
def initialize_av_container(input_container: "av.container.InputContainer") -> None:
"""Every PyAV container should be initialized with this function.
This function ensures that no additional threads are created.
This is to avoid deadlocks in ffmpeg when when deallocating the container.
Furthermore, we cannot have multiple threads before forking the process when
using torch data loaders with multiple workers.
"""
for stream in input_container.streams:
cc = stream.codec_context
if cc is not None:
cc.thread_type = "NONE"
cc.thread_count = 0