Skip to content

Commit

Permalink
Add libp2p connection Firewall (#2478)
Browse files Browse the repository at this point in the history
* Add libp2p connection Firewall
* Handle all exceptions from Firewall handlers
  • Loading branch information
Nashatyrev authored Jul 30, 2020
1 parent 0141b55 commit 4c4a83b
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2020 ConsenSys AG.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package tech.pegasys.teku.networking.p2p.libp2p;

import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.handler.timeout.WriteTimeoutException;
import io.netty.handler.timeout.WriteTimeoutHandler;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import tech.pegasys.teku.infrastructure.async.FutureUtil;

/**
* The very first Netty handler in the Libp2p connection pipeline. Sets up Netty Channel options and
* doing other duties preventing DoS attacks
*/
@Sharable
public class Firewall extends ChannelInboundHandlerAdapter {
private static final Logger LOG = LogManager.getLogger();

private final Duration writeTimeout;
private final List<ChannelHandler> additionalHandlers;

public Firewall(Duration writeTimeout, List<ChannelHandler> additionalHandlers) {
this.writeTimeout = writeTimeout;
this.additionalHandlers = additionalHandlers;
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) {
additionalHandlers.forEach(h -> ctx.pipeline().addLast(h));
ctx.channel().config().setWriteBufferWaterMark(new WriteBufferWaterMark(100, 1024));
ctx.pipeline().addLast(new WriteTimeoutHandler(writeTimeout.toMillis(), TimeUnit.MILLISECONDS));
ctx.pipeline().addLast(new FirewallExceptionHandler());
}

@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) {
ctx.channel().config().setAutoRead(ctx.channel().isWritable());
ctx.fireChannelWritabilityChanged();
}

class FirewallExceptionHandler extends ChannelInboundHandlerAdapter {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof WriteTimeoutException) {
LOG.debug("Firewall closed channel by write timeout. No writes during " + writeTimeout);
} else {
LOG.debug("Error in Firewall, disconnecting" + cause);
FutureUtil.ignoreFuture(ctx.close());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -138,9 +140,13 @@ public LibP2PNetwork(
b.getProtocols().addAll(getDefaultProtocols());
b.getProtocols().addAll(rpcHandlers.values());

List<ChannelHandler> beforeSecureLogHandler = new ArrayList<>();
if (config.getWireLogsConfig().isLogWireCipher()) {
b.getDebug().getBeforeSecureHandler().setLogger(LogLevel.DEBUG, "wire.ciphered");
beforeSecureLogHandler.add(new LoggingHandler("wire.ciphered", LogLevel.DEBUG));
}
Firewall firewall = new Firewall(Duration.ofSeconds(30), beforeSecureLogHandler);
b.getDebug().getBeforeSecureHandler().setHandler(firewall);

if (config.getWireLogsConfig().isLogWirePlain()) {
b.getDebug().getAfterSecureHandler().setLogger(LogLevel.DEBUG, "wire.plain");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2020 ConsenSys AG.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package tech.pegasys.teku.networking.p2p.libp2p;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import java.time.Duration;
import java.util.Collections;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class FirewallTest {

@Test
@SuppressWarnings("FutureReturnValueIgnored")
void testFirewallNotPropagateTimeoutExceptionUpstream() throws Exception {
Firewall firewall = new Firewall(Duration.ofMillis(100), Collections.emptyList());
EmbeddedChannel channel =
new EmbeddedChannel(
firewall,
new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
super.exceptionCaught(ctx, cause);
}
});
channel.writeOneOutbound("a");
executeAllScheduledTasks(channel, 5);
Assertions.assertThatCode(channel::checkException).doesNotThrowAnyException();
Assertions.assertThat(channel.isOpen()).isFalse();
}

private void executeAllScheduledTasks(EmbeddedChannel channel, long maxWaitSeconds)
throws TimeoutException, InterruptedException {
long waitTime = 0;
while (waitTime < maxWaitSeconds * 1000) {
long l = channel.runScheduledPendingTasks();
if (l < 0) break;
long ms = l / 1_000_000;
waitTime += ms;
Thread.sleep(ms);
}
if (waitTime >= maxWaitSeconds * 1000) {
throw new TimeoutException();
}
}
}

0 comments on commit 4c4a83b

Please sign in to comment.