Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions av/video/reformatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,19 @@ def _reformat(
frame = frame_sw
src_format = cython.cast(lib.AVPixelFormat, frame.ptr.format)

# Check for shortcut again, in case dst_format matches the downloaded frame's sw_format
if (
dst_format == src_format
and width == frame.ptr.width
and height == frame.ptr.height
and dst_colorspace == src_colorspace
and src_color_range == dst_color_range
and not set_dst_color_trc
and not set_dst_color_primaries
):
frame._init_user_attributes()
return frame

if self.ptr == cython.NULL:
self.ptr = sws_alloc_context()
if self.ptr == cython.NULL:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,48 @@ def test_hardware_decode_download_preserves_frame_props(is_hw_owned: bool) -> No
assert_video_frame_color_props_match(hw_frame, cpu_frame)


def test_hardware_frame_reformat_matches_downloaded_frame() -> None:
hwdevices_available = av.codec.hwaccel.hwdevices_available()
if "HWACCEL_DEVICE_TYPE" not in os.environ:
pytest.skip(
"Set the HWACCEL_DEVICE_TYPE to run this test. "
f"Options are {' '.join(hwdevices_available)}"
)

hwaccel_device_type = os.environ["HWACCEL_DEVICE_TYPE"]
assert hwaccel_device_type in hwdevices_available, (
f"{hwaccel_device_type} not available"
)

test_video_path = fate_suite("h264/interlaced_crop.mp4")
downloaded_frame = decode_first_video_frame(
test_video_path,
av.codec.hwaccel.HWAccel(
device_type=hwaccel_device_type,
is_hw_owned=False,
allow_software_fallback=False,
),
)
hw_frame = decode_first_video_frame(
test_video_path,
av.codec.hwaccel.HWAccel(
device_type=hwaccel_device_type,
is_hw_owned=True,
allow_software_fallback=False,
),
)
assert downloaded_frame.format.name != hw_frame.format.name # E.g. cuda vs nv12

# Download hw_frame to CPU and ensure the contents match downloaded_frame
sw_format = downloaded_frame.format.name
reformatted_frame = hw_frame.reformat(format=sw_format)
assert reformatted_frame.format.name == downloaded_frame.format.name
assert np.array_equal(
reformatted_frame.to_ndarray(format=sw_format),
downloaded_frame.to_ndarray(format=sw_format),
)


def decode_first_video_frame(
path: str, hwaccel: av.codec.hwaccel.HWAccel | None = None
) -> av.VideoFrame:
Expand Down
Loading