Skip to content

Commit

Permalink
[CALCITE-6623] The MongoDB adapter throws a java.lang.ClassCastExcept…
Browse files Browse the repository at this point in the history
…ion when Decimal128 or Binary types are used, or when a primitive value is cast to a string
  • Loading branch information
dssysolyatin authored and dssysolyatin committed Oct 11, 2024
1 parent 7ce986f commit 9add280
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,28 @@
*/
package org.apache.calcite.adapter.mongodb;

import org.apache.calcite.avatica.util.ByteString;
import org.apache.calcite.avatica.util.DateTimeUtils;
import org.apache.calcite.linq4j.Enumerator;
import org.apache.calcite.linq4j.function.Function1;
import org.apache.calcite.linq4j.tree.Primitive;

import com.mongodb.client.MongoCursor;

import org.bson.BsonTimestamp;
import org.bson.Document;
import org.bson.types.Binary;
import org.bson.types.Decimal128;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.math.BigDecimal;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import static java.lang.String.format;

/** Enumerator that reads from a MongoDB collection. */
class MongoEnumerator implements Enumerator<Object> {
private final Iterator<Document> cursor;
Expand Down Expand Up @@ -89,7 +96,7 @@ static Function1<Document, Map> mapGetter() {
/** Returns a function that projects a single field. */
static Function1<Document, Object> singletonGetter(final String fieldName,
final Class fieldClass) {
return a0 -> convert(a0.get(fieldName), fieldClass);
return a0 -> convert(fieldName, a0.get(fieldName), fieldClass);
}

/** Returns a function that projects fields.
Expand All @@ -103,7 +110,7 @@ static Function1<Document, Object[]> listGetter(
for (int i = 0; i < fields.size(); i++) {
final Map.Entry<String, Class> field = fields.get(i);
final String name = field.getKey();
objects[i] = convert(a0.get(name), field.getValue());
objects[i] = convert(name, a0.get(name), field.getValue());
}
return objects;
};
Expand All @@ -119,8 +126,34 @@ static Function1<Document, Object> getter(
: (Function1) listGetter(fields);
}

/**
* Converts the given object to a specific runtime type based on the provided class.
*
* @param fieldName The name of the field being processed, used for error reporting if
* conversion fails.
* @param o The object to be converted. If `null`, the method returns `null` immediately.
* @param clazz The target class to which the object `o` should be converted.
* @return The converted object as an instance of the specified `clazz`, or `null` if `o` is
* `null`.
*
* @throws IllegalArgumentException if the object `o` cannot be converted to the desired
* `clazz` type, including a message indicating the field name, expected data type, and the
* invalid value.
*
* <h3>Conversion Details</h3>:
*
* <p>If the target type is one of the following, the method performs specific conversions:
* <ul>
* <li>`Long`: Converts a `Date` or `BsonTimestamp` object into the respective epoch time
* (milliseconds).
* <li>`BigDecimal`: Converts a `Decimal128` object into a `BigDecimal` instance.
* <li>`String`: Converts arrays to string and uses `String.valueOf(o)` for other objects.
* <li>`ByteString`: Converts a `Binary` object into a `ByteString` instance.
* </ul>
*
*/
@SuppressWarnings("JavaUtilDate")
private static Object convert(Object o, Class clazz) {
private static Object convert(String fieldName, Object o, Class clazz) {
if (o == null) {
return null;
}
Expand All @@ -133,14 +166,41 @@ private static Object convert(Object o, Class clazz) {
if (clazz.isInstance(o)) {
return o;
}
if (o instanceof Date && clazz == Long.class) {
o = ((Date) o).getTime();
} else if (o instanceof Date && primitive != null) {
o = ((Date) o).getTime() / DateTimeUtils.MILLIS_PER_DAY;

if (clazz == Long.class) {
if (o instanceof Date) {
return ((Date) o).getTime();
} else if (o instanceof BsonTimestamp) {
return ((BsonTimestamp) o).getTime() * DateTimeUtils.MILLIS_PER_SECOND;
}
} else if (clazz == BigDecimal.class) {
if (o instanceof Decimal128) {
return new BigDecimal(((Decimal128) o).toString());
}
} else if (clazz == String.class) {
if (o.getClass().isArray()) {
return Primitive.OTHER.arrayToString(o);
} else {
return String.valueOf(o);
}
} else if (clazz == ByteString.class) {
if (o instanceof Binary) {
return new ByteString(((Binary) o).getData());
}
}
if (o instanceof Number && primitive != null) {
return primitive.number((Number) o);

if (primitive != null) {
if (o instanceof String) {
return primitive.parse((String) o);
} else if (o instanceof Number) {
return primitive.number((Number) o);
} else if (o instanceof Date) {
return primitive.number(((Date) o).getTime() / DateTimeUtils.MILLIS_PER_DAY);
}
}
return o;

throw new IllegalArgumentException(
format("Invalid field: '%s'. The dataType '%s' is invalid for '%s'.", fieldName,
clazz.getSimpleName(), o));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import net.hydromatic.foodmart.data.json.FoodmartJson;

import org.bson.BsonArray;
import org.bson.BsonBinary;
import org.bson.BsonDateTime;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
Expand Down Expand Up @@ -105,7 +106,7 @@ public static void setUp() throws Exception {
"url"));

// Manually insert data for data-time test.
MongoCollection<BsonDocument> datatypes = database.getCollection("datatypes")
MongoCollection<BsonDocument> datatypes = database.getCollection("datatypes")
.withDocumentClass(BsonDocument.class);
if (datatypes.countDocuments() > 0) {
datatypes.deleteMany(new BsonDocument());
Expand All @@ -117,6 +118,7 @@ public static void setUp() throws Exception {
doc.put("value", new BsonInt32(1231));
doc.put("ownerId", new BsonString("531e7789e4b0853ddb861313"));
doc.put("arr", new BsonArray(Arrays.asList(new BsonString("a"), new BsonString("b"))));
doc.put("binaryData", new BsonBinary("binaryData".getBytes(StandardCharsets.UTF_8)));
datatypes.insertOne(doc);

schema = new MongoSchema(database);
Expand Down Expand Up @@ -697,68 +699,26 @@ private void checkPredicate(int expected, String q) {
* <a href="https://issues.apache.org/jira/browse/CALCITE-286">[CALCITE-286]
* Error casting MongoDB date</a>. */
@Test void testDate() {
assertModel("{\n"
+ " version: '1.0',\n"
+ " defaultSchema: 'test',\n"
+ " schemas: [\n"
+ " {\n"
+ " type: 'custom',\n"
+ " name: 'test',\n"
+ " factory: 'org.apache.calcite.adapter.mongodb.MongoSchemaFactory',\n"
+ " operand: {\n"
+ " host: 'localhost',\n"
+ " database: 'test'\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}")
.query("select cast(_MAP['date'] as DATE) from \"datatypes\"")
assertModel(MODEL)
.query("select cast(_MAP['date'] as DATE) from \"mongo_raw\".\"datatypes\"")
.returnsUnordered("EXPR$0=2012-09-05");
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5405">[CALCITE-5405]
* Error casting MongoDB dates to TIMESTAMP</a>. */
@Test void testDateConversion() {
assertModel("{\n"
+ " version: '1.0',\n"
+ " defaultSchema: 'test',\n"
+ " schemas: [\n"
+ " {\n"
+ " type: 'custom',\n"
+ " name: 'test',\n"
+ " factory: 'org.apache.calcite.adapter.mongodb.MongoSchemaFactory',\n"
+ " operand: {\n"
+ " host: 'localhost',\n"
+ " database: 'test'\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}")
.query("select cast(_MAP['date'] as TIMESTAMP) from \"datatypes\"")
assertModel(MODEL)
.query("select cast(_MAP['date'] as TIMESTAMP) from \"mongo_raw\".\"datatypes\"")
.returnsUnordered("EXPR$0=2012-09-05 00:00:00");
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5407">[CALCITE-5407]
* Error casting MongoDB array to VARCHAR ARRAY</a>. */
@Test void testArrayConversion() {
assertModel("{\n"
+ " version: '1.0',\n"
+ " defaultSchema: 'test',\n"
+ " schemas: [\n"
+ " {\n"
+ " type: 'custom',\n"
+ " name: 'test',\n"
+ " factory: 'org.apache.calcite.adapter.mongodb.MongoSchemaFactory',\n"
+ " operand: {\n"
+ " host: 'localhost',\n"
+ " database: 'test'\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}")
.query("select cast(_MAP['arr'] as VARCHAR ARRAY) from \"datatypes\"")
assertModel(MODEL)
.query("select cast(_MAP['arr'] as VARCHAR ARRAY) from \"mongo_raw\".\"datatypes\"")
.returnsUnordered("EXPR$0=[a, b]");
}

Expand All @@ -778,6 +738,50 @@ private void checkPredicate(int expected, String q) {
});
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6623">[CALCITE-6623]
* MongoDB adapter throws a java.lang.ClassCastException when Decimal128 or Binary types are
* used, or when a primitive value is cast to a string</a>. */
@Test void testRuntimeTypes() {
assertModel(MODEL)
.query("select cast(_MAP['loc'] AS varchar) "
+ "from \"mongo_raw\".\"zips\" where _MAP['_id']='99801'")
.returnsCount(1)
.returnsValue("[-134.529429, 58.362767]");

assertModel(MODEL)
.query("select cast(_MAP['warehouse_postal_code'] AS bigint) AS postal_code_as_bigint"
+ " from \"mongo_raw\".\"warehouse\" where _MAP['warehouse_id']=1")
.returnsCount(1)
.returnsValue("55555")
.typeIs("[POSTAL_CODE_AS_BIGINT BIGINT]");

assertModel(MODEL)
.query("select cast(_MAP['warehouse_postal_code'] AS varchar) AS postal_code_as_varchar"
+ " from \"mongo_raw\".\"warehouse\" where _MAP['warehouse_id']=1")
.returnsCount(1)
.returnsValue("55555")
.typeIs("[POSTAL_CODE_AS_VARCHAR VARCHAR]");

assertModel(MODEL)
.query("select cast(_MAP['binaryData'] AS binary) from \"mongo_raw\".\"datatypes\"")
.returnsCount(1)
.returns(resultSet -> {
try {
resultSet.next();
//CHECKSTYLE: IGNORE 1
assertThat(new String(resultSet.getBytes(1), StandardCharsets.UTF_8), is("binaryData"));
} catch (Throwable e) {
throw new RuntimeException(e);
}
});

assertModel(MODEL)
.query("select cast(_MAP['loc'] AS bigint) "
+ "from \"mongo_raw\".\"zips\" where _MAP['_id']='99801'")
.throws_("Invalid field:");
}

/**
* Returns a function that checks that a particular MongoDB query
* has been called.
Expand Down

0 comments on commit 9add280

Please sign in to comment.