Skip to content

Commit

Permalink
add deseialization filter by reflction, enable commented out test case
Browse files Browse the repository at this point in the history
Signed-off-by: Ceki Gulcu <ceki@qos.ch>
  • Loading branch information
ceki committed Dec 1, 2023
1 parent d1bf54f commit c612b2f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
*/
package ch.qos.logback.core.net;

import ch.qos.logback.core.util.EnvUtil;

import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidClassException;
import java.io.ObjectInputFilter;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

Expand All @@ -43,6 +48,7 @@ public class HardenedObjectInputStream extends ObjectInputStream {

public HardenedObjectInputStream(InputStream in, String[] whitelist) throws IOException {
super(in);
this.initObjectFilter();
this.whitelistedClassNames = new ArrayList<String>();
if (whitelist != null) {
for (int i = 0; i < whitelist.length; i++) {
Expand All @@ -54,11 +60,36 @@ public HardenedObjectInputStream(InputStream in, String[] whitelist) throws IOEx

public HardenedObjectInputStream(InputStream in, List<String> whitelist) throws IOException {
super(in);

this.initObjectFilter();
this.whitelistedClassNames = new ArrayList<String>();
this.whitelistedClassNames.addAll(whitelist);
}

private void initObjectFilter() {

// invoke the following code by reflection
// this.setObjectInputFilter(ObjectInputFilter.Config.createFilter(
// "maxarray=" + ARRAY_LIMIT + ";maxdepth=" + DEPTH_LIMIT + ";"
// ));
if(EnvUtil.isJDK9OrHigher()) {
try {
ClassLoader classLoader = this.getClass().getClassLoader();

Class oifClass = classLoader.loadClass("java.io.ObjectInputFilter");
Class oifConfigClass = classLoader.loadClass("java.io.ObjectInputFilter$Config");
Method setObjectInputFilterMethod = this.getClass().getMethod("setObjectInputFilter", oifClass);

Method createFilterMethod = oifConfigClass.getMethod("createFilter", String.class);
Object filter = createFilterMethod.invoke(null, "maxarray=" + ARRAY_LIMIT + ";maxdepth=" + DEPTH_LIMIT + ";");
setObjectInputFilterMethod.invoke(this, filter);
} catch (NoSuchMethodException | ClassNotFoundException | IllegalAccessException | InvocationTargetException e) {
// this code should be unreachable
throw new RuntimeException("Failed to initialize object filter", e);
}

}
}

@Override
protected Class<?> resolveClass(ObjectStreamClass anObjectStreamClass) throws IOException, ClassNotFoundException {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectOutputStream;
import java.util.HashSet;
import java.util.Set;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class HardenedObjectInputStreamTest {

Expand Down Expand Up @@ -53,39 +59,38 @@ private void writeObject(ObjectOutputStream oos, Object o) throws IOException {
oos.close();
}

// @Ignore
// @Test
// public void denialOfService() throws ClassNotFoundException, IOException {
// ByteArrayInputStream bis = new ByteArrayInputStream(payload());
// inputStream = new HardenedObjectInputStream(bis, whitelist);
// try {
// Set set = (Set) inputStream.readObject();
// assertNotNull(set);
// } finally {
// inputStream.close();
// }
// }
//
// private byte[] payload() throws IOException {
// Set root = buildEvilHashset();
// return serialize(root);
// }
//
// private Set buildEvilHashset() {
// Set root = new HashSet();
// Set s1 = root;
// Set s2 = new HashSet();
// for (int i = 0; i < 100; i++) {
// Set t1 = new HashSet();
// Set t2 = new HashSet();
// t1.add("foo"); // make it not equal to t2
// s1.add(t1);
// s1.add(t2);
// s2.add(t1);
// s2.add(t2);
// s1 = t1;
// s2 = t2;
// }
// return root;
// }
@Test
public void denialOfService() throws ClassNotFoundException, IOException {
ByteArrayInputStream bis = new ByteArrayInputStream(payload());
inputStream = new HardenedObjectInputStream(bis, whitelist);
try {
assertThrows(InvalidClassException.class, () -> inputStream.readObject());
} finally {
inputStream.close();
}
}

private byte[] payload() throws IOException {
Set root = buildEvilHashset();
writeObject(oos, root);
return bos.toByteArray();
}

private Set buildEvilHashset() {
Set root = new HashSet();
Set s1 = root;
Set s2 = new HashSet();
for (int i = 0; i < 100; i++) {
Set t1 = new HashSet();
Set t2 = new HashSet();
t1.add("foo"); // make it not equal to t2
s1.add(t1);
s1.add(t2);
s2.add(t1);
s2.add(t2);
s1 = t1;
s2 = t2;
}
return root;
}
}

0 comments on commit c612b2f

Please sign in to comment.