Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 114 additions & 44 deletions src/main/java/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@
import java.net.DatagramSocket;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

public class Main {
private enum Mode {
private enum Symbol {
DELAY,
HOST,
UDP,
TCP;

public static Mode fromString(String s) {
return switch (s) {
case "u" -> UDP;
case "t" -> TCP;
default -> throw new IllegalArgumentException("Unknown mode: " + s);
};
}
TCP
}

private static final int TCP_TIMEOUT = 1000; // milliseconds
private static final int TCP_TIMEOUT = 1000; // default TCP timeout in milliseconds

private static void printUsageAndExit() {
System.out.println(
Expand All @@ -29,17 +25,71 @@ private static void printUsageAndExit() {
System.exit(1);
}

private static long parseDelay(String delayStr) {
try {
long delay = Long.parseLong(delayStr);
if (delay < 100 || delay > 60_000) { // 60 seconds max
System.out.println("Delay must be between 100 and 60000 milliseconds.");
printUsageAndExit();
}
return delay;
} catch (NumberFormatException e) {
printUsageAndExit();
return -1; // Unreachable, but required by the compiler
}
}

private static int parsePort(String portStr) {
try {
int port = Integer.parseInt(portStr);
if (port < 1 || port > 65535) {
System.out.println("Port must be between 1 and 65535.");
printUsageAndExit();
}
return port;
} catch (NumberFormatException e) {
printUsageAndExit();
return -1; // Unreachable, but required by the compiler
}
}

public static void main(String[] args) {
if (args.length == 0) {
printUsageAndExit();
}

// Default parameters...
int delay = 100; // default delay
long delay = 100; // default delay
String host = "localhost"; // default host

// Parse arguments...
LinkedList<AbstractMap.SimpleEntry<Mode, Integer>> knockSequence = new LinkedList<>();
Map<Integer, Map<Symbol, Integer>> fsm = new HashMap<>();
fsm.put(
0,
Map.of(
Symbol.DELAY, 1,
Symbol.HOST, 2,
Symbol.UDP, 3,
Symbol.TCP, 3));
fsm.put(
1,
Map.of(
Symbol.HOST, 2,
Symbol.UDP, 3,
Symbol.TCP, 3));
fsm.put(
2,
Map.of(
Symbol.DELAY, 1,
Symbol.UDP, 3,
Symbol.TCP, 3));
fsm.put(
3,
Map.of(
Symbol.UDP, 3,
Symbol.TCP, 3));
List<Runnable> knockSequence = new LinkedList<>();
int state = 0;
for (String arg : args) {
String[] parts = arg.split("=", 2);
if (parts.length != 2) {
Expand All @@ -48,27 +98,45 @@ public static void main(String[] args) {
String key = parts[0];
String value = parts[1];
switch (key) {
case "host" -> host = value;
case "host" -> {
if (!fsm.get(state).containsKey(Symbol.HOST)) {
printUsageAndExit();
}
host = value;
state = fsm.get(state).get(Symbol.HOST);
}
case "delay" -> {
try {
int d = Integer.parseInt(value);
if (d < 0 || d > 60_000) {
printUsageAndExit();
}
delay = d;
} catch (NumberFormatException e) {
if (!fsm.get(state).containsKey(Symbol.DELAY)) {
printUsageAndExit();
}
delay = parseDelay(value);
state = fsm.get(state).get(Symbol.DELAY);
}
case "u", "t" -> {
try {
int port = Integer.parseInt(value);
if (port < 1 || port > 65535) {
printUsageAndExit();
switch (key) {
case "u" -> {
if (!fsm.get(state).containsKey(Symbol.UDP)) {
printUsageAndExit();
}
String finalHost = host;
int finalPort = parsePort(value);
long finalDelay = delay;
knockSequence.add(() -> knockUDP(finalHost, finalPort));
knockSequence.add(() -> knockDelay(finalDelay));
state = fsm.get(state).get(Symbol.UDP);
}
knockSequence.add(new AbstractMap.SimpleEntry<>(Mode.fromString(key), port));
} catch (NumberFormatException e) {
printUsageAndExit();
case "t" -> {
if (!fsm.get(state).containsKey(Symbol.TCP)) {
printUsageAndExit();
}
String finalHost = host;
int finalPort = parsePort(value);
long finalDelay = delay;
knockSequence.add(() -> knockTCP(finalHost, finalPort, TCP_TIMEOUT));
knockSequence.add(() -> knockDelay(finalDelay));
state = fsm.get(state).get(Symbol.TCP);
}
default -> printUsageAndExit(); // Unreachable, but required by SpotBugs
}
}
default -> printUsageAndExit();
Expand All @@ -77,23 +145,10 @@ public static void main(String[] args) {
if (knockSequence.isEmpty()) {
printUsageAndExit();
}
knockSequence.remove(knockSequence.size() - 1); // Remove last delay

// Perform knocking sequence...
for (AbstractMap.SimpleEntry<Mode, Integer> entry : knockSequence) {
Mode mode = entry.getKey();
int port = entry.getValue();
System.out.printf("Knocking %s %s:%d ... ", mode.name(), host, port);
switch (mode) {
case UDP -> System.out.println(knockUDP(host, port) ? "Success" : "Failed");
case TCP -> System.out.println(knockTCP(host, port, TCP_TIMEOUT) ? "Success" : "Failed");
}
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
System.out.println("Sleep interrupted: " + e.getMessage());
return;
}
}
knockSequence.forEach(Runnable::run);
}

/**
Expand All @@ -105,11 +160,14 @@ public static void main(String[] args) {
* @return true if the datagram was sent successfully, false otherwise
*/
public static boolean knockUDP(String host, int port) {
System.out.print("Knocking UDP " + host + ":" + port + " ... ");
try (DatagramSocket socket = new DatagramSocket()) {
socket.connect(new InetSocketAddress(host, port));
socket.send(new java.net.DatagramPacket(new byte[2048], 2048));
System.out.println("done.");
return true;
} catch (IOException e) {
System.out.println("failed.");
return false;
}
}
Expand All @@ -124,11 +182,23 @@ public static boolean knockUDP(String host, int port) {
* @return true if the connection was successful, false otherwise
*/
public static boolean knockTCP(String host, int port, int timeout) {
System.out.print("Knocking TCP " + host + ":" + port + " ... ");
try (Socket socket = new Socket()) {
socket.connect(new InetSocketAddress(host, port), timeout);
System.out.println("done.");
return true;
} catch (IOException e) {
System.out.println("failed.");
return false;
}
}

private static void knockDelay(long delay) {
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
System.out.println("Sleep interrupted: " + e.getMessage());
System.exit(1);
}
}
}
1 change: 0 additions & 1 deletion src/test/java/MainTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import org.junit.jupiter.api.Test;

public class MainTest {

@Test
public void testKnockTCP_InvalidHost() {
// Test with invalid host - should return false
Expand Down
Loading