|
|
|
|
|
|
|
import numpy as np
|
|
from decord.video_reader import VideoReader
|
|
from decord.audio_reader import AudioReader
|
|
|
|
from decord.ndarray import cpu
|
|
from decord import ndarray as _nd
|
|
from decord.bridge import bridge_out
|
|
|
|
|
|
class AVReader(object):
|
|
"""Individual audio video reader with convenient indexing function.
|
|
|
|
Parameters
|
|
----------
|
|
uri: str
|
|
Path of file.
|
|
ctx: decord.Context
|
|
The context to decode the file, can be decord.cpu() or decord.gpu().
|
|
sample_rate: int, default is -1
|
|
Desired output sample rate of the audio, unchanged if `-1` is specified.
|
|
mono: bool, default is True
|
|
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
|
|
width : int, default is -1
|
|
Desired output width of the video, unchanged if `-1` is specified.
|
|
height : int, default is -1
|
|
Desired output height of the video, unchanged if `-1` is specified.
|
|
num_threads : int, default is 0
|
|
Number of decoding thread, auto if `0` is specified.
|
|
fault_tol : int, default is -1
|
|
The threshold of corupted and recovered frames. This is to prevent silent fault
|
|
tolerance when for example 50% frames of a video cannot be decoded and duplicate
|
|
frames are returned. You may find the fault tolerant feature sweet in many cases,
|
|
but not for training models. Say `N = # recovered frames`
|
|
If `fault_tol` < 0, nothing will happen.
|
|
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
|
|
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
|
|
"""
|
|
|
|
def __init__(
|
|
self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
|
|
):
|
|
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
|
|
self.__audio_reader.add_padding()
|
|
if hasattr(uri, "read"):
|
|
uri.seek(0)
|
|
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
|
|
self.__video_reader.seek(0)
|
|
|
|
def __len__(self):
|
|
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
|
|
we always follow what FFMPEG reports.
|
|
Returns
|
|
-------
|
|
int
|
|
The number of frames in the video file.
|
|
"""
|
|
return len(self.__video_reader)
|
|
|
|
def __getitem__(self, idx):
|
|
"""Get audio samples and video frame at `idx`.
|
|
|
|
Parameters
|
|
----------
|
|
idx : int or slice
|
|
The frame index, can be negative which means it will index backwards,
|
|
or slice of frame indices.
|
|
|
|
Returns
|
|
-------
|
|
(ndarray/list of ndarray, ndarray)
|
|
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
|
|
where N is the number of frames, C is the number of channels,
|
|
S is the number of samples of the corresponding frame.
|
|
|
|
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
|
where N is the length of the slice.
|
|
"""
|
|
assert self.__video_reader is not None and self.__audio_reader is not None
|
|
if isinstance(idx, slice):
|
|
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
|
|
if idx < 0:
|
|
idx += len(self.__video_reader)
|
|
if idx >= len(self.__video_reader) or idx < 0:
|
|
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
|
|
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
|
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
|
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
|
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
|
|
self.__video_reader.seek(0)
|
|
return results
|
|
|
|
def get_batch(self, indices):
|
|
"""Get entire batch of audio samples and video frames.
|
|
|
|
Parameters
|
|
----------
|
|
indices : list of integers
|
|
A list of frame indices. If negative indices detected, the indices will be indexed from backward
|
|
Returns
|
|
-------
|
|
(list of ndarray, ndarray)
|
|
First element is a list of length N containing samples of shape CxS,
|
|
where N is the number of frames, C is the number of channels,
|
|
S is the number of samples of the corresponding frame.
|
|
|
|
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
|
where N is the length of the slice.
|
|
|
|
"""
|
|
assert self.__video_reader is not None and self.__audio_reader is not None
|
|
indices = self._validate_indices(indices)
|
|
audio_arr = []
|
|
prev_video_idx = None
|
|
prev_audio_end_idx = None
|
|
for idx in list(indices):
|
|
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
|
|
|
|
|
|
if prev_video_idx and idx == prev_video_idx + 1:
|
|
audio_start_idx = prev_audio_end_idx
|
|
else:
|
|
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
|
|
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
|
|
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
|
|
prev_video_idx = idx
|
|
prev_audio_end_idx = audio_end_idx
|
|
results = (audio_arr, self.__video_reader.get_batch(indices))
|
|
self.__video_reader.seek(0)
|
|
return results
|
|
|
|
def _get_slice(self, sl):
|
|
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
|
|
for idx in list(sl):
|
|
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
|
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
|
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
|
audio_arr = np.concatenate(
|
|
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
|
|
)
|
|
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
|
|
self.__video_reader.seek(0)
|
|
return results
|
|
|
|
def _validate_indices(self, indices):
|
|
"""Validate int64 integers and convert negative integers to positive by backward search"""
|
|
assert self.__video_reader is not None and self.__audio_reader is not None
|
|
indices = np.array(indices, dtype=np.int64)
|
|
|
|
indices[indices < 0] += len(self.__video_reader)
|
|
if not (indices >= 0).all():
|
|
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
|
|
if not (indices < len(self.__video_reader)).all():
|
|
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
|
|
return indices
|
|
|