diff --git a/src/main/java/org/fungover/haze/Main.java b/src/main/java/org/fungover/haze/Main.java index 74e38fa1..f53c2579 100644 --- a/src/main/java/org/fungover/haze/Main.java +++ b/src/main/java/org/fungover/haze/Main.java @@ -8,6 +8,7 @@ import java.io.InputStreamReader; import java.net.InetSocketAddress; import java.net.ServerSocket; +import java.net.Socket; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -19,9 +20,12 @@ public class Main { public static void main(String[] args) { Initialize initialize = new Initialize(); initialize.importCliOptions(args); - - HazeDatabase hazeDatabase = new HazeDatabase(); HazeList hazeList = new HazeList(); + HazeDatabase hazeDatabase = new HazeDatabase(); + Auth auth = new Auth(); + initializeServer(args, initialize, auth); + final boolean isPasswordSet = auth.isPasswordSet(); + Thread printingHook = new Thread(() -> shutdown(hazeDatabase)); Runtime.getRuntime().addShutdownHook(printingHook); @@ -36,18 +40,22 @@ public static void main(String[] args) { Runnable newThread = () -> { try { BufferedReader input = new BufferedReader(new InputStreamReader(client.getInputStream())); - + boolean clientAuthenticated = false; while (true) { List inputList = new ArrayList<>(); String firstReading = input.readLine(); readInputStream(input, inputList, firstReading); - client.getOutputStream().write(executeCommand(hazeDatabase, inputList,hazeList ).getBytes()); + clientAuthenticated = authenticateClient(auth, isPasswordSet, client, inputList, clientAuthenticated); + + client.getOutputStream().write(executeCommand(hazeDatabase, inputList, hazeList).getBytes()); inputList.forEach(System.out::println); // For checking incoming message printThreadDebug(); + + inputList.clear(); } } catch (IOException e) { @@ -73,7 +81,7 @@ private static void printThreadDebug() { } public static String executeCommand(HazeDatabase hazeDatabase, List inputList, HazeList hazeList) { - logger.debug("executeCommand: {} {} ", ()-> hazeDatabase, ()-> inputList); + logger.debug("executeCommand: {} {} ", () -> hazeDatabase, () -> inputList); String command = inputList.get(0).toUpperCase(); return switch (command) { @@ -91,13 +99,14 @@ public static String executeCommand(HazeDatabase hazeDatabase, List inpu case "LLEN" -> hazeList.lLen(inputList); case "LMOVE" -> hazeList.lMove(inputList); case "LTRIM" -> hazeList.callLtrim(inputList); + case "AUTH" -> "+OK\r\n"; default -> "-ERR unknown command\r\n"; }; } private static void readInputStream(BufferedReader input, List inputList, String firstReading) throws IOException { - logger.debug("readInputStream: {} {} {}", ()-> input, () -> inputList, () -> firstReading); + logger.debug("readInputStream: {} {} {}", () -> input, () -> inputList, () -> firstReading); int size; if (firstReading.startsWith("*")) { size = Integer.parseInt(firstReading.substring(1)) * 2; @@ -112,4 +121,27 @@ private static void readInputStream(BufferedReader input, List inputList } } + private static void initializeServer(String[] args, Initialize initialize, Auth auth) { + initialize.importCliOptions(args); + auth.setPassword(initialize.getPassword()); + } + + private static boolean authenticateClient(Auth auth, boolean isPasswordSet, Socket client, List inputList, boolean clientAuthenticated) throws IOException { + if (authCommandReceived(isPasswordSet, inputList, clientAuthenticated)) + return auth.authenticate(inputList.get(1), client); + + shutdownClientIfNotAuthenticated(client, clientAuthenticated, isPasswordSet); + return clientAuthenticated; + } + + private static void shutdownClientIfNotAuthenticated(Socket client, boolean clientAuthenticated, boolean isPasswordSet) throws IOException { + if (!clientAuthenticated && isPasswordSet) { + client.getOutputStream().write(Auth.printAuthError()); + client.shutdownOutput(); + } + } + + private static boolean authCommandReceived(boolean isPasswordSet, List inputList, boolean clientAuthenticated) { + return isPasswordSet && !clientAuthenticated && inputList.size() == 2 && inputList.get(0).equals("AUTH"); + } }