Skip to content

Commit

Permalink
Fix miscellaneous (#53)
Browse files Browse the repository at this point in the history
* fix morpho calls

* fix misc stuff

* fix fine_dehalo2
  • Loading branch information
Ichunjo authored Jan 19, 2025
1 parent d4143ff commit ff88863
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 142 deletions.
192 changes: 72 additions & 120 deletions vsdehalo/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
from vsdenoise import Prefilter
from vsexprtools import ExprOp, combine, complexpr_available, norm_expr
from vskernels import Bilinear, BSpline, Lanczos, Mitchell, NoShift, Point, Scaler, ScalerT
from vsmasktools import EdgeDetect, Morpho, Robinson3, XxpandMode, grow_mask, retinex
from vsmasktools import EdgeDetect, Morpho, RadiusT, Robinson3, XxpandMode, grow_mask, retinex
from vsrgtools import (
RemoveGrainMode, RepairMode, box_blur, contrasharpening, contrasharpening_dehalo, gauss_blur, limit_filter, repair
BlurMatrix, BlurMatrixBase, RemoveGrainMode, RepairMode, box_blur, contrasharpening,
contrasharpening_dehalo, gauss_blur, limit_filter, repair
)
from vsrgtools.util import norm_rmode_planes
from vstools import (
ConvMode, CustomIndexError, CustomIntEnum, CustomValueError, FieldBased, FuncExceptT, FunctionUtil,
InvalidColorFamilyError, KwargsT, PlanesT, UnsupportedFieldBasedError, check_ref_clip, check_variable, clamp,
cround, fallback, get_peak_value, get_y, join, mod4, normalize_planes, normalize_seq, scale_mask, split, to_arr,
vs
ConvMode, CustomIndexError, CustomIntEnum, CustomValueError, FieldBased, FuncExceptT,
FunctionUtil, InvalidColorFamilyError, KwargsT, OneDimConvModeT, PlanesT,
UnsupportedFieldBasedError, check_ref_clip, check_variable, check_variable_format, clamp,
cround, fallback, get_peak_value, get_y, join, mod4, normalize_planes, normalize_seq,
scale_mask, split, to_arr, vs
)

__all__ = [
Expand All @@ -42,15 +44,15 @@ def _limit_dehalo(

def _dehalo_mask(
clip: vs.VideoNode, ref: vs.VideoNode, lowsens: list[float], highsens: list[float],
sigma_mask: float | bool, mask_radius: int, mask_coords: int | tuple[int, ConvMode] | Sequence[int],
sigma_mask: float | bool, mask_radius: RadiusT, mask_coords: Sequence[int] | None,
planes: PlanesT
) -> vs.VideoNode:
peak = get_peak_value(clip)

mask = norm_expr(
[
Morpho.gradient(clip, mask_radius, planes, coords=mask_coords),
Morpho.gradient(ref, mask_radius, planes, coords=mask_coords)
Morpho.gradient(clip, mask_radius, planes=planes, coords=mask_coords),
Morpho.gradient(ref, mask_radius, planes=planes, coords=mask_coords)
],
'x 0 = 0.0 x y - x / ? {lowsens} - x {peak} / 256 255 / + 512 255 / / {highsens} + * '
'0.0 max 1.0 min {peak} *', planes, peak=peak,
Expand Down Expand Up @@ -94,12 +96,12 @@ def _supersample(work_clip: vs.VideoNode, dehalo: vs.VideoNode, ss: float) -> vs

w, h = mod4(work_clip.width * ss), mod4(work_clip.height * ss)
ss_clip = norm_expr([
supersampler.scale(work_clip, w, h), # type: ignore
supersampler_ref.scale(dehalo.std.Maximum(), w, h), # type: ignore
supersampler_ref.scale(dehalo.std.Minimum(), w, h) # type: ignore
supersampler.scale(work_clip, w, h),
supersampler_ref.scale(dehalo.std.Maximum(), w, h),
supersampler_ref.scale(dehalo.std.Minimum(), w, h)
], 'x y min z max', planes)

return supersampler.scale(ss_clip, work_clip.width, work_clip.height) # type: ignore
return supersampler.scale(ss_clip, work_clip.width, work_clip.height)

if len(set(ss)) == 1 or planes == [0] or clip.format.num_planes == 1: # type: ignore
dehalo = _supersample(clip, ref, ss[0])
Expand Down Expand Up @@ -134,10 +136,10 @@ def __call__(
ss: FloatIterArr = 1.5, contra: int | float | bool = 0.0, exclude: bool = True,
edgeproc: float = 0.0, edgemask: EdgeDetect = Robinson3(), planes: PlanesT = 0,
show_mask: int | FineDehaloMask | bool = False,
mask_radius: int = 1, downscaler: ScalerT = Mitchell, upscaler: ScalerT = BSpline,
mask_radius: RadiusT = 1, downscaler: ScalerT = Mitchell, upscaler: ScalerT = BSpline,
supersampler: ScalerT = Lanczos(3), supersampler_ref: ScalerT = Mitchell, pre_ss: float = 1.0,
pre_supersampler: ScalerT = Nnedi3(0, field=0, shifter=NoShift), pre_downscaler: ScalerT = Point,
mask_coords: int | tuple[int, ConvMode] | Sequence[int] = 3,
mask_coords: Sequence[int] | None = None,
func: FuncExceptT | None = None
) -> vs.VideoNode:
"""
Expand Down Expand Up @@ -288,7 +290,7 @@ def __call__(
dehaloed = contrasharpening_dehalo(dehaloed, work_clip, contra, planes=planes)
else:
dehaloed = contrasharpening(
dehaloed, work_clip, None if contra is True else contra, planes=planes
dehaloed, work_clip, int(contra), planes=planes
)

y_merge = work_clip.std.MaskedMerge(dehaloed, mask, planes)
Expand Down Expand Up @@ -358,25 +360,24 @@ def mask(

def fine_dehalo2(
clip: vs.VideoNode,
mode: ConvMode = ConvMode.HV,
mode: OneDimConvModeT = ConvMode.HV,
radius: int = 2, mask_radius: int = 2,
brightstr: float = 1.0, darkstr: float = 1.0,
dark: bool | None = True, planes: PlanesT = 0,
dark: bool | None = True,
show_mask: bool = False
) -> vs.VideoNode:
"""
Halo removal function for 2nd order halos.
:param clip: Source clip.
:param mode: Horizontal/Vertical or both ways.
:param radius: Radius for mask growing.
:param radius: Radius for the fixing convolution.
:param brightstr: Strength factor for bright halos.
:param darkstr: Strength factor for dark halos.
:param dark: Whether to filter for dark or bright haloing.
None for disable merging with source clip.
:param planes: Planes to process.
:param show_mask: Whether to return the computed mask.
:param clip: Source clip.
:param mode: Horizontal/Vertical or both ways.
:param radius: Radius for the fixing convolution.
:param mask_radius: Radius for mask growing.
:param brightstr: Strength factor for bright halos.
:param darkstr: Strength factor for dark halos.
:param dark: Whether to filter for dark or bright haloing.
None for disable merging with source clip.
:param show_mask: Whether to return the computed mask.
:return: Dehaloed clip.
"""
Expand All @@ -385,81 +386,31 @@ def fine_dehalo2(
if clip.format.color_family not in {vs.YUV, vs.GRAY}:
raise ValueError('fine_dehalo2: format not supported')

planes = normalize_planes(clip, planes)

is_float = clip.format.sample_type == vs.FLOAT

work_clip, *chroma = split(clip) if planes == [0] else (clip, )
work_clip, *chroma = split(clip)

mask_h = mask_v = None

mask_h_conv = [1, 2, 1, 0, 0, 0, -1, -2, -1]
mask_v_conv = [1, 0, -1, 2, 0, -2, 1, 0, -1]

# intended to be reversed
if complexpr_available:
h_mexpr, v_mexpr = [
ExprOp.convolution('x', coord, None, 4, False)
for coord in (mask_h_conv, mask_v_conv)
]

if mode == ConvMode.HV:
do_mv = do_mh = True
else:
do_mv, do_mh = [
mode == m for m in {ConvMode.HORIZONTAL, ConvMode.VERTICAL}
]

mask_args = (h_mexpr, do_mv, do_mh, v_mexpr)

mask_h, mask_v = [
norm_expr(work_clip, [
mexpr, 3, ExprOp.MUL, [omexpr, ExprOp.SUB] if do_om else None, ExprOp.clamp(0, 1) if is_float else None
], planes) if do_m else None
for mexpr, do_m, do_om, omexpr in [mask_args, mask_args[::-1]]
]
else:
if mode in {ConvMode.HV, ConvMode.VERTICAL}:
mask_h = work_clip.std.Convolution(mask_h_conv, None, 4, planes, False)

if mode in {ConvMode.HV, ConvMode.HORIZONTAL}:
mask_v = work_clip.std.Convolution(mask_v_conv, None, 4, planes, False)

if mask_h and mask_v:
mask_h2 = norm_expr([mask_h, mask_v], 'x 3 * y -', planes)
mask_v2 = norm_expr([mask_v, mask_h], 'x 3 * y -', planes)
mask_h, mask_v = mask_h2, mask_v2
elif mask_h:
mask_h = norm_expr(mask_h, 'x 3 *', planes)
elif mask_v:
mask_v = norm_expr(mask_v, 'x 3 *', planes)

if is_float:
mask_h = mask_h and mask_h.std.Limiter(planes=planes)
mask_v = mask_v and mask_v.std.Limiter(planes=planes)
if mode in {ConvMode.HV, ConvMode.VERTICAL}:
mask_h = BlurMatrixBase([1, 2, 1, 0, 0, 0, -1, -2, -1], ConvMode.V)(work_clip, divisor=4, saturate=False)

fix_weights = list(range(-1, -radius - 1, -1))
fix_rweights = list(reversed(fix_weights))
fix_zeros, fix_mweight = [0] * radius, 10 * (radius + 2)
if mode in {ConvMode.HV, ConvMode.HORIZONTAL}:
mask_v = BlurMatrixBase([1, 0, -1, 2, 0, -2, 1, 0, -1], ConvMode.H)(work_clip, divisor=4, saturate=False)

fix_h_conv = [*fix_weights, *fix_zeros, fix_mweight, *fix_zeros, *fix_rweights]
fix_v_conv = [*fix_rweights, *fix_zeros, fix_mweight, *fix_zeros, *fix_weights]
if mask_h and mask_v:
mask_h2 = norm_expr([mask_h, mask_v], ['x 3 * y -', ExprOp.clamp()])
mask_v2 = norm_expr([mask_v, mask_h], ['x 3 * y -', ExprOp.clamp()])
mask_h, mask_v = mask_h2, mask_v2
elif mask_h:
mask_h = norm_expr(mask_h, ['x 3 *', ExprOp.clamp()])
elif mask_v:
mask_v = norm_expr(mask_v, ['x 3 *', ExprOp.clamp()])

fix_h, fix_v = [
norm_expr(work_clip, ExprOp.convolution('x', coord, mode=mode), planes)
if complexpr_available else
work_clip.std.Convolution(coord, planes=planes, mode=mode)
for coord, mode in [(fix_h_conv, ConvMode.HORIZONTAL), (fix_v_conv, ConvMode.VERTICAL)]
]

mask_h, mask_v = [
grow_mask(mask, mask_radius, 1.8, planes, coordinates=coord) if mask else None
for mask, coord in [
(mask_h, [0, 1, 0, 0, 0, 0, 1, 0]), (mask_v, [0, 0, 0, 1, 1, 0, 0, 0])
]
]
if mask_h:
mask_h = grow_mask(mask_h, mask_radius, coord=[0, 1, 0, 0, 0, 0, 1, 0], multiply=1.8)
if mask_v:
mask_v = grow_mask(mask_v, mask_radius, coord=[0, 0, 0, 1, 1, 0, 0, 0], multiply=1.8)

if is_float:
if clip.format.sample_type == vs.FLOAT:
mask_h = mask_h and mask_h.std.Limiter()
mask_v = mask_v and mask_v.std.Limiter()

Expand All @@ -471,27 +422,27 @@ def fine_dehalo2(

return ret_mask

dehaloed = work_clip
op = '' if dark is None else ExprOp.MAX if dark else ExprOp.MIN
fix_weights = list(range(-1, -radius - 1, -1))
fix_rweights = list(reversed(fix_weights))
fix_zeros, fix_mweight = [0] * radius, 10 * (radius + 2)

if complexpr_available and mask_h and mask_v and clip.format.sample_type is vs.FLOAT:
d_clips = [work_clip, fix_h, fix_v, mask_h, mask_v]
d_expr = 'x 1 a - * y a * + 1 b - * z b * +'
fix_h_conv = [*fix_weights, *fix_zeros, fix_mweight, *fix_zeros, *fix_rweights]
fix_v_conv = [*fix_rweights, *fix_zeros, fix_mweight, *fix_zeros, *fix_weights]

if op:
d_clips, d_expr = [*d_clips, clip], f'{d_expr} x {op}'
fix_h = ExprOp.convolution('x', fix_h_conv, mode=ConvMode.HORIZONTAL)(work_clip)
fix_v = ExprOp.convolution('x', fix_v_conv, mode=ConvMode.VERTICAL)(work_clip)

dehaloed = norm_expr(d_clips, d_expr, planes)
else:
for fix, mask in [(fix_h, mask_v), (fix_v, mask_h)]:
if mask:
dehaloed = dehaloed.std.MaskedMerge(fix, mask)
dehaloed = work_clip

for fix, mask in [(fix_h, mask_v), (fix_v, mask_h)]:
if mask:
dehaloed = dehaloed.std.MaskedMerge(fix, mask)

if op:
dehaloed = combine([work_clip, dehaloed], op) # type: ignore
if dark is not None:
dehaloed = combine([work_clip, dehaloed], ExprOp.MAX if dark else ExprOp.MIN)

if darkstr != brightstr != 1.0:
dehaloed = _limit_dehalo(work_clip, dehaloed, darkstr, brightstr, planes)
dehaloed = _limit_dehalo(work_clip, dehaloed, darkstr, brightstr, 0)

if not chroma:
return dehaloed
Expand All @@ -503,10 +454,10 @@ def dehalo_alpha(
clip: vs.VideoNode, rx: FloatIterArr = 2.0, ry: FloatIterArr | None = None, darkstr: FloatIterArr = 0.0,
brightstr: FloatIterArr = 1.0, lowsens: FloatIterArr = 50.0, highsens: FloatIterArr = 50.0,
sigma_mask: float | bool = False, ss: FloatIterArr = 1.5, planes: PlanesT = 0, show_mask: bool = False,
mask_radius: int = 1, downscaler: ScalerT = Mitchell, upscaler: ScalerT = BSpline,
mask_radius: RadiusT = 1, downscaler: ScalerT = Mitchell, upscaler: ScalerT = BSpline,
supersampler: ScalerT = Lanczos(3), supersampler_ref: ScalerT = Mitchell, pre_ss: float = 1.0,
pre_supersampler: ScalerT = Nnedi3(0, field=0, shifter=NoShift), pre_downscaler: ScalerT = Point,
mask_coords: int | tuple[int, ConvMode] | Sequence[int] = 3,
mask_coords: Sequence[int] | None = None,
func: FuncExceptT | None = None
) -> vs.VideoNode:
"""
Expand Down Expand Up @@ -571,7 +522,7 @@ def dehalo_alpha(
)

def _rescale(clip: vs.VideoNode, rx: float, ry: float) -> vs.VideoNode:
return upscaler.scale(downscaler.scale( # type: ignore
return upscaler.scale(downscaler.scale(
clip, mod4(clip.width / rx), mod4(clip.height / ry)
), clip.width, clip.height)

Expand Down Expand Up @@ -618,8 +569,8 @@ def dehalo_sigma(
blur_func: Prefilter = Prefilter.GAUSS, planes: PlanesT = 0,
supersampler: ScalerT = Lanczos(3), supersampler_ref: ScalerT = Mitchell,
pre_ss: float = 1.0, pre_supersampler: ScalerT = Nnedi3(0, field=0, shifter=NoShift),
pre_downscaler: ScalerT = Point, mask_radius: int = 1, sigma_mask: float | bool = False,
mask_coords: int | tuple[int, ConvMode] | Sequence[int] = 3,
pre_downscaler: ScalerT = Point, mask_radius: RadiusT = 1, sigma_mask: float | bool = False,
mask_coords: Sequence[int] | None = None,
show_mask: bool = False, func: FuncExceptT | None = None, **kwargs: Any
) -> vs.VideoNode:
func = func or dehalo_alpha
Expand Down Expand Up @@ -719,7 +670,7 @@ def dehalomicron(
actual_dehalo = dehalo_sigma(
func.work_clip, pre_ss=1 + pre_ss, sigma=sigma, ss=ss - 0.5 * pre_ss, planes=func.norm_planes, **kwargs
)
dehalo_ref = fine_dehalo(func.work_clip, planes=func.norm_planes, **fdehalo_kwargs)
dehalo_ref = fine_dehalo(func.work_clip, planes=func.norm_planes, **fdehalo_kwargs) # type: ignore[arg-type]

dehalo_min = ExprOp.MIN(actual_dehalo, dehalo_ref, planes=func.norm_planes)

Expand Down Expand Up @@ -751,10 +702,10 @@ def dehalomicron(
def dehalo_merge(
clip: vs.VideoNode, dehalo: vs.VideoNode, darkstr: list[float] | float = 0.0, brightstr: list[float] | float = 1.0,
lowsens: list[float] | float = 50.0, highsens: list[float] | float = 50.0, sigma_mask: float | bool = False,
ss: list[float] | float = 1.5, planes: PlanesT = 0, show_mask: bool = False, mask_radius: int = 1,
ss: list[float] | float = 1.5, planes: PlanesT = 0, show_mask: bool = False, mask_radius: RadiusT = 1,
supersampler: ScalerT = Lanczos(3), supersampler_ref: ScalerT = Mitchell, pre_ss: float = 1.0,
pre_supersampler: ScalerT = Nnedi3(0, field=0, shifter=NoShift), pre_downscaler: ScalerT = Point,
mask_coords: int | tuple[int, ConvMode] | Sequence[int] = 3, func: FuncExceptT | None = None
mask_coords: Sequence[int] | None = None, func: FuncExceptT | None = None
) -> vs.VideoNode:
"""
Merge dehaloed clip onto the source clip.
Expand Down Expand Up @@ -785,6 +736,7 @@ def dehalo_merge(
func = func or dehalo_merge

assert check_ref_clip(clip, dehalo, func)
assert check_variable_format(clip, func)

if FieldBased.from_video(clip).is_inter:
raise UnsupportedFieldBasedError('Only progressive video is supported!', func)
Expand All @@ -805,7 +757,7 @@ def dehalo_merge(
)

darkstr_i, brightstr_i, lowsens_i, highsens_i, ss_i = next(
_dehalo_schizo_norm(darkstr, brightstr, lowsens, highsens, ss)
_dehalo_schizo_norm(darkstr, brightstr, lowsens, highsens, ss) # type: ignore[call-overload]
)

if not all(x >= 1 for x in ss_i):
Expand Down
Loading

0 comments on commit ff88863

Please sign in to comment.