Skip to content

Commit

Permalink
Throw an error if the top-level object is not a list.
Browse files Browse the repository at this point in the history
This makes sense as uzuki2 is, after all, designed to load lists, not assorted
vectors or factors. Nonetheless, it can be disabled with a new 'strict_list'
option if the caller doesn't mind dealing with non-list outputs.

This option requires a new class for the HDF5 parsers. While we're making
such a class, we took the opportunity to allow users to modify the HDF5 buffer
size, rather than hard-coding it into the various functions.
  • Loading branch information
LTLA committed Nov 18, 2023
1 parent fe5e6d7 commit ad65ef8
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 46 deletions.
94 changes: 60 additions & 34 deletions include/uzuki2/parse_hdf5.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ inline H5::DataSet get_scalar_dataset(const H5::Group& handle, const std::string
}

template<class Host, class Function>
void parse_integer_like(const H5::DataSet& handle, Host* ptr, Function check, const Version& version) try {
void parse_integer_like(const H5::DataSet& handle, Host* ptr, Function check, const Version& version, hsize_t buffer_size) try {
if (ritsuko::hdf5::exceeds_integer_limit(handle, 32, true)) {
throw std::runtime_error("dataset cannot be represented by 32-bit signed integers");
}
Expand All @@ -71,7 +71,7 @@ void parse_integer_like(const H5::DataSet& handle, Host* ptr, Function check, co
}

hsize_t full_length = ptr->size();
auto block_size = ritsuko::hdf5::pick_1d_block_size(handle.getCreatePlist(), full_length, /* buffer_size = */ 10000);
auto block_size = ritsuko::hdf5::pick_1d_block_size(handle.getCreatePlist(), full_length, buffer_size);
std::vector<int32_t> buffer(block_size);
ritsuko::hdf5::iterate_1d_blocks(
full_length,
Expand All @@ -94,7 +94,7 @@ void parse_integer_like(const H5::DataSet& handle, Host* ptr, Function check, co
}

template<class Host, class Function>
void parse_string_like(const H5::DataSet& handle, Host* ptr, Function check) try {
void parse_string_like(const H5::DataSet& handle, Host* ptr, Function check, hsize_t buffer_size) try {
auto dtype = handle.getDataType();
if (dtype.getClass() != H5T_STRING) {
throw std::runtime_error("expected a string dataset");
Expand All @@ -111,7 +111,7 @@ void parse_string_like(const H5::DataSet& handle, Host* ptr, Function check) try
ritsuko::hdf5::load_1d_string_dataset(
handle,
ptr->size(),
/* buffer_size = */ 10000,
buffer_size,
[&](size_t i, const char* str, size_t len) -> void {
std::string x(str, str + len);
if (has_missing && x == missing_val) {
Expand All @@ -127,7 +127,7 @@ void parse_string_like(const H5::DataSet& handle, Host* ptr, Function check) try
}

template<class Host, class Function>
void parse_numbers(const H5::DataSet& handle, Host* ptr, Function check, const Version& version) try {
void parse_numbers(const H5::DataSet& handle, Host* ptr, Function check, const Version& version, hsize_t buffer_size) try {
if (version.lt(1, 3)) {
if (handle.getTypeClass() != H5T_FLOAT) {
throw std::runtime_error("expected a floating-point dataset");
Expand Down Expand Up @@ -166,7 +166,7 @@ void parse_numbers(const H5::DataSet& handle, Host* ptr, Function check, const V
};

hsize_t full_length = ptr->size();
auto block_size = ritsuko::hdf5::pick_1d_block_size(handle.getCreatePlist(), full_length, /* buffer_size = */ 10000);
auto block_size = ritsuko::hdf5::pick_1d_block_size(handle.getCreatePlist(), full_length, buffer_size);
std::vector<double> buffer(block_size);
ritsuko::hdf5::iterate_1d_blocks(
full_length,
Expand All @@ -189,7 +189,7 @@ void parse_numbers(const H5::DataSet& handle, Host* ptr, Function check, const V
}

template<class Host>
void extract_names(const H5::Group& handle, Host* ptr) try {
void extract_names(const H5::Group& handle, Host* ptr, hsize_t buffer_size) try {
if (handle.childObjType("names") != H5O_TYPE_DATASET) {
throw std::runtime_error("expected a dataset");
}
Expand All @@ -208,8 +208,8 @@ void extract_names(const H5::Group& handle, Host* ptr) try {

ritsuko::hdf5::load_1d_string_dataset(
nhandle,
nlen,
/* buffer_size = */ 10000,
nlen,
buffer_size,
[&](size_t i, const char* val, size_t len) -> void {
ptr->set_name(i, std::string(val, val + len));
}
Expand All @@ -219,7 +219,7 @@ void extract_names(const H5::Group& handle, Host* ptr) try {
}

template<class Provisioner, class Externals>
std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const Version& version) try {
std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const Version& version, hsize_t buffer_size) try {
// Deciding what type we're dealing with.
auto object_type = ritsuko::hdf5::load_scalar_string_attribute(handle, "uzuki_object");
std::shared_ptr<Base> output;
Expand All @@ -241,11 +241,11 @@ std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const
throw std::runtime_error("expected a group at 'data/" + istr + "'");
}
auto lhandle = dhandle.openGroup(istr);
lptr->set(i, parse_inner<Provisioner>(lhandle, ext, version));
lptr->set(i, parse_inner<Provisioner>(lhandle, ext, version, buffer_size));
}

if (named) {
extract_names(handle, lptr);
extract_names(handle, lptr, buffer_size);
}

} else if (object_type == "vector") {
Expand All @@ -263,7 +263,7 @@ std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const
if (vector_type == "integer") {
auto iptr = Provisioner::new_Integer(len, named, is_scalar);
output.reset(iptr);
parse_integer_like(dhandle, iptr, [](int32_t) -> void {}, version);
parse_integer_like(dhandle, iptr, [](int32_t) -> void {}, version, buffer_size);

} else if (vector_type == "boolean") {
auto bptr = Provisioner::new_Boolean(len, named, is_scalar);
Expand All @@ -272,7 +272,7 @@ std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const
if (x != 0 && x != 1) {
throw std::runtime_error("boolean values should be 0 or 1");
}
}, version);
}, version, buffer_size);

} else if (vector_type == "factor" || (version.equals(1, 0) && vector_type == "ordered")) {
auto levhandle = ritsuko::hdf5::get_dataset(handle, "levels");
Expand All @@ -298,13 +298,13 @@ std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const
if (x < 0 || x >= levlen) {
throw std::runtime_error("factor codes should be non-negative and less than the number of levels");
}
}, version);
}, version, buffer_size);

std::unordered_set<std::string> present;
ritsuko::hdf5::load_1d_string_dataset(
levhandle,
levlen,
/* buffer_size = */ 10000,
buffer_size,
[&](size_t i, const char* val, size_t len) -> void {
std::string x(val, val + len);
if (present.find(x) != present.end()) {
Expand All @@ -328,7 +328,7 @@ std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const
ritsuko::hdf5::load_1d_string_dataset(
fhandle,
1,
/* buffer_size = */ 10000,
buffer_size,
[&](size_t, const char* val, size_t len) -> void {
std::string x(val, val + len);
if (x == "date") {
Expand All @@ -345,35 +345,35 @@ std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const
auto sptr = Provisioner::new_String(len, named, is_scalar, format);
output.reset(sptr);
if (format == StringVector::NONE) {
parse_string_like(dhandle, sptr, [](const std::string&) -> void {});
parse_string_like(dhandle, sptr, [](const std::string&) -> void {}, buffer_size);

} else if (format == StringVector::DATE) {
parse_string_like(dhandle, sptr, [&](const std::string& x) -> void {
if (!ritsuko::is_date(x.c_str(), x.size())) {
throw std::runtime_error("dates should follow YYYY-MM-DD formatting");
}
});
}, buffer_size);

} else if (format == StringVector::DATETIME) {
parse_string_like(dhandle, sptr, [&](const std::string& x) -> void {
if (!ritsuko::is_rfc3339(x.c_str(), x.size())) {
throw std::runtime_error("date-times should follow the Internet Date/Time format");
}
});
}, buffer_size);
}

} else if (vector_type == "number") {
auto dptr = Provisioner::new_Number(len, named, is_scalar);
output.reset(dptr);
parse_numbers(dhandle, dptr, [](double) -> void {}, version);
parse_numbers(dhandle, dptr, [](double) -> void {}, version, buffer_size);

} else {
throw std::runtime_error("unknown vector type '" + vector_type + "'");
}

if (named) {
auto vptr = static_cast<Vector*>(output.get());
extract_names(handle, vptr);
extract_names(handle, vptr, buffer_size);
}

} else if (object_type == "nothing") {
Expand Down Expand Up @@ -411,12 +411,28 @@ std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const
* @endcond
*/

/**
* @brief Options for HDF5 file parsing.
*/
struct Options {
/**
* Buffer size, in terms of the number of elements, to use for reading data from HDF5 datasets.
*/
hsize_t buffer_size = 10000;

/**
* Whether to throw an error if the top-level R object is not an R list.
*/
bool strict_list = true;
};

/**
* @tparam Provisioner A class namespace defining static methods for creating new `Base` objects.
* @tparam Externals Class describing how to resolve external references for type `EXTERNAL`.
*
* @param handle Handle for a HDF5 group corresponding to the list.
* @param ext Instance of an external reference resolver class.
* @param options Optional parameters.
*
* @return A `ParsedList` containing a pointer to the root `Base` object.
* Depending on `Provisioner`, this may contain references to all nested objects.
Expand Down Expand Up @@ -456,7 +472,7 @@ std::shared_ptr<Base> parse_inner(const H5::Group& handle, Externals& ext, const
* - `size_t size()`, which returns the number of available external references.
*/
template<class Provisioner, class Externals>
ParsedList parse(const H5::Group& handle, Externals ext) {
ParsedList parse(const H5::Group& handle, Externals ext, Options options = Options()) {
Version version;
if (handle.attrExists("uzuki_version")) {
auto ver_str = ritsuko::hdf5::load_scalar_string_attribute(handle, "uzuki_version");
Expand All @@ -466,8 +482,13 @@ ParsedList parse(const H5::Group& handle, Externals ext) {
}

ExternalTracker etrack(std::move(ext));
auto ptr = parse_inner<Provisioner>(handle, etrack, version);
auto ptr = parse_inner<Provisioner>(handle, etrack, version, options.buffer_size);

if (options.strict_list && ptr->type() != LIST) {
throw std::runtime_error("top-level object should represent an R list");
}
etrack.validate();

return ParsedList(std::move(ptr), std::move(version));
}

Expand All @@ -478,15 +499,16 @@ ParsedList parse(const H5::Group& handle, Externals ext) {
* @tparam Provisioner A class namespace defining static methods for creating new `Base` objects.
*
* @param handle Handle for a HDF5 group corresponding to the list.
* @param options Optional parameters.
*
* @return A `ParsedList` containing a pointer to the root `Base` object.
* Depending on `Provisioner`, this may contain references to all nested objects.
*
* Any invalid representations in `contents` will cause an error to be thrown.
*/
template<class Provisioner>
ParsedList parse(const H5::Group& handle) {
return parse<Provisioner>(handle, uzuki2::DummyExternals(0));
ParsedList parse(const H5::Group& handle, Options options = Options()) {
return parse<Provisioner>(handle, uzuki2::DummyExternals(0), std::move(options));
}

/**
Expand All @@ -498,16 +520,17 @@ ParsedList parse(const H5::Group& handle) {
* @param file Path to a HDF5 file.
* @param name Name of the HDF5 group containing the list in `file`.
* @param ext Instance of an external reference resolver class.
* @param options Optional parameters.
*
* @return A `ParsedList` containing a pointer to the root `Base` object.
* Depending on `Provisioner`, this may contain references to all nested objects.
*
* Any invalid representations in `contents` will cause an error to be thrown.
*/
template<class Provisioner, class Externals>
ParsedList parse(const std::string& file, const std::string& name, Externals ext) {
ParsedList parse(const std::string& file, const std::string& name, Externals ext, Options options = Options()) {
H5::H5File handle(file, H5F_ACC_RDONLY);
return parse<Provisioner>(handle.openGroup(name), std::move(ext));
return parse<Provisioner>(handle.openGroup(name), std::move(ext), std::move(options));
}

/**
Expand All @@ -518,16 +541,17 @@ ParsedList parse(const std::string& file, const std::string& name, Externals ext
*
* @param file Path to a HDF5 file.
* @param name Name of the HDF5 group containing the list in `file`.
* @param options Optional parameters.
*
* @return A `ParsedList` containing a pointer to the root `Base` object.
* Depending on `Provisioner`, this may contain references to all nested objects.
*
* Any invalid representations in `contents` will cause an error to be thrown.
*/
template<class Provisioner>
ParsedList parse(const std::string& file, const std::string& name) {
ParsedList parse(const std::string& file, const std::string& name, Options options = Options()) {
H5::H5File handle(file, H5F_ACC_RDONLY);
return parse<Provisioner>(handle.openGroup(name), uzuki2::DummyExternals(0));
return parse<Provisioner>(handle.openGroup(name), uzuki2::DummyExternals(0), std::move(options));
}

/**
Expand All @@ -538,10 +562,11 @@ ParsedList parse(const std::string& file, const std::string& name) {
* @param name Name of the HDF5 group corresponding to `handle`.
* Only used for error messages.
* @param num_external Expected number of external references.
* @param options Optional parameters.
*/
inline void validate(const H5::Group& handle, int num_external = 0) {
inline void validate(const H5::Group& handle, int num_external = 0, Options options = Options()) {
DummyExternals ext(num_external);
parse<DummyProvisioner>(handle, ext);
parse<DummyProvisioner>(handle, ext, std::move(options));
return;
}

Expand All @@ -552,10 +577,11 @@ inline void validate(const H5::Group& handle, int num_external = 0) {
* @param file Path to a HDF5 file.
* @param name Name of the HDF5 group containing the list in `file`.
* @param num_external Expected number of external references.
* @param options Optional parameters.
*/
inline void validate(const std::string& file, const std::string& name, int num_external = 0) {
inline void validate(const std::string& file, const std::string& name, int num_external = 0, Options options = Options()) {
DummyExternals ext(num_external);
parse<DummyProvisioner>(file, name, ext);
parse<DummyProvisioner>(file, name, ext, std::move(options));
return;
}

Expand Down
10 changes: 10 additions & 0 deletions include/uzuki2/parse_json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ struct Options {
* If true, an extra thread is used to avoid blocking I/O operations.
*/
bool parallel = false;

/**
* Whether to throw an error if the top-level R object is not an R list.
*/
bool strict_list = true;
};

/**
Expand Down Expand Up @@ -443,7 +448,12 @@ ParsedList parse(byteme::Reader& reader, Externals ext, Options options = Option

ExternalTracker etrack(std::move(ext));
auto output = parse_object<Provisioner>(contents.get(), etrack, "", version);

if (options.strict_list && output->type() != LIST) {
throw std::runtime_error("top-level object should represent an R list");
}
etrack.validate();

return ParsedList(std::move(output), std::move(version));
}

Expand Down
14 changes: 10 additions & 4 deletions tests/src/external.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

TEST(Hdf5ExternalTest, SimpleLoading) {
auto path = "TEST-external.h5";
uzuki2::hdf5::Options opt;
opt.strict_list = false;

// Simple stuff works correctly.
{
Expand All @@ -17,7 +19,7 @@ TEST(Hdf5ExternalTest, SimpleLoading) {
}
{
DefaultExternals ext(1);
auto parsed = uzuki2::hdf5::parse<DefaultProvisioner>(path, "foo", ext);
auto parsed = uzuki2::hdf5::parse<DefaultProvisioner>(path, "foo", ext, opt);
EXPECT_EQ(parsed->type(), uzuki2::EXTERNAL);

auto stuff = static_cast<const DefaultExternal*>(parsed.get());
Expand All @@ -36,7 +38,7 @@ TEST(Hdf5ExternalTest, SimpleLoading) {
}
{
DefaultExternals ext(2);
auto parsed = uzuki2::hdf5::parse<DefaultProvisioner>(path, "foo", ext);
auto parsed = uzuki2::hdf5::parse<DefaultProvisioner>(path, "foo", ext, opt);
EXPECT_EQ(parsed->type(), uzuki2::LIST);
auto list = static_cast<const DefaultList*>(parsed.get());

Expand All @@ -50,9 +52,11 @@ TEST(Hdf5ExternalTest, SimpleLoading) {

void expect_hdf5_external_error(std::string path, std::string name, std::string msg, int num_expected) {
H5::H5File file(path, H5F_ACC_RDONLY);
uzuki2::hdf5::Options opt;
opt.strict_list = false;
EXPECT_ANY_THROW({
try {
uzuki2::hdf5::validate(file.openGroup(name), num_expected);
uzuki2::hdf5::validate(file.openGroup(name), num_expected, std::move(opt));
} catch (std::exception& e) {
EXPECT_THAT(e.what(), ::testing::HasSubstr(msg));
throw;
Expand Down Expand Up @@ -105,7 +109,9 @@ TEST(Hdf5ExternalTest, CheckErrors) {

auto load_json_with_externals(std::string x, int num_externals) {
DefaultExternals ext(num_externals);
return uzuki2::json::parse_buffer<DefaultProvisioner>(reinterpret_cast<const unsigned char*>(x.c_str()), x.size(), ext);
uzuki2::json::Options opt;
opt.strict_list = false;
return uzuki2::json::parse_buffer<DefaultProvisioner>(reinterpret_cast<const unsigned char*>(x.c_str()), x.size(), ext, std::move(opt));
}

TEST(JsonExternalTest, SimpleLoading) {
Expand Down
Loading

0 comments on commit ad65ef8

Please sign in to comment.