File size: 7,543 Bytes
aa0c2cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# We modified the original AVReader class of decord to solve the problem of memory leak.
# For more details, refer to: https://github.com/dmlc/decord/issues/208
import numpy as np
from decord.video_reader import VideoReader
from decord.audio_reader import AudioReader
from decord.ndarray import cpu
from decord import ndarray as _nd
from decord.bridge import bridge_out
class AVReader(object):
"""Individual audio video reader with convenient indexing function.
Parameters
----------
uri: str
Path of file.
ctx: decord.Context
The context to decode the file, can be decord.cpu() or decord.gpu().
sample_rate: int, default is -1
Desired output sample rate of the audio, unchanged if `-1` is specified.
mono: bool, default is True
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
width : int, default is -1
Desired output width of the video, unchanged if `-1` is specified.
height : int, default is -1
Desired output height of the video, unchanged if `-1` is specified.
num_threads : int, default is 0
Number of decoding thread, auto if `0` is specified.
fault_tol : int, default is -1
The threshold of corupted and recovered frames. This is to prevent silent fault
tolerance when for example 50% frames of a video cannot be decoded and duplicate
frames are returned. You may find the fault tolerant feature sweet in many cases,
but not for training models. Say `N = # recovered frames`
If `fault_tol` < 0, nothing will happen.
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
"""
def __init__(
self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
):
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
self.__audio_reader.add_padding()
if hasattr(uri, "read"):
uri.seek(0)
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
self.__video_reader.seek(0)
def __len__(self):
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
we always follow what FFMPEG reports.
Returns
-------
int
The number of frames in the video file.
"""
return len(self.__video_reader)
def __getitem__(self, idx):
"""Get audio samples and video frame at `idx`.
Parameters
----------
idx : int or slice
The frame index, can be negative which means it will index backwards,
or slice of frame indices.
Returns
-------
(ndarray/list of ndarray, ndarray)
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
where N is the number of frames, C is the number of channels,
S is the number of samples of the corresponding frame.
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
where N is the length of the slice.
"""
assert self.__video_reader is not None and self.__audio_reader is not None
if isinstance(idx, slice):
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
if idx < 0:
idx += len(self.__video_reader)
if idx >= len(self.__video_reader) or idx < 0:
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
self.__video_reader.seek(0)
return results
def get_batch(self, indices):
"""Get entire batch of audio samples and video frames.
Parameters
----------
indices : list of integers
A list of frame indices. If negative indices detected, the indices will be indexed from backward
Returns
-------
(list of ndarray, ndarray)
First element is a list of length N containing samples of shape CxS,
where N is the number of frames, C is the number of channels,
S is the number of samples of the corresponding frame.
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
where N is the length of the slice.
"""
assert self.__video_reader is not None and self.__audio_reader is not None
indices = self._validate_indices(indices)
audio_arr = []
prev_video_idx = None
prev_audio_end_idx = None
for idx in list(indices):
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
# timestamp and sample conversion could have some error that could cause non-continuous audio
# we detect if retrieving continuous frame and make the audio continuous
if prev_video_idx and idx == prev_video_idx + 1:
audio_start_idx = prev_audio_end_idx
else:
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
prev_video_idx = idx
prev_audio_end_idx = audio_end_idx
results = (audio_arr, self.__video_reader.get_batch(indices))
self.__video_reader.seek(0)
return results
def _get_slice(self, sl):
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
for idx in list(sl):
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
audio_arr = np.concatenate(
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
)
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
self.__video_reader.seek(0)
return results
def _validate_indices(self, indices):
"""Validate int64 integers and convert negative integers to positive by backward search"""
assert self.__video_reader is not None and self.__audio_reader is not None
indices = np.array(indices, dtype=np.int64)
# process negative indices
indices[indices < 0] += len(self.__video_reader)
if not (indices >= 0).all():
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
if not (indices < len(self.__video_reader)).all():
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
return indices
|