Skip to content

Commit

Permalink
decouple compensate/degrain tr
Browse files Browse the repository at this point in the history
  • Loading branch information
emotion3459 committed Dec 15, 2024
1 parent d458295 commit 534a9f5
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions vsdenoise/mvtools/mvtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ def compensate( # type: ignore
self, func: Union[
Callable[Concatenate[vs.VideoNode, P], vs.VideoNode],
Callable[Concatenate[list[vs.VideoNode], P], vs.VideoNode]
], thSAD: int = 150, thSCD: int | tuple[int | None, int | None] | None = (None, 51),
] | None = None,
tr: int | None = None, thSAD: int = 150, thSCD: int | tuple[int | None, int | None] | None = (None, 51),
supers: SuperClips | None = None, *args: P.args, ref: vs.VideoNode | None = None,
**kwargs: P.kwargs
) -> vs.VideoNode:
Expand Down Expand Up @@ -579,8 +580,8 @@ def compensate( # type: ignore

@overload
def compensate(
self, func: None,
thSAD: int = 150, thSCD: int | tuple[int | None, int | None] | None = (None, 51),
self, func: None = None,
tr: int | None = None, thSAD: int = 150, thSCD: int | tuple[int | None, int | None] | None = (None, 51),
supers: SuperClips | None = None, ref: vs.VideoNode | None = None
) -> tuple[vs.VideoNode, tuple[int, int]]:
"""
Expand Down Expand Up @@ -626,16 +627,18 @@ def compensate( # type: ignore
self, func: Union[
Callable[Concatenate[vs.VideoNode, P], vs.VideoNode],
Callable[Concatenate[list[vs.VideoNode], P], vs.VideoNode]
] | None, thSAD: int = 150, thSCD: int | tuple[int | None, int | None] | None = (None, 51),
] | None = None,
tr: int | None = None, thSAD: int = 150, thSCD: int | tuple[int | None, int | None] | None = (None, 51),
supers: SuperClips | None = None, *args: P.args, ref: vs.VideoNode | None = None,
**kwargs: P.kwargs
) -> vs.VideoNode | tuple[vs.VideoNode, tuple[int, int]]:
ref = self.get_ref_clip(ref, self.compensate)
tr = min(tr, self.tr) if tr else self.tr

thSCD1, thSCD2 = self.normalize_thscd(thSCD, thSAD, self.compensate)
supers = supers or self.get_supers(ref, inplace=True)

vect_b, vect_f = self.get_vectors_bf(self.vectors)
vect_b, vect_f = self.get_vectors_bf(self.vectors, tr=tr)

compensate_args = dict(
super=supers.render, thsad=thSAD,
Expand All @@ -646,10 +649,10 @@ def compensate( # type: ignore

comp_back, comp_forw = [
[self.mvtools.Compensate(ref, vectors=vect, **compensate_args) for vect in vectors]
for vectors in (reversed(vect_b), vect_f)
for vectors in (vect_b, vect_f)
]

comp_clips = [*comp_forw, ref, *comp_back]
comp_clips = [*comp_back, ref, *comp_forw]
n_clips = len(comp_clips)
offset = (n_clips - 1) // 2

Expand Down Expand Up @@ -732,6 +735,7 @@ def flow( # type: ignore

def degrain(
self,
tr: int | None = None,
thSAD: int | tuple[int | None, int | None] | None = None,
limit: int | tuple[int, int] = 255,
thSCD: int | tuple[int | None, int | None] | None = (None, 51),
Expand Down Expand Up @@ -779,13 +783,14 @@ def degrain(
"""

ref = self.get_ref_clip(ref, self.degrain)
tr = min(tr, self.tr) if tr else self.tr

if isinstance(vectors, MVTools):
vectors = vectors.vectors
elif vectors is None:
vectors = self.vectors

vect_b, vect_f = self.get_vectors_bf(vectors, supers=supers, ref=ref)
vect_b, vect_f = self.get_vectors_bf(vectors, supers=supers, ref=ref, tr=tr)
supers = supers or self.get_supers(ref, inplace=True)

thSAD, thSADC = (thSAD if isinstance(thSAD, tuple) else (thSAD, None))
Expand All @@ -810,7 +815,7 @@ def degrain(
if self.mvtools is MVToolsPlugin.FLOAT_NEW:
output = self.mvtools.Degrain()(ref, supers.render, vectors.vmulti, **degrain_args)
else:
output = self.mvtools.Degrain(self.tr)(
output = self.mvtools.Degrain(tr)(
ref, supers.render, *chain.from_iterable(zip(vect_b, vect_f)), **degrain_args
)

Expand Down Expand Up @@ -937,7 +942,7 @@ def get_supers(self, ref: vs.VideoNode, *, inplace: bool = False) -> SuperClips:

def get_vectors_bf(
self, vectors: MotionVectors, *, supers: SuperClips | None = None,
ref: vs.VideoNode | None = None, inplace: bool = False
ref: vs.VideoNode | None = None, tr: int | None = None, inplace: bool = False
) -> tuple[list[vs.VideoNode], list[vs.VideoNode]]:
"""
Get the backwards and forward vectors.\n
Expand All @@ -954,7 +959,8 @@ def get_vectors_bf(
if not vectors.has_vectors:
vectors = self.analyze(supers=supers, ref=ref, inplace=inplace)

t2 = (self.tr * 2 if self.tr > 1 else self.tr) if self.source_type.is_inter else self.tr
tr = min(tr, self.tr) if tr else self.tr
t2 = (tr * 2 if tr > 1 else tr) if self.source_type.is_inter else tr

vectors_backward = list[vs.VideoNode]()
vectors_forward = list[vs.VideoNode]()
Expand Down

0 comments on commit 534a9f5

Please sign in to comment.