diff --git a/include/mockturtle/io/serialize.hpp b/include/mockturtle/io/serialize.hpp index 5b0c984bb..b02a1a5b7 100644 --- a/include/mockturtle/io/serialize.hpp +++ b/include/mockturtle/io/serialize.hpp @@ -42,6 +42,7 @@ #include "../networks/aig.hpp" #include +#include #include namespace mockturtle @@ -292,6 +293,17 @@ struct serializer } /* namespace detail */ +/*! \brief Serializes a combinational AIG network to a archive, returning false on failure + * + * \param aig Combinational AIG network + * \param os Output archive + */ +inline bool serialize_network_fallible( aig_network const& aig, phmap::BinaryOutputArchive& os ) +{ + detail::serializer _serializer; + return _serializer( os, *aig._storage ); +} + /*! \brief Serializes a combinational AIG network to a archive * * \param aig Combinational AIG network @@ -299,8 +311,7 @@ struct serializer */ inline void serialize_network( aig_network const& aig, phmap::BinaryOutputArchive& os ) { - detail::serializer _serializer; - bool const okay = _serializer( os, *aig._storage ); + bool const okay = serialize_network_fallible( aig, os ); (void)okay; assert( okay && "failed to serialize the network onto stream" ); } @@ -316,12 +327,12 @@ inline void serialize_network( aig_network const& aig, std::string const& filena serialize_network( aig, ar_out ); } -/*! \brief Deserializes a combinational AIG network from a input archive +/*! \brief Deserializes a combinational AIG network from a input archive, returning nullopt on failure * * \param ar_input Input archive * \return Deserialized AIG network */ -inline aig_network deserialize_network( phmap::BinaryInputArchive& ar_input ) +inline std::optional deserialize_network_fallible( phmap::BinaryInputArchive& ar_input ) { detail::serializer _serializer; auto storage = std::make_shared(); @@ -330,10 +341,25 @@ inline aig_network deserialize_network( phmap::BinaryInputArchive& ar_input ) storage->outputs.clear(); storage->hash.clear(); - bool const okay = _serializer( ar_input, storage.get() ); - (void)okay; - assert( okay && "failed to deserialize the network onto stream" ); - return aig_network{ storage }; + if ( _serializer( ar_input, storage.get() ) ) + { + return aig_network{ storage }; + } + + return std::nullopt; +} + +/*! \brief Deserializes a combinational AIG network from a input archive + * + * \param ar_input Input archive + * \return Deserialized AIG network + */ +inline aig_network deserialize_network( phmap::BinaryInputArchive& ar_input ) +{ + auto result = deserialize_network_fallible( ar_input ); + (void)result.has_value(); + assert( result.has_value() && "failed to deserialize the network onto stream" ); + return *result; } /*! \brief Deserializes a combinational AIG network from a file diff --git a/lib/parallel_hashmap/parallel_hashmap/phmap_dump.h b/lib/parallel_hashmap/parallel_hashmap/phmap_dump.h index 57904dc12..50e698c8c 100644 --- a/lib/parallel_hashmap/parallel_hashmap/phmap_dump.h +++ b/lib/parallel_hashmap/parallel_hashmap/phmap_dump.h @@ -21,6 +21,7 @@ #include #include +#include #include #include "phmap.h" namespace phmap @@ -179,11 +180,18 @@ bool parallel_hash_set::load(InputArch // ------------------------------------------------------------------------ class BinaryOutputArchive { public: - BinaryOutputArchive(const char *file_path) { + BinaryOutputArchive(const char *file_path, + size_t bytes_remaining = std::numeric_limits::max()) + : bytes_remaining_(bytes_remaining) { ofs_.open(file_path, std::ios_base::binary); } bool dump(const char *p, size_t sz) { + if ( sz > bytes_remaining_ ) { + bytes_remaining_ = 0; + return false; + } + bytes_remaining_ -= sz; ofs_.write(p, sz); return ofs_.good(); } @@ -191,8 +199,7 @@ class BinaryOutputArchive { template typename std::enable_if::value, bool>::type dump(const V& v) { - ofs_.write(reinterpret_cast(&v), sizeof(V)); - return ofs_.good(); + return dump(reinterpret_cast(&v), sizeof(V)); } bool close() { @@ -202,16 +209,24 @@ class BinaryOutputArchive { private: std::ofstream ofs_; + size_t bytes_remaining_; }; class BinaryInputArchive { public: - BinaryInputArchive(const char * file_path) { + BinaryInputArchive(const char * file_path, + size_t bytes_remaining = std::numeric_limits::max()) + : bytes_remaining_(bytes_remaining) { ifs_.open(file_path, std::ios_base::binary); } bool load(char* p, size_t sz) { + if ( sz > bytes_remaining_ ) { + bytes_remaining_ = 0; + return false; + } + bytes_remaining_ -= sz; ifs_.read(p, sz); return ifs_.good(); } @@ -219,12 +234,12 @@ class BinaryInputArchive { template typename std::enable_if::value, bool>::type load(V* v) { - ifs_.read(reinterpret_cast(v), sizeof(V)); - return ifs_.good(); + return load(reinterpret_cast(v), sizeof(V)); } private: std::ifstream ifs_; + size_t bytes_remaining_; }; } // namespace phmap diff --git a/test/io/serialize.cpp b/test/io/serialize.cpp index bf6f68d4c..3405b90fa 100644 --- a/test/io/serialize.cpp +++ b/test/io/serialize.cpp @@ -1,9 +1,19 @@ #include +#include + #include using namespace mockturtle; +#if __GNUC__ == 7 +namespace fs = std::experimental::filesystem::v1; +#else +namespace fs = std::filesystem; +#endif + +static constexpr char file_name[] = "aig.dmp" ; + TEST_CASE( "serialize aig_network into a file", "[serialize]" ) { aig_network aig; @@ -19,10 +29,10 @@ TEST_CASE( "serialize aig_network into a file", "[serialize]" ) aig.create_po( f5 ); /* serialize */ - serialize_network( aig, "aig.dmp" ); + serialize_network( aig, file_name ); /* deserialize */ - aig_network aig2 = deserialize_network( "aig.dmp" ); + aig_network aig2 = deserialize_network( file_name ); CHECK( aig.size() == aig2.size() ); CHECK( aig.num_cis() == aig2.num_cis() ); @@ -46,3 +56,69 @@ TEST_CASE( "serialize aig_network into a file", "[serialize]" ) CHECK( aig2._storage->nodes[f5.index].children[0u].index == f4.index ); CHECK( aig2._storage->nodes[f5.index].children[1u].index == f3.index ); } + +static aig_network create_network() +{ + aig_network aig; + + const auto a = aig.create_pi(); + const auto b = aig.create_pi(); + + const auto f1 = aig.create_nand( a, b ); + const auto f3 = aig.create_nand( b, f1 ); + const auto f4 = aig.create_nand( a, f1 ); + const auto f5 = aig.create_nand( f4, f3 ); + aig.create_po( f5 ); + + return aig; +} + +// These numbers were chosen to get 100% coverage of the `return false` +// error paths in `serialize.hpp`. +// +// To find a value that gives coverage of a particular `return false` +// statement, change the loops below to iterate from 0 to 1000000, +// configure with `-DCMAKE_BUILD_TYPE=DEBUG` and run in a debugger +// with a breakpoint set at the line of interest. When the breakpoint +// is hit, get the value of `size` from its stack frame and add it +// to this list. +static constexpr int truncate_sizes[] = +{ + 0, 8, 16, 32, 40, 344, 352, 368, 376, 384, 672120 +}; + +TEST_CASE( "write errors are propagated", "[serialize]" ) +{ + aig_network aig = create_network(); + + serialize_network( aig, file_name ); + size_t file_size = fs::file_size( file_name ); + INFO("File size " << file_size); + + for ( size_t size : truncate_sizes ) { + if ( size >= file_size ) + { + break; + } + phmap::BinaryOutputArchive output ( file_name, size ); + CHECK_FALSE( serialize_network_fallible( aig, output ) ); + } +} + +TEST_CASE( "read errors are propagated", "[serialize]" ) +{ + aig_network aig = create_network(); + + serialize_network( aig, file_name ); + size_t file_size = fs::file_size( file_name ); + INFO("File size " << file_size); + + for ( size_t size : truncate_sizes ) { + if ( size >= file_size ) + { + break; + } + phmap::BinaryInputArchive input ( file_name, size ); + CHECK_FALSE( deserialize_network_fallible( input ).has_value() ); + } +}