From c612b2fa3caf6eef3c75f1cd5859438451d0fd6f Mon Sep 17 00:00:00 2001 From: Ceki Gulcu Date: Fri, 1 Dec 2023 13:00:12 +0100 Subject: [PATCH] add deseialization filter by reflction, enable commented out test case Signed-off-by: Ceki Gulcu --- .../core/net/HardenedObjectInputStream.java | 33 +++++++- .../net/HardenedObjectInputStreamTest.java | 75 ++++++++++--------- 2 files changed, 72 insertions(+), 36 deletions(-) diff --git a/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java b/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java index c01812f768..326a179408 100755 --- a/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java +++ b/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java @@ -13,11 +13,16 @@ */ package ch.qos.logback.core.net; +import ch.qos.logback.core.util.EnvUtil; + import java.io.IOException; import java.io.InputStream; import java.io.InvalidClassException; +import java.io.ObjectInputFilter; import java.io.ObjectInputStream; import java.io.ObjectStreamClass; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; @@ -43,6 +48,7 @@ public class HardenedObjectInputStream extends ObjectInputStream { public HardenedObjectInputStream(InputStream in, String[] whitelist) throws IOException { super(in); + this.initObjectFilter(); this.whitelistedClassNames = new ArrayList(); if (whitelist != null) { for (int i = 0; i < whitelist.length; i++) { @@ -54,11 +60,36 @@ public HardenedObjectInputStream(InputStream in, String[] whitelist) throws IOEx public HardenedObjectInputStream(InputStream in, List whitelist) throws IOException { super(in); - + this.initObjectFilter(); this.whitelistedClassNames = new ArrayList(); this.whitelistedClassNames.addAll(whitelist); } + private void initObjectFilter() { + + // invoke the following code by reflection + // this.setObjectInputFilter(ObjectInputFilter.Config.createFilter( + // "maxarray=" + ARRAY_LIMIT + ";maxdepth=" + DEPTH_LIMIT + ";" + // )); + if(EnvUtil.isJDK9OrHigher()) { + try { + ClassLoader classLoader = this.getClass().getClassLoader(); + + Class oifClass = classLoader.loadClass("java.io.ObjectInputFilter"); + Class oifConfigClass = classLoader.loadClass("java.io.ObjectInputFilter$Config"); + Method setObjectInputFilterMethod = this.getClass().getMethod("setObjectInputFilter", oifClass); + + Method createFilterMethod = oifConfigClass.getMethod("createFilter", String.class); + Object filter = createFilterMethod.invoke(null, "maxarray=" + ARRAY_LIMIT + ";maxdepth=" + DEPTH_LIMIT + ";"); + setObjectInputFilterMethod.invoke(this, filter); + } catch (NoSuchMethodException | ClassNotFoundException | IllegalAccessException | InvocationTargetException e) { + // this code should be unreachable + throw new RuntimeException("Failed to initialize object filter", e); + } + + } + } + @Override protected Class resolveClass(ObjectStreamClass anObjectStreamClass) throws IOException, ClassNotFoundException { diff --git a/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java b/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java index 968b4b0fe0..b0dbe69551 100755 --- a/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java +++ b/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java @@ -3,13 +3,19 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InvalidClassException; import java.io.ObjectOutputStream; +import java.util.HashSet; +import java.util.Set; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; public class HardenedObjectInputStreamTest { @@ -53,39 +59,38 @@ private void writeObject(ObjectOutputStream oos, Object o) throws IOException { oos.close(); } -// @Ignore -// @Test -// public void denialOfService() throws ClassNotFoundException, IOException { -// ByteArrayInputStream bis = new ByteArrayInputStream(payload()); -// inputStream = new HardenedObjectInputStream(bis, whitelist); -// try { -// Set set = (Set) inputStream.readObject(); -// assertNotNull(set); -// } finally { -// inputStream.close(); -// } -// } -// -// private byte[] payload() throws IOException { -// Set root = buildEvilHashset(); -// return serialize(root); -// } -// -// private Set buildEvilHashset() { -// Set root = new HashSet(); -// Set s1 = root; -// Set s2 = new HashSet(); -// for (int i = 0; i < 100; i++) { -// Set t1 = new HashSet(); -// Set t2 = new HashSet(); -// t1.add("foo"); // make it not equal to t2 -// s1.add(t1); -// s1.add(t2); -// s2.add(t1); -// s2.add(t2); -// s1 = t1; -// s2 = t2; -// } -// return root; -// } + @Test + public void denialOfService() throws ClassNotFoundException, IOException { + ByteArrayInputStream bis = new ByteArrayInputStream(payload()); + inputStream = new HardenedObjectInputStream(bis, whitelist); + try { + assertThrows(InvalidClassException.class, () -> inputStream.readObject()); + } finally { + inputStream.close(); + } + } + + private byte[] payload() throws IOException { + Set root = buildEvilHashset(); + writeObject(oos, root); + return bos.toByteArray(); + } + + private Set buildEvilHashset() { + Set root = new HashSet(); + Set s1 = root; + Set s2 = new HashSet(); + for (int i = 0; i < 100; i++) { + Set t1 = new HashSet(); + Set t2 = new HashSet(); + t1.add("foo"); // make it not equal to t2 + s1.add(t1); + s1.add(t2); + s2.add(t1); + s2.add(t2); + s1 = t1; + s2 = t2; + } + return root; + } }