Skip to content

Commit 2f870c7

Browse files
authored
Refactor SafeSerializationUtils for better performance (#4973)
Signed-off-by: shikharj05 <8859327+shikharj05@users.noreply.github.com>
1 parent 79a3299 commit 2f870c7

File tree

2 files changed

+136
-8
lines changed

2 files changed

+136
-8
lines changed

src/main/java/org/opensearch/security/support/SafeSerializationUtils.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
import java.net.SocketAddress;
1818
import java.util.Collection;
1919
import java.util.Collections;
20-
import java.util.List;
2120
import java.util.Map;
2221
import java.util.Set;
22+
import java.util.concurrent.ConcurrentHashMap;
2323
import java.util.regex.Pattern;
2424

25-
import com.google.common.collect.ImmutableList;
2625
import com.google.common.collect.ImmutableSet;
2726

2827
import org.opensearch.security.auth.UserInjector;
@@ -57,7 +56,7 @@ public final class SafeSerializationUtils {
5756
LdapAttribute.class
5857
);
5958

60-
private static final List<Class<?>> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of(
59+
private static final Set<Class<?>> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableSet.of(
6160
InetAddress.class,
6261
Number.class,
6362
Collection.class,
@@ -66,18 +65,28 @@ public final class SafeSerializationUtils {
6665
);
6766

6867
private static final Set<String> SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues");
68+
static final Map<Class<?>, Boolean> safeClassCache = new ConcurrentHashMap<>();
6969

7070
static boolean isSafeClass(Class<?> cls) {
71-
return cls.isArray()
72-
|| SAFE_CLASSES.contains(cls)
73-
|| SAFE_CLASS_NAMES.contains(cls.getName())
74-
|| SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls));
71+
return safeClassCache.computeIfAbsent(cls, SafeSerializationUtils::computeIsSafeClass);
72+
}
73+
74+
static boolean computeIsSafeClass(Class<?> cls) {
75+
return cls.isArray() || SAFE_CLASSES.contains(cls) || SAFE_CLASS_NAMES.contains(cls.getName()) || isAssignableFromSafeClass(cls);
76+
}
77+
78+
private static boolean isAssignableFromSafeClass(Class<?> cls) {
79+
for (Class<?> safeClass : SAFE_ASSIGNABLE_FROM_CLASSES) {
80+
if (safeClass.isAssignableFrom(cls)) {
81+
return true;
82+
}
83+
}
84+
return false;
7585
}
7686

7787
static void prohibitUnsafeClasses(Class<?> clazz) throws IOException {
7888
if (!isSafeClass(clazz)) {
7989
throw new IOException("Unauthorized serialization attempt " + clazz.getName());
8090
}
8191
}
82-
8392
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* The OpenSearch Contributors require contributions made to
6+
* this file be licensed under the Apache-2.0 license or a
7+
* compatible open source license.
8+
*
9+
*/
10+
11+
package org.opensearch.security.support;
12+
13+
import java.io.IOException;
14+
import java.net.InetAddress;
15+
import java.net.InetSocketAddress;
16+
import java.util.ArrayList;
17+
import java.util.HashMap;
18+
import java.util.regex.Pattern;
19+
20+
import org.junit.Test;
21+
22+
import org.opensearch.security.auth.UserInjector;
23+
import org.opensearch.security.user.User;
24+
25+
import com.amazon.dlic.auth.ldap.LdapUser;
26+
import org.ldaptive.AbstractLdapBean;
27+
import org.ldaptive.LdapAttribute;
28+
import org.ldaptive.LdapEntry;
29+
import org.ldaptive.SearchEntry;
30+
31+
import static org.junit.Assert.assertEquals;
32+
import static org.junit.Assert.assertFalse;
33+
import static org.junit.Assert.assertTrue;
34+
import static org.junit.Assert.fail;
35+
36+
public class SafeSerializationUtilsTest {
37+
38+
@Test
39+
public void testSafeClasses() {
40+
assertTrue(SafeSerializationUtils.isSafeClass(String.class));
41+
assertTrue(SafeSerializationUtils.isSafeClass(InetSocketAddress.class));
42+
assertTrue(SafeSerializationUtils.isSafeClass(Pattern.class));
43+
assertTrue(SafeSerializationUtils.isSafeClass(User.class));
44+
assertTrue(SafeSerializationUtils.isSafeClass(UserInjector.InjectedUser.class));
45+
assertTrue(SafeSerializationUtils.isSafeClass(SourceFieldsContext.class));
46+
assertTrue(SafeSerializationUtils.isSafeClass(LdapUser.class));
47+
assertTrue(SafeSerializationUtils.isSafeClass(SearchEntry.class));
48+
assertTrue(SafeSerializationUtils.isSafeClass(LdapEntry.class));
49+
assertTrue(SafeSerializationUtils.isSafeClass(AbstractLdapBean.class));
50+
assertTrue(SafeSerializationUtils.isSafeClass(LdapAttribute.class));
51+
}
52+
53+
@Test
54+
public void testSafeAssignableClasses() {
55+
assertTrue(SafeSerializationUtils.isSafeClass(InetAddress.class));
56+
assertTrue(SafeSerializationUtils.isSafeClass(Integer.class));
57+
assertTrue(SafeSerializationUtils.isSafeClass(ArrayList.class));
58+
assertTrue(SafeSerializationUtils.isSafeClass(HashMap.class));
59+
assertTrue(SafeSerializationUtils.isSafeClass(Enum.class));
60+
}
61+
62+
@Test
63+
public void testArraysAreSafe() {
64+
assertTrue(SafeSerializationUtils.isSafeClass(String[].class));
65+
assertTrue(SafeSerializationUtils.isSafeClass(int[].class));
66+
assertTrue(SafeSerializationUtils.isSafeClass(Object[].class));
67+
}
68+
69+
@Test
70+
public void testUnsafeClasses() {
71+
assertFalse(SafeSerializationUtils.isSafeClass(SafeSerializationUtilsTest.class));
72+
assertFalse(SafeSerializationUtils.isSafeClass(Runtime.class));
73+
}
74+
75+
@Test
76+
public void testProhibitUnsafeClasses() {
77+
try {
78+
SafeSerializationUtils.prohibitUnsafeClasses(String.class);
79+
} catch (IOException e) {
80+
fail("Should not throw exception for safe class");
81+
}
82+
83+
try {
84+
SafeSerializationUtils.prohibitUnsafeClasses(SafeSerializationUtilsTest.class);
85+
fail("Should throw exception for unsafe class");
86+
} catch (IOException e) {
87+
assertEquals("Unauthorized serialization attempt " + SafeSerializationUtilsTest.class.getName(), e.getMessage());
88+
}
89+
}
90+
91+
@Test
92+
public void testInheritance() {
93+
class CustomArrayList extends ArrayList<String> {}
94+
assertTrue(SafeSerializationUtils.isSafeClass(CustomArrayList.class));
95+
96+
class CustomMap extends HashMap<String, Integer> {}
97+
assertTrue(SafeSerializationUtils.isSafeClass(CustomMap.class));
98+
}
99+
100+
@Test
101+
public void testCaching() {
102+
// First call should compute the result
103+
boolean result1 = SafeSerializationUtils.isSafeClass(String.class);
104+
assertTrue(result1);
105+
106+
// Second call should use cached result
107+
boolean result2 = SafeSerializationUtils.isSafeClass(String.class);
108+
assertTrue(result2);
109+
110+
// Verify that the cache was used (size should be 1)
111+
assertEquals(1, SafeSerializationUtils.safeClassCache.size());
112+
113+
// Third call for a different class
114+
boolean result3 = SafeSerializationUtils.isSafeClass(Integer.class);
115+
assertTrue(result3);
116+
// Verify that the cache was updated
117+
assertEquals(2, SafeSerializationUtils.safeClassCache.size());
118+
}
119+
}

0 commit comments

Comments
 (0)