Skip to content

Commit

Permalink
Check for null object IDs
Browse files Browse the repository at this point in the history
Problem: at various places in the rust codebase we receive an object ID
from the java side in the form of a `jobject` which is a pointer to a
`ObjectId`. This pointer can be null but we were not checking whether it
was null and consequently were panicking and crashing the entire JVM.

Solution: check for null when loading object IDs and throw an
IllegalArgumentException if the `ObjectId` is null. This required a
macro to avoid lots of code handling the early return.
  • Loading branch information
alexjg committed Jan 11, 2024
1 parent d92e28a commit 312d97d
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 39 deletions.
11 changes: 11 additions & 0 deletions lib/src/test/java/org/automerge/TestDocument.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,15 @@ public void testFree() {
Document doc = new Document();
doc.free();
}

@Test
public void testNullObjectIdThrows() {
// this test is actually a test for any path which uses an `ObjectId`
// as the code which throws the exception is in the rust implementation
// which converts the `ObjectId` into an internal rust type
Document doc = new Document();
Assertions.assertThrows(IllegalArgumentException.class, () -> {
doc.get(null, "key");
});
}
}
1 change: 0 additions & 1 deletion lib/src/test/java/org/automerge/TestSplice.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ public final class TestSplice {
interface InsertedAssertions {
void assertInserted(Object elem1, Object elem2);
}

public TestSplice() {
super();
}
Expand Down
3 changes: 1 addition & 2 deletions rust/src/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ impl Cursor {
.map_err(errors::FromRaw::GetByteArray)?;
let bytes =
std::slice::from_raw_parts(arr.as_ptr() as *const u8, arr.size().unwrap() as usize);
let cursor: automerge::Cursor =
bytes.try_into().map_err(errors::FromRaw::Invalid)?;
let cursor: automerge::Cursor = bytes.try_into().map_err(errors::FromRaw::Invalid)?;
Ok(Self(cursor))
}
}
Expand Down
80 changes: 73 additions & 7 deletions rust/src/obj_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,17 @@ impl JavaObjId {
Ok(raw_obj)
}

pub(crate) unsafe fn from_raw(env: &JNIEnv<'_>, raw: jobject) -> Result<Self, errors::FromRaw> {
pub(crate) unsafe fn from_raw(
env: &JNIEnv<'_>,
raw: jobject,
) -> Result<Option<Self>, errors::FromRaw> {
let obj = JObject::from_raw(raw);
let id_is_null = env
.is_same_object(obj, JObject::null())
.map_err(errors::FromRaw::GetRaw)?;
if id_is_null {
return Ok(None);
}
let bytes_jobject = env
.get_field(obj, "raw", "[B")
.map_err(errors::FromRaw::GetRaw)?
Expand All @@ -62,10 +71,57 @@ impl JavaObjId {
let bytes =
std::slice::from_raw_parts(arr.as_ptr() as *const u8, arr.size().unwrap() as usize);
let obj: automerge::ObjId = bytes.try_into()?;
Ok(Self(obj))
//Ok(Some(Self(obj)))
Ok(Some(Self(obj)))
}
}

/// Get the `ObjId` from a `jobject` or throw an exception and return the given value.
///
/// This macro performs an early return if the `jobject` is null, which means the macro has two
/// forms. The first form, which looks like this:
///
/// ```rust,ignore
/// let obj = obj_id_or_throw!(env, some_obj_id);
/// ```
///
/// Takes a [`jni::JNIEnv`] and a `jobject` and returns a [`JavaObjId`] or throws an exception and
/// early returns a `jobject` from the surrounding function.
///
/// The second form, which looks like this:
///
/// ```rust,ignore
/// let obj = obj_id_or_throw!(env, some_obj_id, false); // the `false` here can be anything
/// ```
///
/// Takes a [`jni::JNIEnv`], a `jobject`, and a value to return from the surrounding function if
/// the `jobject` is null.
macro_rules! obj_id_or_throw {
($env:expr, $obj_id:expr) => {
obj_id_or_throw!($env, $obj_id, JObject::null().into_raw())
};
($env:expr, $obj_id:expr,$returning:expr) => {
match JavaObjId::from_raw($env, $obj_id) {
Ok(Some(id)) => id,
Ok(None) => {
$env.throw_new(
"java/lang/IllegalArgumentException",
"Object ID cannot be null",
)
.unwrap();
return $returning;
}
Err(e) => {
use crate::AUTOMERGE_EXCEPTION;
$env.throw_new(AUTOMERGE_EXCEPTION, format!("{}", e))
.unwrap();
return $returning;
}
}
};
}
pub(crate) use obj_id_or_throw;

#[no_mangle]
#[jni_fn]
pub unsafe extern "C" fn rootObjectId(
Expand All @@ -82,7 +138,7 @@ pub unsafe extern "C" fn isRootObjectId(
_class: jni::objects::JClass,
obj: jni::sys::jobject,
) -> bool {
let obj = JavaObjId::from_raw(&env, obj).unwrap();
let obj = obj_id_or_throw!(&env, obj, false);
obj.as_ref() == &automerge::ROOT
}

Expand All @@ -93,7 +149,7 @@ pub unsafe extern "C" fn objectIdToString(
_class: jni::objects::JClass,
obj: jni::sys::jobject,
) -> jobject {
let obj = JavaObjId::from_raw(&env, obj).unwrap();
let obj = obj_id_or_throw!(&env, obj);
let s = obj.as_ref().to_string();
let jstr = env.new_string(s).unwrap();
jstr.into_raw()
Expand All @@ -104,9 +160,9 @@ pub unsafe extern "C" fn objectIdToString(
pub unsafe extern "C" fn objectIdHash(
env: jni::JNIEnv,
_class: jni::objects::JClass,
left: jni::sys::jobject,
obj: jni::sys::jobject,
) -> jint {
let obj = JavaObjId::from_raw(&env, left).unwrap();
let obj = obj_id_or_throw!(&env, obj, 0);
let mut hasher = DefaultHasher::new();
obj.as_ref().hash(&mut hasher);
hasher.finish() as i32
Expand All @@ -122,7 +178,17 @@ pub unsafe extern "C" fn objectIdsEqual(
) -> jboolean {
let left = JavaObjId::from_raw(&env, left).unwrap();
let right = JavaObjId::from_raw(&env, right).unwrap();
(left.as_ref() == right.as_ref()).into()
match (left, right) {
(None, _) | (_, None) => {
env.throw_new(
"java/lang/IllegalArgumentException",
"Object ID cannot be null",
)
.unwrap();
return false.into();
}
(Some(left), Some(right)) => (left.as_ref() == right.as_ref()).into(),
}
}

pub mod errors {
Expand Down
26 changes: 13 additions & 13 deletions rust/src/read_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::cursor::Cursor;
use crate::interop::{changehash_to_jobject, heads_from_jobject, CHANGEHASH_CLASS};
use crate::java_option::{make_empty_option, make_optional};
use crate::mark::mark_to_java;
use crate::obj_id::JavaObjId;
use crate::obj_id::{obj_id_or_throw, JavaObjId};
use crate::prop::JProp;
use crate::AUTOMERGE_EXCEPTION;
use crate::{interop::AsPointerObj, read_ops::ReadOps};
Expand Down Expand Up @@ -63,7 +63,7 @@ impl SomeReadPointer {
key: P,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);

let key = catch!(env, key.into().try_into_prop(env));
let result = catch!(env, read.get(obj, key));
Expand All @@ -79,7 +79,7 @@ impl SomeReadPointer {
heads_pointer: jobject,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);
let heads = heads_from_jobject(&env, heads_pointer).unwrap();

let key = catch!(env, key.into().try_into_prop(env));
Expand All @@ -96,7 +96,7 @@ impl SomeReadPointer {
heads: Option<jobject>,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);

let key = catch!(env, key.into().try_into_prop(env));
let heads = heads.map(|h| heads_from_jobject(&env, h).unwrap());
Expand Down Expand Up @@ -155,7 +155,7 @@ impl SomeReadPointer {
heads: Option<jobject>,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);
let heads = heads.map(|h| heads_from_jobject(&env, h).unwrap());
let keys = match read.object_type(&obj) {
Ok(automerge::ObjType::Map) => match heads {
Expand Down Expand Up @@ -187,7 +187,7 @@ impl SomeReadPointer {
heads: Option<jobject>,
) -> jlong {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer, 0);
match heads {
Some(h) => {
let heads = heads_from_jobject(&env, h).unwrap();
Expand All @@ -204,7 +204,7 @@ impl SomeReadPointer {
heads: Option<jobject>,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);
let heads = heads.map(|h| heads_from_jobject(&env, h).unwrap());
let items = match read.object_type(&obj) {
Ok(am::ObjType::List) => match heads {
Expand Down Expand Up @@ -244,7 +244,7 @@ impl SomeReadPointer {
heads: Option<jobject>,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);
let heads = heads.map(|h| heads_from_jobject(&env, h).unwrap());

let entries = match read.object_type(&obj) {
Expand Down Expand Up @@ -300,7 +300,7 @@ impl SomeReadPointer {
heads: Option<jobject>,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);
let heads = heads.map(|h| heads_from_jobject(&env, h).unwrap());
let text = match read.object_type(&obj) {
Ok(am::ObjType::Text) => match heads {
Expand Down Expand Up @@ -328,7 +328,7 @@ impl SomeReadPointer {
heads_option: jobject,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);
let heads = maybe_heads(env, heads_option).unwrap();
let marks = if let Some(h) = heads {
read.marks_at(obj, &h)
Expand Down Expand Up @@ -359,7 +359,7 @@ impl SomeReadPointer {
heads_option: jobject,
) -> jobject {
let read = SomeRead::from_pointer(env, self);
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);
let heads = maybe_heads(env, heads_option).unwrap();
let marks = if let Some(h) = heads {
read.get_marks(obj, index as usize, Some(&h))
Expand Down Expand Up @@ -395,7 +395,7 @@ impl SomeReadPointer {
index: jlong,
maybe_heads_pointer: jobject,
) -> jobject {
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer);
let heads = maybe_heads(env, maybe_heads_pointer).unwrap();
let read = SomeRead::from_pointer(env, self);
if index < 0 {
Expand All @@ -421,7 +421,7 @@ impl SomeReadPointer {
cursor_pointer: jobject,
maybe_heads_pointer: jobject,
) -> jlong {
let obj = JavaObjId::from_raw(&env, obj_pointer).unwrap();
let obj = obj_id_or_throw!(&env, obj_pointer, 0);
let heads = maybe_heads(env, maybe_heads_pointer).unwrap();
let read = SomeRead::from_pointer(env, self);

Expand Down
7 changes: 5 additions & 2 deletions rust/src/transaction/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use jni::{
sys::{jlong, jobject},
};

use crate::{obj_id::JavaObjId, AUTOMERGE_EXCEPTION};
use crate::{
obj_id::{obj_id_or_throw, JavaObjId},
AUTOMERGE_EXCEPTION,
};

use super::{do_tx_op, TransactionOp};

Expand All @@ -18,7 +21,7 @@ impl TransactionOp for DeleteOp {
type Output = ();

unsafe fn execute<T: Transactable>(self, env: jni::JNIEnv, tx: &mut T) -> Self::Output {
let obj = JavaObjId::from_raw(&env, self.obj).unwrap();
let obj = obj_id_or_throw!(&env, self.obj, ());
match tx.delete(obj, self.key) {
Ok(_) => {}
Err(e) => {
Expand Down
7 changes: 5 additions & 2 deletions rust/src/transaction/increment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use jni::{
sys::{jlong, jobject},
};

use crate::{obj_id::JavaObjId, AUTOMERGE_EXCEPTION};
use crate::{
obj_id::{obj_id_or_throw, JavaObjId},
AUTOMERGE_EXCEPTION,
};

use super::{do_tx_op, TransactionOp};

Expand All @@ -22,7 +25,7 @@ impl TransactionOp for IncrementOp {
env: jni::JNIEnv,
tx: &mut T,
) -> Self::Output {
let obj = JavaObjId::from_raw(&env, self.obj).unwrap();
let obj = obj_id_or_throw!(&env, self.obj, ());
match tx.increment(obj, self.key, self.value) {
Ok(_) => {}
Err(e) => {
Expand Down
10 changes: 7 additions & 3 deletions rust/src/transaction/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ use jni::{
sys::{jboolean, jbyteArray, jdouble, jlong, jobject},
};

use crate::{obj_id::JavaObjId, obj_type::JavaObjType, AUTOMERGE_EXCEPTION};
use crate::{
obj_id::{obj_id_or_throw, JavaObjId},
obj_type::JavaObjType,
AUTOMERGE_EXCEPTION,
};

use super::{do_tx_op, TransactionOp};

Expand All @@ -23,7 +27,7 @@ impl TransactionOp for InsertOp<am::ScalarValue> {
env: jni::JNIEnv,
tx: &mut T,
) -> Self::Output {
let obj = JavaObjId::from_raw(&env, self.obj).unwrap();
let obj = obj_id_or_throw!(&env, self.obj, ());
let idx = match usize::try_from(self.index) {
Ok(i) => i,
Err(_) => {
Expand All @@ -49,7 +53,7 @@ impl TransactionOp for InsertOp<ObjType> {
env: jni::JNIEnv,
tx: &mut T,
) -> Self::Output {
let obj = JavaObjId::from_raw(&env, self.obj).unwrap();
let obj = obj_id_or_throw!(&env, self.obj);
let idx = match usize::try_from(self.index) {
Ok(i) => i,
Err(_) => {
Expand Down
10 changes: 7 additions & 3 deletions rust/src/transaction/mark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ use jni::{
sys::{jboolean, jdouble, jlong, jobject, jstring},
};

use crate::{expand_mark, obj_id::JavaObjId, AUTOMERGE_EXCEPTION};
use crate::{
expand_mark,
obj_id::{obj_id_or_throw, JavaObjId},
AUTOMERGE_EXCEPTION,
};

use super::{do_tx_op, TransactionOp};

Expand All @@ -28,7 +32,7 @@ impl TransactionOp for MarkOp {
let name_str = JString::from_raw(self.name);
let name: String = env.get_string(name_str).unwrap().into();
let mark = am::marks::Mark::new(name, self.value, self.start, self.end);
let obj = JavaObjId::from_raw(&env, self.obj).unwrap();
let obj = obj_id_or_throw!(&env, self.obj, ());
match tx.mark(obj, mark, expand) {
Ok(_) => {}
Err(e) => {
Expand Down Expand Up @@ -315,7 +319,7 @@ impl TransactionOp for Unmark {
let expand = expand_mark::from_java(&env, expand_obj).unwrap();
let name_str = JString::from_raw(self.name);
let name: String = env.get_string(name_str).unwrap().into();
let obj = JavaObjId::from_raw(&env, self.obj).unwrap();
let obj = obj_id_or_throw!(&env, self.obj, ());
match tx.unmark(obj, &name, self.start, self.end, expand) {
Ok(_) => {}
Err(e) => {
Expand Down
Loading

0 comments on commit 312d97d

Please sign in to comment.