Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions av/filter/context.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ from av.filter.graph cimport Graph


cdef class FilterContext:

cdef lib.AVFilterContext *ptr
cdef readonly object _graph
cdef readonly Filter filter

cdef object _inputs
cdef object _outputs

cdef bint inited
cdef unsigned char _kind


cdef FilterContext wrap_filter_context(Graph graph, Filter filter, lib.AVFilterContext *ptr)
21 changes: 18 additions & 3 deletions av/filter/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

_cinit_sentinel = cython.declare(object, object())

_KIND_OTHER = cython.declare(cython.uchar, 0)
_KIND_SOURCE = cython.declare(cython.uchar, 1) # buffer / abuffer
_KIND_VIDEO_SINK = cython.declare(cython.uchar, 2) # buffersink
_KIND_AUDIO_SINK = cython.declare(cython.uchar, 3) # abuffersink


@cython.cfunc
def wrap_filter_context(
Expand All @@ -21,6 +26,16 @@ def wrap_filter_context(
self._graph = weakref.ref(graph)
self.filter = filter
self.ptr = ptr

name: str = filter.name
if name == "buffer" or name == "abuffer":
self._kind = _KIND_SOURCE
elif name == "buffersink":
self._kind = _KIND_VIDEO_SINK
elif name == "abuffersink":
self._kind = _KIND_AUDIO_SINK
else:
self._kind = _KIND_OTHER
return self


Expand Down Expand Up @@ -108,7 +123,7 @@ def push(self, frame: Frame | None):
res = lib.av_buffersrc_write_frame(self.ptr, cython.NULL)
err_check(res)
return
elif self.filter.name in ("abuffer", "buffer"):
elif self._kind == _KIND_SOURCE:
with cython.nogil:
res = lib.av_buffersrc_write_frame(self.ptr, frame.ptr)
err_check(res)
Expand All @@ -126,9 +141,9 @@ def push(self, frame: Frame | None):
def pull(self):
frame: Frame
res: cython.int
if self.filter.name == "buffersink":
if self._kind == _KIND_VIDEO_SINK:
frame = alloc_video_frame()
elif self.filter.name == "abuffersink":
elif self._kind == _KIND_AUDIO_SINK:
frame = alloc_audio_frame()
else:
# Delegate to the output.
Expand Down
2 changes: 2 additions & 0 deletions av/filter/graph.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ cdef class Graph:
cdef int _nb_filters_seen
cdef dict[long, FilterContext] _context_by_ptr
cdef dict[str, list[FilterContext]] _context_by_type
cdef list[FilterContext] _video_sources
cdef list[FilterContext] _audio_sources
19 changes: 12 additions & 7 deletions av/filter/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __cinit__(self):
self._nb_filters_seen = 0
self._context_by_ptr = {}
self._context_by_type = {}
self._video_sources = []
self._audio_sources = []

def __dealloc__(self):
if self.ptr:
Expand Down Expand Up @@ -108,8 +110,13 @@ def add(self, filter, args=None, **kwargs):

@cython.cfunc
def _register_context(self, ctx: FilterContext):
name: str = ctx.filter.ptr.name
self._context_by_ptr[cython.cast(cython.long, ctx.ptr)] = ctx
self._context_by_type.setdefault(ctx.filter.ptr.name, []).append(ctx)
self._context_by_type.setdefault(name, []).append(ctx)
if name == "buffer":
self._video_sources.append(ctx)
elif name == "abuffer":
self._audio_sources.append(ctx)

@cython.cfunc
def _auto_register(self):
Expand Down Expand Up @@ -234,13 +241,11 @@ def set_audio_frame_size(self, frame_size):

def push(self, frame, at: cython.int = -1):
if frame is None:
contexts = self._get_context_by_type("buffer") + self._get_context_by_type(
"abuffer"
)
contexts = self._video_sources + self._audio_sources
elif isinstance(frame, VideoFrame):
contexts = self._get_context_by_type("buffer")
contexts = self._video_sources
elif isinstance(frame, AudioFrame):
contexts = self._get_context_by_type("abuffer")
contexts = self._audio_sources
else:
raise ValueError(
f"can only AudioFrame, VideoFrame or None; got {type(frame)}"
Expand All @@ -259,7 +264,7 @@ def push(self, frame, at: cython.int = -1):

def vpush(self, frame: VideoFrame | None, at: cython.int = -1):
"""Like `push`, but only for VideoFrames."""
contexts = self._get_context_by_type("buffer")
contexts = self._video_sources
if at >= 0:
if at >= len(contexts):
raise IndexError(
Expand Down
Loading