diff --git a/audio_diffusion_pytorch/components.py b/audio_diffusion_pytorch/components.py index bd9dc40..ca1963e 100644 --- a/audio_diffusion_pytorch/components.py +++ b/audio_diffusion_pytorch/components.py @@ -44,6 +44,8 @@ def UNetV0( attention_heads: Optional[int] = None, embedding_features: Optional[int] = None, resnet_groups: int = 8, + resnet_dilation_factor: int=1, + resnet_dropout_rate: float=0.00, use_modulation: bool = True, modulation_features: int = 1024, embedding_max_length: Optional[int] = None, @@ -102,6 +104,8 @@ def UNetV0( embedding_features=embedding_features, modulation_features=modulation_features, resnet_groups=resnet_groups, + resnet_dilation_factor=resnet_dilation_factor, + resnet_dropout_rate=resnet_dropout_rate, )