diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2e69fa82b6..d9a5f7d8c2 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -33,6 +33,11 @@ on: type: string default: '' required: false + crypto_provider: + description: 'cryptographic provider to use' + type: string + default: 'GnuTLS' + required: false jobs: test: @@ -83,14 +88,15 @@ jobs: if ${{ inputs.enable-ccache }}; then MAYBE_CCACHE_OPT="--ccache" fi - ./configure.py \ - --c++-standard ${{ inputs.standard }} \ - --compiler ${{ inputs.compiler }} \ - --c-compiler $CC \ - --mode ${{ inputs.mode }} \ - $MAYBE_CCACHE_OPT \ - ${{ inputs.options }} \ - ${{ inputs.enables }} + ./configure.py \ + --c++-standard ${{ inputs.standard }} \ + --compiler ${{ inputs.compiler }} \ + --c-compiler $CC \ + --mode ${{ inputs.mode }} \ + $MAYBE_CCACHE_OPT \ + ${{ inputs.options }} \ + ${{ inputs.enables }} \ + --crypto-provider ${{ inputs.crypto_provider }} - name: Build run: cmake --build build/${{inputs.mode}} diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 91e23d8e3c..386b6d0dab 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,7 +11,7 @@ concurrency: jobs: regular_test: - name: "Test (${{ matrix.compiler }}, C++${{ matrix.standard}}, ${{ matrix.mode }})" + name: "Test (${{ matrix.compiler }}, C++${{ matrix.standard}}, ${{ matrix.mode }}, ${{ matrix.crypto_provider }})" uses: ./.github/workflows/test.yaml strategy: fail-fast: false @@ -19,12 +19,14 @@ jobs: compiler: [clang++, g++] standard: [20, 23] mode: [dev, debug, release] + crypto_provider: [GnuTLS, OpenSSL] with: compiler: ${{ matrix.compiler }} standard: ${{ matrix.standard }} mode: ${{ matrix.mode }} enables: ${{ matrix.enables }} options: ${{ matrix.options }} + crypto_provider: ${{ matrix.crypto_provider }} build_with_dpdk: name: "Test with DPDK enabled" uses: ./.github/workflows/test.yaml @@ -36,8 +38,8 @@ jobs: mode: release enables: --enable-dpdk options: --cook dpdk - build_with_cxx_modules: - name: "Test with C++20 modules enabled" + build_with_cxx_modules_gnutls: + name: "Test with C++20 modules enabled (GnuTLS)" uses: ./.github/workflows/test.yaml strategy: fail-fast: false @@ -47,3 +49,16 @@ jobs: mode: debug enables: --enable-cxx-modules enable-ccache: false + crypto_provider: GnuTLS + build_with_cxx_modules_openssl: + name: "Test with C++20 modules enabled (OpenSSL)" + uses: ./.github/workflows/test.yaml + strategy: + fail-fast: false + with: + compiler: clang++ + standard: 23 + mode: debug + enables: --enable-cxx-modules + enable-ccache: false + crypto_provider: OpenSSL diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f2f7b6ce3..157ec0f6be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,16 @@ if (NOT Seastar_SCHEDULING_GROUPS_COUNT MATCHES "^[1-9][0-9]*") message(FATAL_ERROR "Seastar_SCHEDULING_GROUPS_COUNT must be a positive number (${Seastar_SCHEDULING_GROUPS_COUNT})") endif () +option (Seastar_USE_OPENSSL + "Use OpenSSL rather than GnuTLS for cryptographic operations, including TLS" + OFF) + +if (Seastar_USE_OPENSSL) + set(Seastar_USE_GNUTLS OFF) +else() + set(Seastar_USE_GNUTLS ON) +endif () + # # Add a dev build type. # @@ -753,13 +763,16 @@ add_library (seastar src/net/native-stack-impl.hh src/net/native-stack.cc src/net/net.cc + $<$:src/net/ossl.cc> src/net/packet.cc src/net/posix-stack.cc src/net/proxy.cc src/net/socket_address.cc src/net/stack.cc src/net/tcp.cc - src/net/tls.cc + $<$:src/net/tls.cc> + src/net/tls-impl.cc + src/net/tls-impl.hh src/net/udp.cc src/net/unix_address.cc src/net/virtio.cc @@ -778,6 +791,9 @@ add_library (seastar src/util/tmp_file.cc src/util/short_streams.cc src/websocket/server.cc + src/websocket/base64.hh + $<$:src/websocket/base64-gnutls.cc> + $<$:src/websocket/base64-openssl.cc> ) add_library (Seastar::seastar ALIAS seastar) @@ -856,7 +872,9 @@ target_link_libraries (seastar SourceLocation::source_location PRIVATE ${CMAKE_DL_LIBS} - GnuTLS::gnutls + $<$:GnuTLS::gnutls> + $<$:OpenSSL::SSL> + $<$:OpenSSL::Crypto> StdAtomic::atomic lksctp-tools::lksctp-tools protobuf::libprotobuf @@ -913,6 +931,8 @@ include (CTest) target_compile_definitions(seastar PUBLIC SEASTAR_API_LEVEL=${Seastar_API_LEVEL} + $<$:SEASTAR_USE_OPENSSL> + $<$:SEASTAR_USE_GNUTLS> $<$:SEASTAR_BUILD_SHARED_LIBS>) target_compile_features(seastar diff --git a/cmake/SeastarDependencies.cmake b/cmake/SeastarDependencies.cmake index fea6b3d61e..5695c42d7f 100644 --- a/cmake/SeastarDependencies.cmake +++ b/cmake/SeastarDependencies.cmake @@ -91,6 +91,7 @@ macro (seastar_find_dependencies) GnuTLS LibUring LinuxMembarrier + OpenSSL # Protobuf is searched manually. Sanitizers SourceLocation @@ -124,8 +125,12 @@ macro (seastar_find_dependencies) VERSION 8.1.1) seastar_set_dep_args (lz4 REQUIRED VERSION 1.7.3) + seastar_set_dep_args (OpenSSL + VERSION 3.0.0 + OPTION ${Seastar_USE_OPENSSL}) seastar_set_dep_args (GnuTLS REQUIRED - VERSION 3.3.26) + VERSION 3.3.26 + OPTION ${Seastar_USE_GNUTLS}) seastar_set_dep_args (LibUring VERSION 2.0 OPTION ${Seastar_IO_URING}) diff --git a/configure.py b/configure.py index c0b284bc17..68ccdebfa1 100755 --- a/configure.py +++ b/configure.py @@ -89,6 +89,9 @@ def standard_supported(standard, compiler='g++'): arg_parser.add_argument('--verbose', dest='verbose', action='store_true', help='Make configure output more verbose.') arg_parser.add_argument('--scheduling-groups-count', action='store', dest='scheduling_groups_count', default='16', help='Number of available scheduling groups in the reactor') +arg_parser.add_argument('--crypto-provider', dest='crypto_provider', choices=seastar_cmake.SUPPORTED_CRYPTO_PROVIDERS, + default='GnuTLS', help='The cryptographic provider ot use') +arg_parser.add_argument('--openssl-root-dir', dest='openssl_root_dir', help="Root directory for OpenSSL library") add_tristate( arg_parser, @@ -191,6 +194,7 @@ def configure_mode(mode): '-DBUILD_SHARED_LIBS={}'.format('yes' if mode in ('debug', 'dev') else 'no'), '-DSeastar_API_LEVEL={}'.format(args.api_level), '-DSeastar_SCHEDULING_GROUPS_COUNT={}'.format(args.scheduling_groups_count), + '-DSeastar_USE_OPENSSL={}'.format('yes' if args.crypto_provider == 'OpenSSL' else 'no'), tr(args.exclude_tests, 'EXCLUDE_TESTS_FROM_ALL'), tr(args.exclude_apps, 'EXCLUDE_APPS_FROM_ALL'), tr(args.exclude_demos, 'EXCLUDE_DEMOS_FROM_ALL'), @@ -211,6 +215,9 @@ def configure_mode(mode): tr(args.debug_shared_ptr, 'DEBUG_SHARED_PTR', value_when_none='default'), ] + if args.openssl_root_dir is not None: + TRANSLATED_ARGS.appen(f'-DOPENSSL_ROOT_DIR={args.openssl_root_dir}') + ingredients_to_cook = set(args.cook) if args.dpdk: diff --git a/include/seastar/net/tcp.hh b/include/seastar/net/tcp.hh index e2f5a567c8..7e4c4bbd49 100644 --- a/include/seastar/net/tcp.hh +++ b/include/seastar/net/tcp.hh @@ -30,8 +30,12 @@ #include #include #include +#ifdef SEASTAR_USE_OPENSSL +#include +#else #include #endif +#endif #include #include #include @@ -42,6 +46,7 @@ #include #include #include +#include #include namespace seastar { @@ -2084,6 +2089,36 @@ tcp_seq tcp::tcb::get_isn() { // ISN = M + F(localip, localport, remoteip, remoteport, secretkey) // M is the 4 microsecond timer using namespace std::chrono; +#ifdef SEASTAR_USE_OPENSSL + uint32_t hash[8]; + hash[0] = _local_ip.ip; + hash[1] = _foreign_ip.ip; + hash[2] = (_local_port << 16) + _foreign_port; + unsigned int hash_size = sizeof(hash); + + // Why SHA-256 for OpenSSL vs MD5? + // MD5 may be disabled if OpenSSL is in FIPS mode, also some bench testing + // has shown that the SHA-256 performance is equivalent or better than MD5 + // as SHA256 is hardware accelerated on most modern CPU architectures + auto md_ptr = EVP_MD_fetch(nullptr, "SHA256", nullptr); + assert(md_ptr); + auto free_md_ptr = defer([&]() noexcept { EVP_MD_free(md_ptr); }); + assert(hash_size == static_cast(EVP_MD_get_size(md_ptr))); + auto md_ctx = EVP_MD_CTX_new(); + assert(md_ctx); + auto free_md_ctx = defer([&]() noexcept { EVP_MD_CTX_free(md_ctx); }); + auto res = EVP_DigestInit(md_ctx, md_ptr); + assert(1 == res); + res = EVP_DigestUpdate( + md_ctx, hash, 3 * sizeof(hash[0])); + assert(1 == res); + res = EVP_DigestUpdate( + md_ctx, _isn_secret.key, sizeof(_isn_secret.key)); + assert(1 == res); + res = EVP_DigestFinal_ex( + md_ctx, reinterpret_cast(hash), &hash_size); + assert(1 == res); +#else uint32_t hash[4]; hash[0] = _local_ip.ip; hash[1] = _foreign_ip.ip; @@ -2096,6 +2131,7 @@ tcp_seq tcp::tcb::get_isn() { // reuse "hash" for the output of digest assert(sizeof(hash) == gnutls_hash_get_len(GNUTLS_DIG_MD5)); gnutls_hash_deinit(md5_hash_handle, hash); +#endif auto seq = hash[0]; auto m = duration_cast(clock_type::now().time_since_epoch()); seq += m.count() / 4; diff --git a/include/seastar/net/tls.hh b/include/seastar/net/tls.hh index 704a50d5d0..356e4c1e87 100644 --- a/include/seastar/net/tls.hh +++ b/include/seastar/net/tls.hh @@ -115,6 +115,16 @@ namespace tls { shared_ptr _impl; }; + enum class tls_version { + tlsv1_0, + tlsv1_1, + tlsv1_2, + tlsv1_3 + }; + + std::string_view format_as(tls_version); + std::ostream& operator<<(std::ostream&, const tls_version&); + class abstract_credentials { protected: abstract_credentials() = default; @@ -157,6 +167,18 @@ namespace tls { */ using dn_callback = noncopyable_function; + enum class client_auth { + NONE, REQUEST, REQUIRE + }; + + /** + * Session resumption support. + * We only support TLS1.3 session tickets. + */ + enum class session_resume_mode { + NONE, TLS13_SESSION_TICKET + }; + /** * Holds certificates and keys. * @@ -190,6 +212,7 @@ namespace tls { // TODO add methods for certificate verification +#ifdef SEASTAR_USE_GNUTLS /** * TLS handshake priority string. See gnutls docs and syntax at * https://gnutls.org/manual/html_node/Priority-Strings.html @@ -197,6 +220,45 @@ namespace tls { * Allows specifying order and allowance for handshake alg. */ void set_priority_string(const sstring&); +#endif + +#ifdef SEASTAR_USE_OPENSSL + /** + * Used to set the cipher string for TLS versions 1.2 and below + * + * See https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_cipher_list.html + * for documentation on the format of the cipher list string + */ + void set_cipher_string(const sstring&); + + /** + * Used to set the cipher suites to use for TLSv1.3 + * + * See https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_ciphersuites.html + * for documentation on the format of the ciphersuites string. + */ + void set_ciphersuites(const sstring&); + + /** + * Call this when you want to enable server precedence when + * negotitating the TLS handshake. Client precedence is on + * by default. + */ + void enable_server_precedence(); + /** + * @brief Set the minimum tls version for this connection + * + * If unset, will default to the minimum of the underlying + * implementation + */ + void set_minimum_tls_version(tls_version); + /** + * @brief Set the maximum tls version for this connection + * + * If unset, will default to the maximum of the underly implementation + */ + void set_maximum_tls_version(tls_version); +#endif /** * Register a callback for receiving Distinguished Name (DN) information @@ -235,6 +297,12 @@ namespace tls { template friend class reloadable_credentials; shared_ptr _impl; + + // The following methods are provided so classes that inherit from + // certificate_credentials can access the underly implementation + void enable_load_system_trust(); + void set_client_auth(client_auth); + void set_session_resume_mode(session_resume_mode); }; /** Exception thrown on certificate validation error */ @@ -243,18 +311,6 @@ namespace tls { using runtime_error::runtime_error; }; - enum class client_auth { - NONE, REQUEST, REQUIRE - }; - - /** - * Session resumption support. - * We only support TLS1.3 session tickets. - */ - enum class session_resume_mode { - NONE, TLS13_SESSION_TICKET - }; - /** * Extending certificates and keys for server usage. * More probably goes in here... @@ -312,9 +368,20 @@ namespace tls { future<> set_system_trust(); void set_client_auth(client_auth); - void set_priority_string(const sstring&); void set_session_resume_mode(session_resume_mode); +#ifdef SEASTAR_USE_GNUTLS + void set_priority_string(const sstring&); +#endif + +#ifdef SEASTAR_USE_OPENSSL + void set_cipher_string(const sstring&); + void set_ciphersuites(const sstring&); + void enable_server_precedence(); + void set_minimum_tls_version(tls_version); + void set_maximum_tls_version(tls_version); +#endif + void apply_to(certificate_credentials&) const; shared_ptr build_certificate_credentials() const; @@ -331,6 +398,11 @@ namespace tls { client_auth _client_auth = client_auth::NONE; session_resume_mode _session_resume_mode = session_resume_mode::NONE; sstring _priority; + sstring _cipher_string; + sstring _ciphersuites; + bool _enable_server_precedence = false; + std::optional _min_tls_version; + std::optional _max_tls_version; }; using session_data = std::vector; @@ -582,3 +654,10 @@ template <> struct fmt::formatter : fmt::formatt return fmt::format_to(ctx.out(), "{}={}", name.type, name.value); } }; + +template <> struct fmt::formatter : fmt::formatter { + template + auto format(seastar::tls::tls_version version, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "{}", format_as(version)); + } +}; diff --git a/install-dependencies.sh b/install-dependencies.sh index 3c913f8152..92ecdd050f 100755 --- a/install-dependencies.sh +++ b/install-dependencies.sh @@ -44,6 +44,7 @@ debian_packages=( libpciaccess-dev libprotobuf-dev libsctp-dev + libssl-dev libtool liburing-dev libxml2-dev @@ -96,6 +97,7 @@ redhat_packages=( meson numactl-devel openssl + openssl-devel protobuf-compiler protobuf-devel python3 @@ -227,6 +229,7 @@ opensuse_packages=( meson ninja openssl + openssl-devel protobuf-devel python3-PyYAML ragel diff --git a/seastar_cmake.py b/seastar_cmake.py index 105c01c50c..e9718fc6db 100644 --- a/seastar_cmake.py +++ b/seastar_cmake.py @@ -25,6 +25,8 @@ COOKING_BASIC_ARGS = ['./cooking.sh'] +SUPPORTED_CRYPTO_PROVIDERS = ['GnuTLS', 'OpenSSL'] + def build_path(mode, build_root): """Return the absolute path to the build directory for the given mode, i.e., seastar_dir//""" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 46e7a13daf..333aed872c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -66,7 +66,9 @@ target_sources (seastar-module net/packet.cc net/inet_address.cc net/socket_address.cc - net/tls.cc + $<$:net/tls.cc> + $<$:net/ossl.cc> + net/tls-impl.cc net/virtio.cc http/client.cc http/common.cc @@ -95,6 +97,8 @@ target_compile_definitions (seastar-module $<$:SEASTAR_SSTRING> SEASTAR_API_LEVEL=${Seastar_API_LEVEL} SEASTAR_SCHEDULING_GROUPS_COUNT=${Seastar_SCHEDULING_GROUPS_COUNT} + $<$:SEASTAR_USE_OPENSSL> + $<$:SEASTAR_USE_GNUTLS> PRIVATE SEASTAR_MODULE ${Seastar_PRIVATE_COMPILE_DEFINITIONS}) @@ -120,7 +124,9 @@ target_link_libraries (seastar-module SourceLocation::source_location PRIVATE ${CMAKE_DL_LIBS} - GnuTLS::gnutls + $<$:GnuTLS::gnutls> + $<$:OpenSSL::SSL> + $<$:OpenSSL::Crypto> StdAtomic::atomic lksctp-tools::lksctp-tools rt::rt diff --git a/src/net/ossl.cc b/src/net/ossl.cc new file mode 100644 index 0000000000..fb5ee1e91d --- /dev/null +++ b/src/net/ossl.cc @@ -0,0 +1,2151 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2015 Cloudius Systems + * Copyright 2024 Redpanda Data + */ + +// The data flow in the openssl-based seastar tls implementation looks like the following. The session's put/get() +// methods delegate to openssl functions, which perform the openssl handshake, encryption and decryption, and then they +// write the raw data into a custom BIO object. This custom BIO object is implemented in this file, it is specific to +// seastar and it delegates the write/read methods to the seastar::tls::session's data_sink/data_source (which are the +// interfaces around the underlying seastar connected socket). +// +// +----------------------------+ +-------------------------------------+ +-------------+ +----------------------------+ +------+ +// | | | | | | | | | | +// -------> seastar::tls::session +------> SSL_{do_handshake|write_ex|read_ex} +-----> custom BIO +----> data_sink/data_source +---> OS | +// | | | | | | | (seastar socket) | | | +// +----------------------------+ +-------------------------------------+ +-------------+ +----------------------------+ +------+ + +#ifdef SEASTAR_MODULE +module; +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#ifdef SEASTAR_MODULE +module seastar; +#else +#include +#include +#include +#include +#include +#include +#include +#include + +#include "net/tls-impl.hh" +#endif + +template <> struct fmt::formatter : fmt::ostream_formatter {}; + +namespace seastar { + +enum class ossl_errc : int{}; + +} + +namespace std { + +template<> +struct is_error_code_enum : true_type {}; + +} + +template<> +struct fmt::formatter : public fmt::formatter { + auto format(seastar::ossl_errc error, fmt::format_context& ctx) const -> decltype(ctx.out()) { + constexpr size_t error_buf_size = 256; + // Buffer passed to ERR_error_string must be at least 256 bytes large + // https://www.openssl.org/docs/man3.0/man3/ERR_error_string_n.html + std::array buf{}; + ERR_error_string_n( + static_cast(error), buf.data(), buf.size()); + // ERR_error_string_n does include the terminating null character + return fmt::format_to(ctx.out(), "{}", buf.data()); + } +}; + +namespace seastar { + +int tls_version_to_ossl(tls::tls_version version) { + switch(version) { + case tls::tls_version::tlsv1_0: + return TLS1_VERSION; + case tls::tls_version::tlsv1_1: + return TLS1_1_VERSION; + case tls::tls_version::tlsv1_2: + return TLS1_2_VERSION; + case tls::tls_version::tlsv1_3: + return TLS1_3_VERSION; + } + + __builtin_unreachable(); +} + +class ossl_error_category : public std::error_category { +public: + constexpr ossl_error_category() noexcept : std::error_category{} {} + const char* name() const noexcept override { + return "OpenSSL"; + } + std::string message(int error) const override { + return fmt::format("{}", static_cast(error)); + } +}; + +const std::error_category& tls::error_category() { + static const ossl_error_category ec; + return ec; +} + +std::error_code make_error_code(ossl_errc e) { + return std::error_code(static_cast(e), tls::error_category()); +} + +std::system_error make_ossl_error(const std::string & msg) { + std::vector error_codes; + for (auto code = ERR_get_error(); code != 0; code = ERR_get_error()) { + error_codes.push_back(static_cast(code)); + } + if (error_codes.empty()) { + return std::system_error{ + static_cast(ERR_PACK(ERR_LIB_USER, 0, ERR_R_OPERATION_FAIL)), + tls::error_category(), + msg}; + } else { + auto err_code = static_cast(error_codes.front()); + if (ERR_LIB_SYS == ERR_GET_LIB(err_code)) { + // If the error code belongs to ERR_LIB_SYS, then the error is a system error + // Extract the errno using ERR_GET_REASON and throw a std::generic_category + return std::system_error( + ERR_GET_REASON(err_code), + std::generic_category(), + fmt::format("{}: {}", msg, error_codes)); + } + return std::system_error( + static_cast(err_code), + tls::error_category(), + fmt::format("{}: {}", msg, error_codes)); + } +} + +template +sstring asn1_str_to_str(T* asn1) { + const auto len = ASN1_STRING_length(asn1); + return sstring(reinterpret_cast(ASN1_STRING_get0_data(asn1)), len); +}; + +template +struct ssl_deleter { + void operator()(T* ptr) { fn(ptr); } +}; + +// Must define this method as sk_X509_pop_free is a macro +void X509_pop_free(STACK_OF(X509)* ca) { + sk_X509_pop_free(ca, X509_free); +} + +void X509_INFO_pop_free(STACK_OF(X509_INFO)* infos) { + sk_X509_INFO_pop_free(infos, X509_INFO_free); +} + +void GENERAL_NAME_pop_free(GENERAL_NAMES* gns) { + sk_GENERAL_NAME_pop_free(gns, GENERAL_NAME_free); +} + +template +using ssl_handle = std::unique_ptr>; + +using bio_method_ptr = ssl_handle; +using bio_ptr = ssl_handle; +using evp_pkey_ptr = ssl_handle; +using x509_ptr = ssl_handle; +using x509_crl_ptr = ssl_handle; +using x509_store_ptr = ssl_handle; +using x509_store_ctx_ptr = ssl_handle; +using x509_chain_ptr = ssl_handle; +using x509_infos_ptr = ssl_handle; +using general_names_ptr = ssl_handle; +using pkcs12 = ssl_handle; +using ssl_ctx_ptr = ssl_handle; +using ssl_ptr = ssl_handle; +using x509_verify_param_ptr = ssl_handle; +using ssl_session_ptr = ssl_handle; + +/** + * The purpose of this structure is to hold the AES and HMAC keys used to encrypt/decrypt TLS 1.3 + * session tickets given to and presented by, the TLS client. + */ +struct session_ticket_keys { + std::array key_name; + std::array aes_key; + std::array hmac_key; + + void generate_keys() { + generate_key(key_name); + generate_key(aes_key); + generate_key(hmac_key); + } + + ~session_ticket_keys() { + clear_key(aes_key); + clear_key(hmac_key); + } + +private: + /** + * This will zeroize the key contents by writing 0xff, then 0x00, + * and finally 0xff over the memory that held the key + * @tparam N Size of the array + * @param key The key to clear + */ + template + static void clear_key(std::array& key) { + // Zeroize sensitive key data + OPENSSL_cleanse(key.data(), N); + } + + /** + * Generates a key + * @tparam N The size of the key + * @param key The key to generate + * @throws ossl_error if unable to generate random data + */ + template + static void generate_key(std::array& key) { + if (RAND_priv_bytes(key.data(), N) <= 0) { + throw make_ossl_error("Failed to generate key"); + } + } +}; + +// sufficiently large enough to avoid collision with OpenSSL BIO controls +#define BIO_C_SET_POINTER 1000 +// Index into ex data for SSL structure to fetch a pointer to session +#define SSL_EX_DATA_SESSION 0 + +BIO_METHOD* get_method(); + +void tls::credentials_builder::set_cipher_string(const sstring& cipher_string) { + _cipher_string = cipher_string; +} + +void tls::credentials_builder::set_ciphersuites(const sstring& ciphersuites) { + _ciphersuites = ciphersuites; +} + +void tls::credentials_builder::enable_server_precedence() { + _enable_server_precedence = true; +} + +void tls::credentials_builder::set_minimum_tls_version(tls_version version) { + _min_tls_version.emplace(version); +} + +void tls::credentials_builder::set_maximum_tls_version(tls_version version) { + _max_tls_version.emplace(version); +} + +/// TODO: Implement the DH params impl struct +/// +class tls::dh_params::impl { +public: + explicit impl(level) {} + impl(const blob&, x509_crt_format){} + + const EVP_PKEY* get() const { return _pkey.get(); } + + explicit operator const EVP_PKEY*() const { return _pkey.get(); } + +private: + evp_pkey_ptr _pkey; +}; + +tls::dh_params::dh_params(level lvl) : _impl(std::make_unique(lvl)) +{} + +tls::dh_params::dh_params(const blob& b, x509_crt_format fmt) + : _impl(std::make_unique(b, fmt)) { +} + +// TODO(rob) some small amount of code duplication here +tls::dh_params::~dh_params() = default; + +tls::dh_params::dh_params(dh_params&&) noexcept = default; +tls::dh_params& tls::dh_params::operator=(dh_params&&) noexcept = default; + +class tls::certificate_credentials::impl { + struct certkey_pair { + x509_ptr cert; + evp_pkey_ptr key; + explicit operator bool() const noexcept { + return cert != nullptr && key != nullptr; + } + }; + + static const int credential_store_idx = 0; + +public: + // This callback is designed to intercept the verification process and to implement an additional + // check, returning 0 or -1 will force verification to fail. + // + // However it has been implemented in this case soley to cache the last observed certificate so + // that it may be inspected during the session::verify() method, if desired. + // + static int verify_callback(int preverify_ok, X509_STORE_CTX* store_ctx) { + // Grab the 'this' pointer from the stores generic data cache, it should always exist + auto store = X509_STORE_CTX_get0_store(store_ctx); + auto credential_impl = static_cast(X509_STORE_get_ex_data(store, credential_store_idx)); + assert(credential_impl != nullptr); + // Store a pointer to the current connection certificate within the impl instance + auto cert = X509_STORE_CTX_get_current_cert(store_ctx); + X509_up_ref(cert); + credential_impl->_last_cert = x509_ptr(cert); + return preverify_ok; + } + + impl() : _creds([] { + auto store = X509_STORE_new(); + if(store == nullptr) { + throw std::bad_alloc(); + } + X509_STORE_set_verify_cb(store, verify_callback); + return store; + }()) { + // The static verify_callback above will use the stored pointer to 'this' to store the last + // observed x509 certificate + [[maybe_unused]] auto res = + X509_STORE_set_ex_data(_creds.get(), credential_store_idx, this); + assert(res == 1); + } + + + // Parses a PEM certificate file that may contain more then one entry, calls the callback provided + // passing the associated X509_INFO* argument. The parameter is not retained so the caller must retain + // the item before the end of the function call. + template + static void iterate_pem_certs(const bio_ptr& cert_bio, LoadFunc fn) { + auto infos = x509_infos_ptr(PEM_X509_INFO_read_bio(cert_bio.get(), nullptr, nullptr, nullptr)); + auto num_elements = sk_X509_INFO_num(infos.get()); + if (num_elements <= 0) { + throw make_ossl_error("Failed to parse PEM cert"); + } + for (auto i=0; i < num_elements; i++) { + auto object = sk_X509_INFO_value(infos.get(), i); + fn(object); + } + } + + static x509_ptr parse_x509_cert(const blob& b, x509_crt_format fmt) { + bio_ptr cert_bio(BIO_new_mem_buf(b.begin(), b.size())); + x509_ptr cert; + switch(fmt) { + case tls::x509_crt_format::PEM: + cert = x509_ptr(PEM_read_bio_X509(cert_bio.get(), nullptr, nullptr, nullptr)); + break; + case tls::x509_crt_format::DER: + cert = x509_ptr(d2i_X509_bio(cert_bio.get(), nullptr)); + break; + } + if (!cert) { + throw make_ossl_error("Failed to parse x509 certificate"); + } + return cert; + } + + void set_x509_trust(const blob& b, x509_crt_format fmt) { + bio_ptr cert_bio(BIO_new_mem_buf(b.begin(), b.size())); + x509_ptr cert; + switch(fmt) { + case tls::x509_crt_format::PEM: + iterate_pem_certs(cert_bio, [this](X509_INFO* info){ + if (!info->x509) { + throw make_ossl_error("Failed to parse x509 cert"); + } + X509_STORE_add_cert(*this, info->x509); + }); + break; + case tls::x509_crt_format::DER: + cert = x509_ptr(d2i_X509_bio(cert_bio.get(), nullptr)); + if (!cert) { + throw make_ossl_error("Failed to parse x509 certificate"); + } + X509_STORE_add_cert(*this, cert.get()); + break; + } + } + + void set_x509_crl(const blob& b, x509_crt_format fmt) { + bio_ptr cert_bio(BIO_new_mem_buf(b.begin(), b.size())); + x509_crl_ptr crl; + switch(fmt) { + case x509_crt_format::PEM: + iterate_pem_certs(cert_bio, [this](X509_INFO* info) { + if (!info->crl) { + throw make_ossl_error("Failed to parse CRL"); + } + X509_STORE_add_crl(*this, info->crl); + }); + break; + case x509_crt_format::DER: + crl = x509_crl_ptr(d2i_X509_CRL_bio(cert_bio.get(), nullptr)); + if (!crl) { + throw make_ossl_error("Failed to parse x509 crl"); + } + X509_STORE_add_crl(*this, crl.get()); + break; + } + + enable_crl_checking(); + } + + void set_x509_key(const blob& cert, const blob& key, x509_crt_format fmt) { + x509_ptr x509_cert{nullptr}; + bio_ptr key_bio(BIO_new_mem_buf(key.begin(), key.size())); + evp_pkey_ptr pkey; + switch(fmt) { + case x509_crt_format::PEM: + pkey = evp_pkey_ptr(PEM_read_bio_PrivateKey(key_bio.get(), nullptr, nullptr, nullptr)); + // The provided `cert` blob may contain more than one cert. We need to be prepared + // for this situation. So we will parse through the blob using `iterate_pem_certs`. + // The first cert encountered will be assigned to x509_cert and all subsequent certs + // will be added to the X509_STORE's trusted certificates + iterate_pem_certs(bio_ptr{BIO_new_mem_buf(cert.begin(), cert.size())}, [this, &x509_cert](X509_INFO* info) { + if (!info->x509) { + throw make_ossl_error("Failed to parse X.509 certificate in loading key/cert chain"); + } + if (!x509_cert) { + x509_cert = x509_ptr{info->x509}; + // By setting x509 to nullptr, the sk_X509_INFO_pop_free function will not + // call X509_free on it. We have 'transfered' ownership above to the + // x509_cert X509 ptr + info->x509 = nullptr; + } else { + X509_STORE_add_cert(*this, info->x509); + } + }); + break; + case x509_crt_format::DER: + pkey = evp_pkey_ptr(d2i_PrivateKey_bio(key_bio.get(), nullptr)); + // We don't handle a chain of certs when encoded in DER + x509_cert = parse_x509_cert(cert, fmt); + break; + default: + __builtin_unreachable(); + } + if (!pkey) { + throw make_ossl_error("Error attempting to parse private key"); + } + if (!X509_check_private_key(x509_cert.get(), pkey.get())) { + throw make_ossl_error("Failed to verify cert/key pair"); + } + _cert_and_key = certkey_pair{.cert = std::move(x509_cert), .key = std::move(pkey)}; + } + + void set_simple_pkcs12(const blob& b, x509_crt_format, const sstring& password) { + // Load the PKCS12 file + bio_ptr bio(BIO_new_mem_buf(b.begin(), b.size())); + if (auto p12 = pkcs12(d2i_PKCS12_bio(bio.get(), nullptr))) { + // Extract the certificate and private key from PKCS12, using provided password + EVP_PKEY *pkey = nullptr; + X509 *cert = nullptr; + STACK_OF(X509) *ca = nullptr; + if (!PKCS12_parse(p12.get(), password.c_str(), &pkey, &cert, &ca)) { + throw make_ossl_error("Failed to extract cert key pair from pkcs12 file"); + } + // Ensure signature validation checks pass before continuing + if (!X509_check_private_key(cert, pkey)) { + X509_free(cert); + EVP_PKEY_free(pkey); + throw make_ossl_error("Failed to verify cert/key pair"); + } + _cert_and_key = certkey_pair{.cert = x509_ptr(cert), .key = evp_pkey_ptr(pkey)}; + + // Iterate through all elements in the certificate chain, adding them to the store + auto ca_ptr = x509_chain_ptr(ca); + if (ca_ptr) { + auto num_elements = sk_X509_num(ca_ptr.get()); + while (num_elements > 0) { + auto e = sk_X509_pop(ca_ptr.get()); + X509_STORE_add_cert(*this, e); + // store retains certificate + X509_free(e); + num_elements -= 1; + } + } + } else { + throw make_ossl_error("Failed to parse pkcs12 file"); + } + } + + void enable_crl_checking() { + if (!std::exchange(_crl_check_flag_set, true)) { + x509_verify_param_ptr x509_vfy(X509_VERIFY_PARAM_new()); + + if (1 != X509_VERIFY_PARAM_set_flags( + x509_vfy.get(), X509_V_FLAG_CRL_CHECK|X509_V_FLAG_CRL_CHECK_ALL)) { + throw make_ossl_error( + "Failed to set X509_V_FLAG_CRL_CHECK flag"); + } + + if (1 != X509_STORE_set1_param(*this, x509_vfy.get())) { + throw make_ossl_error( + "Failed to set verification parameters on X509 store"); + } + } + } + + void dh_params(const tls::dh_params&) {} + + void set_client_auth(client_auth ca) { + _client_auth = ca; + } + client_auth get_client_auth() const { + return _client_auth; + } + void set_session_resume_mode(session_resume_mode m) { + _session_resume_mode = m; + if (m != session_resume_mode::NONE) { + _session_ticket_keys.generate_keys(); + } + } + + session_resume_mode get_session_resume_mode() { + return _session_resume_mode; + } + + const session_ticket_keys & get_session_ticket_keys() const { + return _session_ticket_keys; + } + + void set_dn_verification_callback(dn_callback cb) { + _dn_callback = std::move(cb); + } + + void set_enable_certificate_verification(bool enable) { + _enable_certificate_verification = enable; + } + + void set_cipher_string(const sstring& cipher_string) { + _cipher_string = cipher_string; + } + + void set_ciphersuites(const sstring& ciphersuites) { + _ciphersuites = ciphersuites; + } + + void enable_server_precedence() { + _enable_server_precedence = true; + } + + void set_minimum_tls_version(tls_version version) { + _min_tls_version.emplace(version); + } + + void set_maximum_tls_version(tls_version version) { + _max_tls_version.emplace(version); + } + + const sstring& get_cipher_string() const noexcept { + return _cipher_string; + } + + const sstring& get_ciphersuites() const noexcept { + return _ciphersuites; + } + + bool is_server_precedence_enabled() { + return _enable_server_precedence; + } + + const std::optional& minimum_tls_version() const noexcept { + return _min_tls_version; + } + + const std::optional& maximum_tls_version() const noexcept { + return _max_tls_version; + } + + // Returns the certificate of last attempted verification attempt, if there was no attempt, + // this will not be updated and will remain stale + const x509_ptr& get_last_cert() const { return _last_cert; } + + operator X509_STORE*() const { return _creds.get(); } + + const certkey_pair& get_certkey_pair() const { + return _cert_and_key; + } + +private: + friend class certificate_credentials; + friend class credentials_builder; + friend class session; + + void set_load_system_trust(bool trust) { + _load_system_trust = trust; + } + + bool need_load_system_trust() const { + return _load_system_trust; + } + + certkey_pair _cert_and_key; + session_ticket_keys _session_ticket_keys; + x509_ptr _last_cert; + x509_store_ptr _creds; + dn_callback _dn_callback; + std::optional _min_tls_version; + std::optional _max_tls_version; + sstring _cipher_string; + sstring _ciphersuites; + + client_auth _client_auth = client_auth::NONE; + session_resume_mode _session_resume_mode = session_resume_mode::NONE; + bool _load_system_trust = false; + bool _enable_server_precedence = false; + bool _crl_check_flag_set = false; + bool _enable_certificate_verification = true; +}; + +tls::certificate_credentials::certificate_credentials() + : _impl(make_shared()) { +} + +tls::certificate_credentials::~certificate_credentials() { +} + +tls::certificate_credentials::certificate_credentials( + certificate_credentials&&) noexcept = default; +tls::certificate_credentials& tls::certificate_credentials::operator=( + certificate_credentials&&) noexcept = default; + +void tls::certificate_credentials::set_x509_trust(const blob& b, + x509_crt_format fmt) { + _impl->set_x509_trust(b, fmt); +} + +void tls::certificate_credentials::set_x509_crl(const blob& b, + x509_crt_format fmt) { + _impl->set_x509_crl(b, fmt); + +} +void tls::certificate_credentials::set_x509_key(const blob& cert, + const blob& key, x509_crt_format fmt) { + _impl->set_x509_key(cert, key, fmt); +} + +void tls::certificate_credentials::set_simple_pkcs12(const blob& b, + x509_crt_format fmt, const sstring& password) { + _impl->set_simple_pkcs12(b, fmt, password); +} + +future<> tls::certificate_credentials::set_system_trust() { + _impl->_load_system_trust = true; + return make_ready_future<>(); +} + +void tls::certificate_credentials::set_cipher_string(const sstring& cipher_string) { + _impl->set_cipher_string(cipher_string); +} + +void tls::certificate_credentials::set_ciphersuites(const sstring& ciphersuites) { + _impl->set_ciphersuites(ciphersuites); +} + +void tls::certificate_credentials::enable_server_precedence() { + _impl->enable_server_precedence(); +} + +void tls::certificate_credentials::set_minimum_tls_version(tls_version version) { + _impl->set_minimum_tls_version(version); +} + +void tls::certificate_credentials::set_maximum_tls_version(tls_version version) { + _impl->set_maximum_tls_version(version); +} + +void tls::certificate_credentials::set_dn_verification_callback(dn_callback cb) { + _impl->set_dn_verification_callback(std::move(cb)); +} + +void tls::certificate_credentials::set_enable_certificate_verification(bool enable) { + _impl->set_enable_certificate_verification(enable); +} + +void tls::certificate_credentials::enable_load_system_trust() { + _impl->_load_system_trust = true; +} + +void tls::certificate_credentials::set_client_auth(client_auth ca) { + _impl->set_client_auth(ca); +} + +void tls::certificate_credentials::set_session_resume_mode(session_resume_mode m) { + _impl->set_session_resume_mode(m); +} + +tls::server_credentials::server_credentials() + : server_credentials(dh_params{}) +{} + +tls::server_credentials::server_credentials(shared_ptr dh) + : server_credentials(*dh) +{} + +tls::server_credentials::server_credentials(const dh_params& dh) { + _impl->dh_params(dh); +} + +tls::server_credentials::server_credentials(server_credentials&&) noexcept = default; +tls::server_credentials& tls::server_credentials::operator=( + server_credentials&&) noexcept = default; + +void tls::server_credentials::set_client_auth(client_auth ca) { + _impl->set_client_auth(ca); +} + +namespace tls { + +int session_ticket_cb(SSL * s, unsigned char key_name[16], + unsigned char iv[EVP_MAX_IV_LENGTH], + EVP_CIPHER_CTX * ctx, EVP_MAC_CTX *hctx, int enc); + +/** + * Session wraps an OpenSSL SSL session and context, + * and is the actual conduit for an TLS/SSL data flow. + * + * We use a connected_socket and its sink/source + * for IO. Note that we need to keep ownership + * of these, since we handle handshake etc. + * + * The implmentation below relies on OpenSSL, for the gnutls implementation + * see tls.cc and the CMake option 'Seastar_WITH_OSSL' + */ +class session : public enable_shared_from_this, public session_impl { +public: + using buf_type = temporary_buffer; + using frag_iter = net::fragment*; + + session(session_type t, shared_ptr creds, + std::unique_ptr sock, tls_options options = {}) + : _sock(std::move(sock)) + , _local_address(fmt::format("{}", _sock->local_address())) + , _remote_address(fmt::format("{}", _sock->remote_address())) + , _creds(creds->_impl) + , _in(_sock->source()) + , _out(_sock->sink()) + , _in_sem(1) + , _out_sem(1) + , _options(std::move(options)) + , _output_pending(make_ready_future<>()) + , _ctx(make_ssl_context(t)) + , _ssl([this]() { + auto ssl = SSL_new(_ctx.get()); + if (!ssl) { + throw make_ossl_error("Failed to create SSL session"); + } + return ssl; + }()) + , _type(t) { + if (1 != SSL_set_ex_data(_ssl.get(), SSL_EX_DATA_SESSION, this)) { + throw make_ossl_error("Failed to set EX data for SSL session"); + } + bio_ptr in_bio(BIO_new(get_method())); + bio_ptr out_bio(BIO_new(get_method())); + if (!in_bio || !out_bio) { + throw std::runtime_error("Failed to create BIOs"); + } + if (1 != BIO_ctrl(in_bio.get(), BIO_C_SET_POINTER, 0, this)) { + throw make_ossl_error("Failed to set bio ptr to in bio"); + } + if (1 != BIO_ctrl(out_bio.get(), BIO_C_SET_POINTER, 0, this)) { + throw make_ossl_error("Failed to set bio ptr to out bio"); + } + // SSL_set_bio transfers ownership of the read and write bios to the SSL + // instance + SSL_set_bio(_ssl.get(), in_bio.release(), out_bio.release()); + + if (_type == session_type::SERVER) { + SSL_set_accept_state(_ssl.get()); + } else { + if (!_options.server_name.empty()) { + SSL_set_tlsext_host_name( + _ssl.get(), _options.server_name.c_str()); + } + SSL_set_connect_state(_ssl.get()); + } + + if (_type == session_type::CLIENT && !_options.session_resume_data.empty()) { + auto data_ptr = std::as_const(_options.session_resume_data).data(); + long data_size = _options.session_resume_data.size(); + auto sess = ssl_session_ptr(d2i_SSL_SESSION(nullptr, &data_ptr, data_size)); + if (!sess) { + throw make_ossl_error("Failed to decode SSL_SESSION data for session resumption"); + } + if (1 != SSL_set_session(_ssl.get(), sess.get())) { + throw make_ossl_error("Failed to set SSL_SESSION on SSL for session resumption"); + } + } + _options.session_resume_data.clear(); + } + + session(session_type t, shared_ptr creds, + connected_socket sock, + tls_options options = {}) + : session(t, std::move(creds), net::get_impl::get(std::move(sock)), options) {} + + ~session() { + assert(_output_pending.available()); + } + + friend std::ostream & operator<<(std::ostream &os, const session & session) { + fmt::print(os, "{}:{}:{} -", + session.get_type_string(), + session.local_address(), + session.remote_address()); + return os; + } + + const char * get_type_string() const { + return _type == session_type::CLIENT ? "Client": "Server"; + } + + // This function waits for the _output_pending future to resolve + // If an error occurs, it is saved off into _error and returned + future<> wait_for_output() { + tls_log.trace("{} wait_for_output", *this); + return std::exchange(_output_pending, make_ready_future()) + .handle_exception([this](auto ep) { + tls_log.debug("{} wait_for_output error: {}", *this, ep); + _error = ep; + return make_exception_future(ep); + }); + } + + template T> + future<> + handle_output_error(T err) { + _error = std::make_exception_ptr(err); + return wait_for_output().then_wrapped([this, err](auto f) { + try { + f.get(); + // output was ok/done, just generate error exception + return make_exception_future(_error); + } catch(...) { + std::throw_with_nested(err); + } + }); + } + + // Helper function for handling the SSL errors in do_put + future handle_do_put_ssl_err(const int ssl_err) { + switch(ssl_err) { + case SSL_ERROR_ZERO_RETURN: + // Indicates a hang up somewhere + // Mark _eof and stop iteratio + _eof = true; + return make_ready_future(stop_iteration::yes); + case SSL_ERROR_NONE: + // Should not have been reached in this situation + // Continue iteration + return make_ready_future(stop_iteration::no); + case SSL_ERROR_SYSCALL: + { + auto err = make_ossl_error("System error encountered during SSL write"); + return handle_output_error(std::move(err)).then([] { + return stop_iteration::yes; + }); + } + case SSL_ERROR_SSL: { + auto ec = ERR_GET_REASON(ERR_peek_error()); + if (ec == SSL_R_UNEXPECTED_EOF_WHILE_READING) { + // Probably shouldn't have during a write, but + // let's handle this gracefully + ERR_clear_error(); + _eof = true; + return make_ready_future(stop_iteration::yes); + } + auto err = make_ossl_error( + "Error occurred during SSL write"); + return handle_output_error(std::move(err)).then([] { + return stop_iteration::yes; + }); + } + default: + { + // Some other unhandled situation + auto err = std::runtime_error( + "Unknown error encountered during SSL write"); + return handle_output_error(std::move(err)).then([] { + return stop_iteration::yes; + }); + } + } + } + + // Called post locking of the _out_sem + // This function takes and holds the sempahore units for _out_sem and + // will attempt to send the provided packet. If a renegotiation is needed + // any unprocessed part of the packet is returned. + future do_put(net::packet p) { + tls_log.trace("{} do_put", *this); + if (!connected()) { + tls_log.debug("{} do_put: not connected", *this); + return make_ready_future(std::move(p)); + } + assert(_output_pending.available()); + return do_with(std::move(p), + [this](net::packet& p) { + // This do_until runs until either a renegotiation occurs or the packet is empty + return do_until( + [this, &p] { return eof() || !connected() || p.len() == 0;}, + [this, &p]() mutable { + std::string_view frag_view = + {p.fragments().begin()->base, p.fragments().begin()->size}; + return repeat([this, frag_view, &p]() mutable { + if (frag_view.empty()) { + return make_ready_future(stop_iteration::yes); + } + size_t bytes_written = 0; + auto write_rc = SSL_write_ex( + _ssl.get(), frag_view.data(), frag_view.size(), &bytes_written); + tls_log.trace("{} do_put: SSL_write_ex: {}", *this, write_rc); + if (write_rc != 1) { + const auto ssl_err = SSL_get_error(_ssl.get(), write_rc); + tls_log.trace("{} do_put: SSL_get_error: {}", *this, ssl_err); + if (ssl_err == SSL_ERROR_WANT_WRITE) { + return wait_for_output().then([] { + return stop_iteration::no; + }); + } else if (!connected() || ssl_err == SSL_ERROR_WANT_READ) { + ERR_clear_error(); + return make_ready_future(stop_iteration::yes); + } + return handle_do_put_ssl_err(ssl_err); + } else { + tls_log.trace("{} do_put: bytes_written: {}", *this, bytes_written); + frag_view.remove_prefix(bytes_written); + p.trim_front(bytes_written); + return wait_for_output().then([] { + return stop_iteration::no; + }); + } + }); + } + ).then([this, &p] { + tls_log.trace("{} do_put: returning packet of size: {}", *this, p.len()); + return std::move(p); + }); + } + ); + } + + // Used to push unencrypted data through OpenSSL, which will + // encrypt it and then place it into the output bio. + future<> put(net::packet p) override { + tls_log.trace("{} put", *this); + constexpr size_t openssl_max_record_size = 16 * 1024; + if (_error) { + return make_exception_future(_error); + } + if (_shutdown) { + return make_exception_future<>( + std::system_error(EPIPE, std::system_category())); + } + if (!connected()) { + tls_log.trace("{} put: not connected, performing handshake", *this); + return handshake().then( + [this, p = std::move(p)]() mutable { return put(std::move(p)); }); + } + + // We want to make sure that we write to the underlying bio with as large + // packets as possible. This is because eventually this translates to a + // sendmsg syscall. Further it results in larger TLS records which makes + // encryption/decryption faster. Hence to avoid cases where we would do + // an extra syscall for something like a 100 bytes header we linearize the + // packet if it's below the max TLS record size. + if (p.nr_frags() > 1 && p.len() <= openssl_max_record_size) { + p.linearize(); + } + return with_semaphore(_out_sem, 1, [this, p = std::move(p)]() mutable { + return do_put(std::move(p)); + }).then([this](net::packet p) { + if (eof() || p.len() == 0) { + tls_log.trace("{} put: eof: {}, p.len(): {}", *this, eof(), p.len()); + return make_ready_future(); + } else { + tls_log.trace("{} put: not completed packet sending, re-doing handshake", *this); + return handshake().then([this, p = std::move(p)]() mutable { + return put(std::move(p)); + }); + } + }); + } + + // Called after locking the _in_sem and _out_sem semaphores. + // This function will walk through the handshake with a remote peer + // If EOF is encountered, ENOTCONN is thrown + future<> do_handshake() { + tls_log.trace("{} do_handshake", *this); + if (eof()) { + tls_log.trace("{} do_handshake: eof encountered", *this); + // if we have experienced and eof, set the error and return + // GnuTLS will probably return GNUTLS_E_PREMATURE_TERMINATION + // from gnutls_handshake in this situation. + _error = std::make_exception_ptr(std::system_error( + ENOTCONN, + std::system_category(), + "EOF encountered during handshake")); + return make_exception_future(_error); + } else if (connected()) { + tls_log.trace("{} do_handshake: already connected", *this); + return make_ready_future<>(); + } + return do_until( + [this] { return connected() || eof(); }, + [this] { + try { + auto n = SSL_do_handshake(_ssl.get()); + tls_log.trace("{} do_handshake: SSL_do_handshake: {}", *this, n); + if (n <= 0) { + auto ssl_error = SSL_get_error(_ssl.get(), n); + tls_log.trace("{} do_handshake: SSL_get_error: {}", *this, ssl_error); + switch(ssl_error) { + case SSL_ERROR_NONE: + // probably shouldn't have gotten here + break; + case SSL_ERROR_ZERO_RETURN: + // peer has closed + _eof = true; + break; + case SSL_ERROR_WANT_WRITE: + return wait_for_output(); + case SSL_ERROR_WANT_READ: + return wait_for_output().then([this] { + return wait_for_input(); + }); + case SSL_ERROR_SYSCALL: + { + auto err = make_ossl_error("System error during handshake"); + return handle_output_error(std::move(err)); + } + case SSL_ERROR_SSL: + { + auto ec = ERR_GET_REASON(ERR_peek_error()); + tls_log.debug("{} do_handshake: ERR_GET_REASON: {}", *this, ec); + switch (ec) { + case SSL_R_UNEXPECTED_EOF_WHILE_READING: + // well in this situation, the remote end closed + ERR_clear_error(); + _eof = true; + return make_ready_future<>(); + case SSL_R_PEER_DID_NOT_RETURN_A_CERTIFICATE: + case SSL_R_CERTIFICATE_VERIFY_FAILED: + case SSL_R_NO_CERTIFICATES_RETURNED: + ERR_clear_error(); + verify(); + // may throw, otherwise fall through + [[fallthrough]]; + default: + auto err = make_ossl_error("Failed to establish SSL handshake"); + return handle_output_error(std::move(err)); + } + break; + } + default: + auto err = std::runtime_error( + "Unknown error encountered during handshake"); + return handle_output_error(std::move(err)); + } + } else { + if (_type == session_type::CLIENT + || _creds->get_client_auth() != client_auth::NONE) { + verify(); + } + return wait_for_output(); + } + } catch(...) { + return make_exception_future<>(std::current_exception()); + } + return make_ready_future<>(); + } + ); + } + + // This function will attempt to pull data off of the _in stream + // if there isn't already data needing to be processed first. + future<> wait_for_input() { + tls_log.trace("{} wait_for_input", *this); + // If we already have data, then it needs to be processed + if (!_input.empty()) { + tls_log.trace("{} wait_for_input: input not empty", *this); + return make_ready_future(); + } + return _in.get() + .then([this](buf_type buf) { + // Set EOF if it's empty + tls_log.debug("{} wait_for_input: buffer {}empty", *this, buf.empty() ? "is ": ""); + _eof |= buf.empty(); + _input = std::move(buf); + }) + .handle_exception([this](auto ep) { + tls_log.debug("{} wait_for_input: exception: {}", *this, ep); + _error = ep; + return make_exception_future(ep); + }); + } + + // Called after locking the _in_sem semaphore + // This function attempts to pull unencrypted data off of the + // SSL session using SSL_read. If ther eis no data, then + // we will call perform_pull and wait for data to arrive. + future do_get() { + tls_log.trace("{} do_get", *this); + // Data is available to be pulled of the SSL session if there is pending + // data on the SSL session or there is data in the in_bio() which SSL reads + // from + auto data_to_pull = (BIO_ctrl_pending(in_bio()) + SSL_pending(_ssl.get())) > 0; + auto f = make_ready_future<>(); + if (!data_to_pull) { + tls_log.trace("{} do_get: no data to pull, waiting for input", *this); + // If nothing is in the SSL buffers then we may have to wait for + // data to come in + f = wait_for_input(); + } + return f.then([this] { + if (eof()) { + return make_ready_future(); + } + auto avail = BIO_ctrl_pending(in_bio()) + SSL_pending(_ssl.get()); + tls_log.trace("{} do_get: available: {}", *this, avail); + buf_type buf(avail); + size_t bytes_read = 0; + auto read_result = SSL_read_ex( + _ssl.get(), buf.get_write(), avail, &bytes_read); + tls_log.trace("{} do_get: SSL_read_ex: {}", *this, read_result); + tls_log.trace("{} do_get: SSL_read_ex bytes_ready: {}", *this, bytes_read); + if (read_result != 1) { + const auto ssl_err = SSL_get_error(_ssl.get(), read_result); + tls_log.trace("{} do_get: SSL_get_error: {}", *this, ssl_err); + switch (ssl_err) { + case SSL_ERROR_ZERO_RETURN: + // Remote end has closed + _eof = true; + [[fallthrough]]; + case SSL_ERROR_NONE: + // well we shouldn't be here at all + return make_ready_future(); + case SSL_ERROR_WANT_WRITE: + return wait_for_output().then([this] { return do_get(); }); + case SSL_ERROR_WANT_READ: + // This may be caused by a renegotiation request, in this situation + // return an empty buffer (the get() function will initiate a handshake) + return make_ready_future(); + case SSL_ERROR_SYSCALL: + if (ERR_peek_error() == 0) { + // SSL_get_error + // (https://www.openssl.org/docs/man3.0/man3/SSL_get_error.html) + // states that on OpenSSL versions prior to 3.0, an + // SSL_ERROR_SYSCALL with nothing on the stack and errno + // == 0 indicates EOF but future versions should report + // SSL_ERROR_SSL with SSL_R_UNEXPECTED_EOF_WHILE_READING + // on the stack. However we are seeing situations on + // OpenSSL versions 3.0.9 and 3.0.14 where SSL_ERROR_SYSCALL + // is returned and errno == 0 and the stack is empty. + // We will treat this as EOF + _eof = true; + return make_ready_future(); + } + _error = std::make_exception_ptr( + make_ossl_error("System error during SSL read")); + return make_exception_future(_error); + case SSL_ERROR_SSL: + { + auto ec = ERR_GET_REASON(ERR_peek_error()); + if (ec == SSL_R_UNEXPECTED_EOF_WHILE_READING) { + // in this situation, the remote end hung up + ERR_clear_error(); + _eof = true; + return make_ready_future(); + } + _error = std::make_exception_ptr( + make_ossl_error( + "Failure during processing SSL read")); + return make_exception_future(_error); + } + default: + _error = std::make_exception_ptr(std::runtime_error( + "Unexpected error condition during SSL read")); + return make_exception_future(_error); + } + } else { + buf.trim(bytes_read); + return make_ready_future(std::move(buf)); + } + }); + } + + // Called by user applications to pull data off of the TLS session + future get() override { + tls_log.trace("{} get", *this); + if (_error) { + return make_exception_future(_error); + } + if (_shutdown || eof()) { + return make_ready_future(buf_type()); + } + if (!connected()) { + tls_log.trace("{} get: not connected, performing handshake", *this); + return handshake().then(std::bind(&session::get, this)); + } + return with_semaphore(_in_sem, 1, std::bind(&session::do_get, this)) + .then([this](buf_type buf) { + if (buf.empty() && !eof()) { + tls_log.trace("{} get: buffer empty and not eof, performing handshake", *this); + return handshake().then(std::bind(&session::get, this)); + } + tls_log.trace("{} get: returning buffer of size {}", *this, buf.size()); + return make_ready_future(std::move(buf)); + }); + } + + // Performs shutdown + future<> do_shutdown() { + tls_log.trace("{} do_shutdown", *this); + if (_error || !connected()) { + tls_log.trace("{} do_shutdown: error exists or not connected", *this); + return make_ready_future(); + } + + auto res = SSL_shutdown(_ssl.get()); + tls_log.trace("{} do_shutdown: SSL_shutdown: {}", *this, res); + if (res == 1) { + return wait_for_output(); + } else if (res == 0) { + return yield().then([this] { return do_shutdown(); }); + } else { + auto ssl_err = SSL_get_error(_ssl.get(), res); + tls_log.trace("{} do_shutdown: SSL_get_error: {}", *this, ssl_err); + switch (ssl_err) { + case SSL_ERROR_NONE: + // this is weird, yield and try again + return yield().then([this] { return do_shutdown(); }); + case SSL_ERROR_ZERO_RETURN: + // Looks like the other end is done, so let's just assume we're + // done as well + return wait_for_output(); + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + return wait_for_output().then([this, ssl_err] { + // In neither case do we actually want to pull data off of the socket (yet) + // If we initiate the shutdown, then we just send the shutdown alert and wait + // for EOF (outside of this function) + if (ssl_err == SSL_ERROR_WANT_READ) { + return make_ready_future(); + } else { + return do_shutdown(); + } + }); + case SSL_ERROR_SYSCALL: + { + auto err = make_ossl_error("System error during shutdown"); + return handle_output_error(std::move(err)); + } + case SSL_ERROR_SSL: + { + if (ERR_GET_REASON(ERR_peek_error()) == SSL_R_APPLICATION_DATA_AFTER_CLOSE_NOTIFY) { + // This may have resulted in a race condition where we receive a packet immediately after + // sending out the close notify alert. In this situation, retry shutdown silently + ERR_clear_error(); + return yield().then([this] { return do_shutdown(); }); + } + auto err = make_ossl_error("Error occurred during SSL shutdown"); + return handle_output_error(std::move(err)); + } + default: + { + auto err = std::runtime_error( + "Unknown error occurred during SSL shutdown"); + return handle_output_error(std::move(err)); + } + } + } + } + + void verify() { + tls_log.trace("{} verify", *this); + if (!_creds->_enable_certificate_verification) { + tls_log.debug("{} verify: certificate verification disabled, skipping", *this); + return; + } + // A success return code (0) does not signify if a cert was presented or not, that + // must be explicitly queried via SSL_get_peer_certificate + auto res = SSL_get_verify_result(_ssl.get()); + tls_log.trace("{} verify: SSL_get_verify_result: {}", *this, res); + if (res != X509_V_OK) { + auto stat_cstr(X509_verify_cert_error_string(res)); + auto dn = extract_dn_information(); + if (dn) { + sstring stat_str{stat_cstr}; + boost::algorithm::trim(stat_str); + throw verification_error(fmt::format( + R"|({} (Issuer=["{}"], Subject=["{}"]))|", + stat_str, + dn->issuer, + dn->subject)); + } + throw verification_error(stat_cstr); + } else if (SSL_get0_peer_certificate(_ssl.get()) == nullptr) { + tls_log.trace("{} verify: No peer certificate", *this); + // If a peer certificate was not presented, + // SSL_get_verify_result will return X509_V_OK: + // https://www.openssl.org/docs/man3.0/man3/SSL_get_verify_result.html + if ( + _type == session_type::SERVER + && _creds->get_client_auth() == client_auth::REQUIRE) { + throw verification_error("no certificate presented by peer"); + } + return; + } + + if (_creds->_dn_callback) { + auto dn = extract_dn_information(); + assert(dn.has_value()); + _creds->_dn_callback( + _type, std::move(dn->subject), std::move(dn->issuer)); + } + } + + bool eof() const { + return _eof; + } + + bool connected() const { + return SSL_is_init_finished(_ssl.get()); + } + + // This function waits for eof() to occur on the input stream + // Unless wait_for_eof_on_shutdown is false + future<> wait_for_eof() { + tls_log.trace("{} wait_for_eof", *this); + if (!_options.wait_for_eof_on_shutdown) { + // Seastar option to allow users to just bypass EOF waiting + return make_ready_future(); + } + return with_semaphore(_in_sem, 1, [this] { + if (_error || !connected()) { + return make_ready_future(); + } + return do_until( + [this] { return eof(); }, + [this] { return do_get().discard_result(); }); + }).finally([this] { + tls_log.trace("{} wait_for_eof: complete", *this); + }); + } + + // This function is called to kick off the handshake. It will obtain + // locks on the _in_sem and _out_sem semaphores and start the handshake. + future<> handshake() { + tls_log.trace("{} handshake", *this); + if (_creds->need_load_system_trust()) { + if (!SSL_CTX_set_default_verify_paths(_ctx.get())) { + throw make_ossl_error( + "Could not load system trust"); + } + _creds->set_load_system_trust(false); + } + + return with_semaphore(_in_sem, 1, [this] { + return with_semaphore(_out_sem, 1, [this] { + return do_handshake().handle_exception([this](auto ep) { + if (!_error) { + _error = ep; + } + return make_exception_future<>(_error); + }); + }); + }); + } + + future<> shutdown() { + tls_log.trace("{} shutdown", *this); + // first, make sure any pending write is done. + // bye handshake is a flush operation, but this + // allows us to not pay extra attention to output state + // + // we only send a simple "bye" alert packet. Then we + // read from input until we see EOF. Any other reader + // before us will get it instead of us, and mark _eof = true + // in which case we will be no-op. This is performed all + // within do_shutdown + return with_semaphore(_out_sem, 1, + std::bind(&session::do_shutdown, this)).then( + std::bind(&session::wait_for_eof, this)).finally([me = shared_from_this()] {}); + // note moved finally clause above. It is theorethically possible + // that we could complete do_shutdown just before the close calls + // below, get pre-empted, have "close()" finish, get freed, and + // then call wait_for_eof on stale pointer. + } + + void close() noexcept override { + tls_log.trace("{} close", *this); + // only do once. + if (!std::exchange(_shutdown, true)) { + tls_log.trace("{} close: performing shutdown", *this); + // running in background. try to bye-handshake us nicely, but after 10s we forcefully close. + engine().run_in_background(with_timeout( + timer<>::clock::now() + std::chrono::seconds(10), shutdown()) + .finally([this] { + _eof = true; + return _in.close(); + }).finally([this] { + return _out.close(); + }).finally([this] { + // make sure to wait for handshake attempt to leave semaphores. Must be in same order as + // handshake aqcuire, because in worst case, we get here while a reader is attempting + // re-handshake. + return with_semaphore(_in_sem, 1, [this] { + return with_semaphore(_out_sem, 1, [] { }); + }); + }).handle_exception([me = shared_from_this()](std::exception_ptr){ + }).discard_result()); + } + } + // helper for sink + future<> flush() noexcept override { + return with_semaphore(_out_sem, 1, [this] { return _out.flush(); }); + } + + seastar::net::connected_socket_impl& socket() const override { + return *_sock; + } + + future> get_distinguished_name() override { + using result_t = std::optional; + if (_error) { + return make_exception_future(_error); + } + if (_shutdown) { + return make_exception_future( + std::system_error(ENOTCONN, std::system_category())); + } + if (!connected()) { + return handshake().then( + [this]() mutable { return get_distinguished_name(); }); + } + result_t dn = extract_dn_information(); + return make_ready_future(std::move(dn)); + } + + future> get_alt_name_information( + std::unordered_set types) override { + using result_t = std::vector; + + if (_error) { + return make_exception_future(_error); + } + if (_shutdown) { + return make_exception_future( + std::system_error(ENOTCONN, std::system_category())); + } + if (!connected()) { + return handshake().then([this, types = std::move(types)]() mutable { + return get_alt_name_information(std::move(types)); + }); + } + + const auto& peer_cert = get_peer_certificate(); + if (!peer_cert) { + return make_ready_future(); + } + return make_ready_future( + do_get_alt_name_information(peer_cert, types)); + } + + template + auto state_checked_access(Func&& f, Args&& ...args) { + using future_type = typename futurize>::type; + using result_t = typename future_type::value_type; + if (_error) { + return make_exception_future(_error); + } + if (_shutdown) { + return make_exception_future(std::system_error(ENOTCONN, std::system_category())); + } + if (!connected()) { + return handshake().then([this, f = std::move(f), ...args = std::forward(args)]() mutable { + return session::state_checked_access(std::move(f), std::forward(args)...); + }); + } + return futurize_invoke(f, std::forward(args)...); + } + + future is_resumed() override { + return state_checked_access([this] { + return SSL_session_reused(_ssl.get()) == 1; + }); + } + + future get_session_resume_data() override { + return state_checked_access([this] { + // get0 does not increment reference counter so no clean up necessary + auto sess = SSL_get0_session(_ssl.get()); + if (!sess || 0 == SSL_SESSION_is_resumable(sess)) { + return session_data{}; + } + auto len = i2d_SSL_SESSION(sess, nullptr); + if (len == 0) { + return session_data{}; + } + session_data data(len); + auto data_ptr = data.data(); + i2d_SSL_SESSION(sess, &data_ptr); + return data; + }); + } + + const sstring& local_address() const noexcept { + return _local_address; + } + + const sstring& remote_address() const noexcept { + return _remote_address; + } + +private: + std::vector do_get_alt_name_information(const x509_ptr &peer_cert, + const std::unordered_set &types) const { + int ext_idx = X509_get_ext_by_NID( + peer_cert.get(), NID_subject_alt_name, -1); + if (ext_idx < 0) { + return {}; + } + auto ext = X509_get_ext(peer_cert.get(), ext_idx); + if (!ext) { + return {}; + } + auto names = general_names_ptr(static_cast(X509V3_EXT_d2i(ext))); + if (!names) { + return {}; + } + int num_names = sk_GENERAL_NAME_num(names.get()); + std::vector alt_names; + alt_names.reserve(num_names); + + for (auto i = 0; i < num_names; i++) { + GENERAL_NAME* name = sk_GENERAL_NAME_value(names.get(), i); + if (auto known_t = field_to_san_type(name)) { + if (types.empty() || types.count(known_t->type)) { + alt_names.push_back(std::move(*known_t)); + } + } + } + return alt_names; + } + + std::optional field_to_san_type(GENERAL_NAME* name) const { + subject_alt_name san; + switch(name->type) { + case GEN_IPADD: + { + san.type = subject_alt_name_type::ipaddress; + const auto* data = ASN1_STRING_get0_data(name->d.iPAddress); + const auto size = ASN1_STRING_length(name->d.iPAddress); + if (size == sizeof(::in_addr)) { + ::in_addr addr; + memcpy(&addr, data, size); + san.value = net::inet_address(addr); + } else if (size == sizeof(::in6_addr)) { + ::in6_addr addr; + memcpy(&addr, data, size); + san.value = net::inet_address(addr); + } else { + throw std::runtime_error(fmt::format("Unexpected size: {} for ipaddress alt name value", size)); + } + break; + } + case GEN_EMAIL: + { + san.type = subject_alt_name_type::rfc822name; + san.value = asn1_str_to_str(name->d.rfc822Name); + break; + } + case GEN_URI: + { + san.type = subject_alt_name_type::uri; + san.value = asn1_str_to_str(name->d.uniformResourceIdentifier); + break; + } + case GEN_DNS: + { + san.type = subject_alt_name_type::dnsname; + san.value = asn1_str_to_str(name->d.dNSName); + break; + } + case GEN_OTHERNAME: + { + san.type = subject_alt_name_type::othername; + san.value = asn1_str_to_str(name->d.dNSName); + break; + } + case GEN_DIRNAME: + { + san.type = subject_alt_name_type::dn; + auto dirname = get_dn_string(name->d.directoryName); + if (!dirname) { + throw std::runtime_error("Expected non null value for SAN dirname"); + } + san.value = std::move(*dirname); + break; + } + default: + return std::nullopt; + } + return san; + } + + const x509_ptr& get_peer_certificate() const { + return _creds->get_last_cert(); + } + + std::optional extract_dn_information() const { + const auto& peer_cert = get_peer_certificate(); + if (!peer_cert) { + return std::nullopt; + } + auto subject = get_dn_string(X509_get_subject_name(peer_cert.get())); + auto issuer = get_dn_string(X509_get_issuer_name(peer_cert.get())); + if (!subject || !issuer) { + throw make_ossl_error( + "error while extracting certificate DN strings"); + } + return session_dn{ + .subject = std::move(*subject), .issuer = std::move(*issuer)}; + } + + ssl_ctx_ptr make_ssl_context(session_type type) { + auto ssl_ctx = ssl_ctx_ptr(SSL_CTX_new(TLS_method())); + if (!ssl_ctx) { + throw make_ossl_error( + "Failed to initialize SSL context"); + } + const auto& ck_pair = _creds->get_certkey_pair(); + if (type == session_type::SERVER) { + if (!ck_pair) { + throw make_ossl_error( + "Cannot start session without cert/key pair for server"); + } + switch (_creds->get_client_auth()) { + case client_auth::NONE: + default: + SSL_CTX_set_verify(ssl_ctx.get(), SSL_VERIFY_NONE, nullptr); + break; + case client_auth::REQUEST: + SSL_CTX_set_verify(ssl_ctx.get(), SSL_VERIFY_PEER, nullptr); + break; + case client_auth::REQUIRE: + SSL_CTX_set_verify( + ssl_ctx.get(), + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + nullptr); + break; + } + + auto options = SSL_OP_ALL | SSL_OP_ALLOW_CLIENT_RENEGOTIATION; + if (_creds->is_server_precedence_enabled()) { + options |= SSL_OP_CIPHER_SERVER_PREFERENCE; + } + + SSL_CTX_set_options(ssl_ctx.get(), options); + + switch(_creds->get_session_resume_mode()) { + case session_resume_mode::NONE: + SSL_CTX_set_session_cache_mode(ssl_ctx.get(), SSL_SESS_CACHE_OFF); + break; + case session_resume_mode::TLS13_SESSION_TICKET: + // By default, SSL contexts have server size cache enabled + if (1 != SSL_CTX_set_tlsext_ticket_key_evp_cb(ssl_ctx.get(), &session_ticket_cb)) { + throw make_ossl_error("Failed to set session ticket callback function"); + } + break; + } + } else { + if (_creds->is_server_precedence_enabled()) { + SSL_CTX_set_options(ssl_ctx.get(), SSL_OP_CIPHER_SERVER_PREFERENCE); + } + } + + auto& min_tls_version = _creds->minimum_tls_version(); + auto& max_tls_version = _creds->maximum_tls_version(); + + if (min_tls_version.has_value()) { + if (!SSL_CTX_set_min_proto_version(ssl_ctx.get(), + tls_version_to_ossl(*min_tls_version))) { + throw make_ossl_error( + fmt::format("Failed to set minimum TLS version to {}", + *min_tls_version)); + } + } + + if (max_tls_version.has_value()) { + if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), + tls_version_to_ossl(*max_tls_version))) { + throw make_ossl_error( + fmt::format("Failed to set maximum TLS version to {}", + *max_tls_version)); + } + } + + auto get_security_level = [&ssl_ctx]() { + auto min_version = SSL_CTX_get_min_proto_version(ssl_ctx.get()); + // If 3.0.0 <= OpenSSL Version < 3.1.0 then: + // TLS1.0 is disabled at level 3 and TLS1.1 is disabled at level4 + // If 3.1.0 <= OpenSSL Version, then TLS1.0 and 1.2 are disabled at level 1 + #if OPENSSL_VERSION_NUMBER >= 0x30000000 && OPENSSL_VERSION_NUMBER < 0x30100000 + switch(min_version) { + case SSL3_VERSION: + case TLS1_VERSION: + return 2; + case TLS1_1VERSION: + case DTLS1_VERSION: + return 3; + default: + return 4; + } + #elif OPENSSL_VERSION_NUMBER >= 0x30100000 + switch(min_version) { + case SSL3_VERSION: + case TLS1_VERSION: + case TLS1_1_VERSION: + case DTLS1_VERSION: + return 0; + default: + return 1; + } + #else + #error "Unsupported OpenSSL Version" + #endif + }; + + SSL_CTX_set_security_level(ssl_ctx.get(), get_security_level()); + + // Servers must supply both certificate and key, clients may + // optionally use these + if (ck_pair) { + if (!SSL_CTX_use_cert_and_key( + ssl_ctx.get(), + ck_pair.cert.get(), + ck_pair.key.get(), + nullptr, + 1)) { + throw make_ossl_error( + "Failed to load cert/key pair"); + } + } + // Increments the reference count of *_creds, now should have a + // total ref count of two, will be deallocated when both OpenSSL and + // the certificate_manager call X509_STORE_free + SSL_CTX_set1_cert_store(ssl_ctx.get(), *_creds); + + if (!_creds->get_cipher_string().empty()) { + if (SSL_CTX_set_cipher_list(ssl_ctx.get(), + _creds->get_cipher_string().c_str()) != 1) { + throw make_ossl_error( + fmt::format( + "Failed to set cipher string '{}'", _creds->get_cipher_string())); + } + } + + if (!_creds->get_ciphersuites().empty()) { + if (SSL_CTX_set_ciphersuites(ssl_ctx.get(), _creds->get_ciphersuites().c_str()) != 1) { + throw make_ossl_error( + fmt::format( + "Failed to set ciphersuites '{}'", _creds->get_ciphersuites())); + } + } + + return ssl_ctx; + } + + static std::optional get_dn_string(X509_NAME* name) { + auto out = bio_ptr(BIO_new(BIO_s_mem())); + if (-1 == X509_NAME_print_ex(out.get(), name, 0, ASN1_STRFLGS_RFC2253 | XN_FLAG_SEP_COMMA_PLUS | + XN_FLAG_FN_SN | XN_FLAG_DUMP_UNKNOWN_FIELDS)) { + return std::nullopt; + } + char* bio_ptr = nullptr; + auto len = BIO_get_mem_data(out.get(), &bio_ptr); + if (len < 0) { + throw make_ossl_error("Failed to allocate DN string"); + } + return sstring(bio_ptr, len); + } + + size_t in_avail() const { return _input.size(); } + + BIO* in_bio() { return SSL_get_rbio(_ssl.get()); } + BIO* out_bio() { return SSL_get_wbio(_ssl.get()); } + +private: + std::unique_ptr _sock; + sstring _local_address; + sstring _remote_address; + shared_ptr _creds; + data_source _in; + data_sink _out; + std::exception_ptr _error; + + semaphore _in_sem; + semaphore _out_sem; + tls_options _options; + + future<> _output_pending; + buf_type _input; + ssl_ctx_ptr _ctx; + ssl_ptr _ssl; + session_type _type; + bool _eof = false; + bool _shutdown = false; + + friend int bio_write_ex(BIO* b, const char * data, size_t dlen, size_t * written); + friend int bio_read_ex(BIO* b, char * data, size_t dlen, size_t *readbytes); + friend long bio_ctrl(BIO * b, int ctrl, long num, void * data); + friend int session_ticket_cb(SSL*, unsigned char[16], unsigned char[EVP_MAX_IV_LENGTH], + EVP_CIPHER_CTX*, EVP_MAC_CTX*, int); +}; + +// The following callback function is used whenever session tickets are generated or received by +// the TLS server. If TLS session resumption is enabled, then an AES and HMAC key are +// generated and stored within the certificate_credentials (which is stored within the TLS session). +// The call back uses these keys to initialize the encryption and MAC operations for both encryption (enc = 1) +// and decryption (enc = 0). Because the key lives with the certificate_credentials which is passed +// to every instance of an SSL session, the same key can be used over and over again to encrypt/decrypt +// session tickets across multiple instances of server sessions. For more information see: +// https://docs.openssl.org/3.0/man3/SSL_CTX_set_tlsext_ticket_key_cb/ +int session_ticket_cb(SSL * s, unsigned char key_name[16], + unsigned char iv[EVP_MAX_IV_LENGTH], + EVP_CIPHER_CTX * ctx, EVP_MAC_CTX *hctx, int enc) { + auto * sess = static_cast(SSL_get_ex_data(s, SSL_EX_DATA_SESSION)); + std::span key_name_span(key_name, 16); + const auto & gen_key_name = sess->_creds->get_session_ticket_keys().key_name; + const auto & aes_key = sess->_creds->get_session_ticket_keys().aes_key; + auto hmac_key_ptr = sess->_creds->get_session_ticket_keys().hmac_key.data(); + auto hmac_key_size = sess->_creds->get_session_ticket_keys().hmac_key.size(); + OSSL_PARAM params[3]; + params[0] = OSSL_PARAM_construct_octet_string(OSSL_MAC_PARAM_KEY, + const_cast(hmac_key_ptr), + hmac_key_size); + params[1] = OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_DIGEST, + const_cast("sha256"), 0); + params[2] = OSSL_PARAM_construct_end(); + + if (enc) { + if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) <= 0) { + return -1; + } + + std::copy(gen_key_name.begin(), gen_key_name.end(), key_name_span.begin()); + + if (EVP_EncryptInit_ex2(ctx, EVP_aes_256_cbc(), aes_key.data(), iv, nullptr) == 0) { + return -1; + } + + if (EVP_MAC_CTX_set_params(hctx, params) == 0) { + return -1; + } + + return 1; + } else { + if (!std::equal(key_name_span.begin(), key_name_span.end(), gen_key_name.begin())) { + return 0; + } + if (EVP_MAC_CTX_set_params(hctx, params) == 0) { + return -1; + } + + if (EVP_DecryptInit_ex2(ctx, EVP_aes_256_cbc(), aes_key.data(), iv, nullptr) == 0) { + return -1; + } + return 1; + } +} + + +tls::session* unwrap_bio_ptr(void * ptr) { + return static_cast(ptr); +} + +tls::session* unwrap_bio_ptr(BIO * b) { + return unwrap_bio_ptr(BIO_get_data(b)); +} + +/// The 'ioctl' for BIO +long bio_ctrl(BIO * b, int ctrl, long num, void * data) { + if (BIO_get_init(b) <= 0 && ctrl != BIO_C_SET_POINTER) { + return 0; + } + + auto session = unwrap_bio_ptr(b); + + switch(ctrl) { + case BIO_C_SET_POINTER: + if (BIO_get_init(b) <= 0) { + BIO_set_data(b, data); + BIO_set_init(b, 1); + return 1; + } else { + return 0; + } + case BIO_CTRL_GET_CLOSE: + return BIO_get_shutdown(b); + case BIO_CTRL_SET_CLOSE: + BIO_set_shutdown(b, static_cast(num)); + break; + + case BIO_CTRL_DUP: + case BIO_CTRL_FLUSH: + return 1; + case BIO_CTRL_EOF: + return BIO_test_flags(b, BIO_FLAGS_IN_EOF) != 0; + case BIO_CTRL_PENDING: + return static_cast(session->_input.size()); + case BIO_CTRL_WPENDING: + return session->_output_pending.available() ? 0 : 1; + default: + return 0; + } + + return 0; +} + +/// This is called when the BIO is created +/// +/// It is important for this to be set, even if it doesn't do anything +/// because the BIO init flag won't get automatically set to '1'. +int bio_create(BIO*) { + return 1; +} + +/// Handles writes to the BIO +/// +/// This function will attempt to call _out.put() and store the future in +/// _output_pending. If _output_pending has not yet resolved, return '0' +/// and set the retry write flag. +int bio_write_ex(BIO* b, const char * data, size_t dlen, size_t * written) { + auto session = unwrap_bio_ptr(b); + tls_log.trace("{} bio_write_ex: dlen {}", *session, dlen); + BIO_clear_retry_flags(b); + + if (!session->_output_pending.available()) { + tls_log.trace("{} bio_write_ex: nothing pending in output", *session); + BIO_set_retry_write(b); + return 0; + } + + try { + size_t n; + + if (!session->_output_pending.failed()) { + scattered_message msg; + msg.append(std::string_view(data, dlen)); + n = msg.size(); + session->_output_pending = session->_out.put(std::move(msg).release()); + tls_log.trace("{} bio_write_ex: Appended {} bytes to output pending", *session, n); + } + + if (session->_output_pending.failed()) { + tls_log.debug("{} bio_write_ex: output pending has error", *session); + std::rethrow_exception(session->_output_pending.get_exception()); + } + + if (written != nullptr) { + *written = n; + } + + return 1; + } catch(const std::system_error & e) { + tls_log.debug("{} bio_write_ex: system error occurred: {}", *session, e.what()); + ERR_raise_data(ERR_LIB_SYS, e.code().value(), e.what()); + session->_output_pending = make_exception_future<>(std::current_exception()); + } catch(...) { + tls_log.debug("{} bio_write_ex: unknown error occurred", *session); + ERR_raise(ERR_LIB_SYS, EIO); + session->_output_pending = make_exception_future<>(std::current_exception()); + } + + return 0; +} + +/// Handles reading data from the BIO +/// +/// This will check to see if EOF has been reached and set the EOF +/// flag if EOF has been reached. It will set the retry read flag +/// if no data is available, otherwise it will copy data off of +/// the _input buffer and return it to the caller. +int bio_read_ex(BIO* b, char * data, size_t dlen, size_t *readbytes) { + auto session = unwrap_bio_ptr(b); + tls_log.trace("{} bio_read_ex: dlen: {}", *session, dlen); + BIO_clear_retry_flags(b); + if (session->eof()) { + tls_log.trace("{} bio_read_ex: eof", *session); + BIO_set_flags(b, BIO_FLAGS_IN_EOF); + return 0; + } + + if (session->_input.empty()) { + tls_log.trace("{} bio_read_ex: input empty", *session); + BIO_set_retry_read(b); + return 0; + } + + auto n = std::min(dlen, session->_input.size()); + memcpy(data, session->_input.get(), n); + session->_input.trim_front(n); + if (readbytes != nullptr) { + *readbytes = n; + } + + tls_log.trace("{} bio_read_ex: read {} bytes from input", *session, n); + return 1; +} + +/// This function creates the custom BIO method +bio_method_ptr create_bio_method() { + auto new_index = BIO_get_new_index(); + if (new_index == -1) { + throw make_ossl_error("Failed to obtain new BIO index"); + } + bio_method_ptr meth(BIO_meth_new(new_index, "SS-OSSL")); + if (!meth) { + throw make_ossl_error("Failed to create new BIO method"); + } + + if (1 != BIO_meth_set_create(meth.get(), bio_create)) { + throw make_ossl_error("Failed to set the BIO creation method"); + } + + if (1 != BIO_meth_set_ctrl(meth.get(), bio_ctrl)) { + throw make_ossl_error("Failed to set BIO control method"); + } + + if (1 != BIO_meth_set_write_ex(meth.get(), bio_write_ex)) { + throw make_ossl_error("Failed to set BIO write_ex method"); + } + + if (1 != BIO_meth_set_read_ex(meth.get(), bio_read_ex)) { + throw make_ossl_error("Failed to set BIO read_ex method"); + } + + return meth; +} + +} // namespace tls + +BIO_METHOD* get_method() { + static thread_local bio_method_ptr method_ptr = [] { + return tls::create_bio_method(); + }(); + + return method_ptr.get(); +} + +future tls::wrap_client(shared_ptr cred, connected_socket&& s, sstring name) { + tls_options options{.server_name = std::move(name)}; + return wrap_client(std::move(cred), std::move(s), std::move(options)); +} + +future tls::wrap_client(shared_ptr cred, connected_socket&& s, tls_options options) { + session_ref sess(seastar::make_shared(session_type::CLIENT, std::move(cred), std::move(s), options)); + connected_socket sock(std::make_unique(std::move(sess))); + return make_ready_future(std::move(sock)); +} + +future tls::wrap_server(shared_ptr cred, connected_socket&& s) { + session_ref sess(seastar::make_shared(session_type::SERVER, std::move(cred), std::move(s))); + connected_socket sock(std::make_unique(std::move(sess))); + return make_ready_future(std::move(sock)); +} + +} // namespace seastar + +const int seastar::tls::ERROR_UNKNOWN_COMPRESSION_ALGORITHM = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_UNSUPPORTED_COMPRESSION_ALGORITHM); +const int seastar::tls::ERROR_UNKNOWN_CIPHER_TYPE = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_UNKNOWN_CIPHER_TYPE); +const int seastar::tls::ERROR_INVALID_SESSION = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_INVALID_SESSION_ID); +const int seastar::tls::ERROR_UNEXPECTED_HANDSHAKE_PACKET = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_UNEXPECTED_RECORD); +const int seastar::tls::ERROR_UNKNOWN_CIPHER_SUITE = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_UNSUPPORTED_PROTOCOL); +const int seastar::tls::ERROR_UNKNOWN_ALGORITHM = ERR_PACK( + ERR_LIB_RSA, 0, RSA_R_UNKNOWN_ALGORITHM_TYPE); +const int seastar::tls::ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_NO_SUITABLE_SIGNATURE_ALGORITHM); +const int seastar::tls::ERROR_SAFE_RENEGOTIATION_FAILED = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_RENEGOTIATION_MISMATCH); +const int seastar::tls::ERROR_UNSAFE_RENEGOTIATION_DENIED = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_UNSAFE_LEGACY_RENEGOTIATION_DISABLED); +const int seastar::tls::ERROR_UNKNOWN_SRP_USERNAME = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_INVALID_SRP_USERNAME); +const int seastar::tls::ERROR_PREMATURE_TERMINATION = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_UNEXPECTED_EOF_WHILE_READING); +// System errors are not ERR_PACK'ed like other errors but instead +// are OR'ed with ((unsigned int)INT_MAX + 1) +const int seastar::tls::ERROR_PUSH = int(ERR_SYSTEM_FLAG | EPIPE); +const int seastar::tls::ERROR_PULL = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_READ_BIO_NOT_SET); +const int seastar::tls::ERROR_UNEXPECTED_PACKET = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_UNEXPECTED_MESSAGE); +const int seastar::tls::ERROR_UNSUPPORTED_VERSION = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_UNSUPPORTED_SSL_VERSION); +const int seastar::tls::ERROR_NO_CIPHER_SUITES = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_NO_CIPHERS_AVAILABLE); +const int seastar::tls::ERROR_DECRYPTION_FAILED = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_DECRYPTION_FAILED); +const int seastar::tls::ERROR_MAC_VERIFY_FAILED = ERR_PACK( + ERR_LIB_SSL, 0, SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC); diff --git a/src/net/tls-impl.cc b/src/net/tls-impl.cc new file mode 100644 index 0000000000..d167cc472e --- /dev/null +++ b/src/net/tls-impl.cc @@ -0,0 +1,764 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2015 Cloudius Systems + */ + +#ifdef SEASTAR_MODULE +module; +#endif + +#include +#include +#include + +#include + +#ifdef SEASTAR_USE_GNUTLS +#include +#endif + +#include +#include +#include + +#include +#include + +#ifdef SEASTAR_MODULE +module seastar; +#else +#include +#include +#include +#include +#include + +#include "net/tls-impl.hh" +#endif + +namespace seastar { + +logger tls::tls_log("seastar-tls"); + +std::unique_ptr net::get_impl::get(connected_socket s) { + return std::move(s._csi); +} + +net::connected_socket_impl* net::get_impl::maybe_get_ptr(connected_socket& s) { + if (s._csi) { + return s._csi.get(); + } + return nullptr; +} + +tls::session_ref::~session_ref() { + // This is not super pretty. But we take some care to only own sessions + // through session_ref, and we need to initiate shutdown on "last owner", + // since we cannot revive the session in destructor. + if (_session && _session.use_count() == 1) { + _session->close(); + } +} + +// Helper +struct file_info { + sstring filename; + std::chrono::system_clock::time_point modified; +}; + +struct file_result { + temporary_buffer buf; + file_info file; + operator temporary_buffer&&() && { + return std::move(buf); + } +}; + +static future read_fully(const sstring& name, const sstring& what) { + return open_file_dma(name, open_flags::ro).then([name = name](file f) mutable { + return do_with(std::move(f), [name = std::move(name)](file& f) mutable { + return f.stat().then([&f, name = std::move(name)](struct stat s) mutable { + return f.dma_read_bulk(0, s.st_size).then([s, name = std::move(name)](temporary_buffer buf) mutable { + return file_result{ std::move(buf), file_info{ + std::move(name), std::chrono::system_clock::from_time_t(s.st_mtim.tv_sec) + + std::chrono::duration_cast(std::chrono::nanoseconds(s.st_mtim.tv_nsec)) + } }; + }); + }).finally([&f]() { + return f.close(); + }); + }); + }).handle_exception([name = name, what = what](std::exception_ptr ep) -> future { + try { + std::rethrow_exception(std::move(ep)); + } catch (...) { + std::throw_with_nested(std::runtime_error(sstring("Could not read ") + what + " " + name)); + } + }); +} + +future tls::dh_params::from_file( + const sstring& filename, x509_crt_format fmt) { + return read_fully(filename, "dh parameters").then([fmt](temporary_buffer buf) { + return make_ready_future(dh_params(blob(buf.get()), fmt)); + }); +} + +future<> tls::abstract_credentials::set_x509_trust_file( + const sstring& cafile, x509_crt_format fmt) { + return read_fully(cafile, "trust file").then([this, fmt](temporary_buffer buf) { + set_x509_trust(blob(buf.get(), buf.size()), fmt); + }); +} + +future<> tls::abstract_credentials::set_x509_crl_file( + const sstring& crlfile, x509_crt_format fmt) { + return read_fully(crlfile, "crl file").then([this, fmt](temporary_buffer buf) { + set_x509_crl(blob(buf.get(), buf.size()), fmt); + }); +} + +future<> tls::abstract_credentials::set_x509_key_file( + const sstring& cf, const sstring& kf, x509_crt_format fmt) { + return read_fully(cf, "certificate file").then([this, fmt, kf = kf](temporary_buffer buf) { + return read_fully(kf, "key file").then([this, fmt, buf = std::move(buf)](temporary_buffer buf2) { + set_x509_key(blob(buf.get(), buf.size()), blob(buf2.get(), buf2.size()), fmt); + }); + }); +} + +future<> tls::abstract_credentials::set_simple_pkcs12_file( + const sstring& pkcs12file, x509_crt_format fmt, + const sstring& password) { + return read_fully(pkcs12file, "pkcs12 file").then([this, fmt, password = password](temporary_buffer buf) { + set_simple_pkcs12(blob(buf.get(), buf.size()), fmt, password); + }); +} + +static const sstring dh_level_key = "dh_level"; +static const sstring x509_trust_key = "x509_trust"; +static const sstring x509_crl_key = "x509_crl"; +static const sstring x509_key_key = "x509_key"; +static const sstring pkcs12_key = "pkcs12"; +static const sstring system_trust = "system_trust"; + +using buffer_type = std::basic_string>; + +struct x509_simple { + buffer_type data; + tls::x509_crt_format format; + file_info file; +}; + +struct x509_key { + buffer_type cert; + buffer_type key; + tls::x509_crt_format format; + file_info cert_file; + file_info key_file; +}; + +struct pkcs12_simple { + buffer_type data; + tls::x509_crt_format format; + sstring password; + file_info file; +}; + +void tls::credentials_builder::set_dh_level(dh_params::level level) { + _blobs.emplace(dh_level_key, level); +} + +void tls::credentials_builder::set_x509_trust(const blob& b, x509_crt_format fmt) { + _blobs.emplace(x509_trust_key, x509_simple{ std::string(b), fmt }); +} + +void tls::credentials_builder::set_x509_crl(const blob& b, x509_crt_format fmt) { + _blobs.emplace(x509_crl_key, x509_simple{ std::string(b), fmt }); +} + +void tls::credentials_builder::set_x509_key(const blob& cert, const blob& key, x509_crt_format fmt) { + _blobs.emplace(x509_key_key, x509_key { std::string(cert), std::string(key), fmt }); +} + +void tls::credentials_builder::set_simple_pkcs12(const blob& b, x509_crt_format fmt, const sstring& password) { + _blobs.emplace(pkcs12_key, pkcs12_simple{std::string(b), fmt, password }); +} + +static buffer_type to_buffer(const temporary_buffer& buf) { + return buffer_type(buf.get(), buf.get() + buf.size()); +} + +future<> tls::credentials_builder::set_x509_trust_file(const sstring& cafile, x509_crt_format fmt) { + return read_fully(cafile, "trust file").then([this, fmt](file_result f) { + _blobs.emplace(x509_trust_key, x509_simple{ to_buffer(f.buf), fmt, std::move(f.file) }); + }); +} + +future<> tls::credentials_builder::set_x509_crl_file(const sstring& crlfile, x509_crt_format fmt) { + return read_fully(crlfile, "crl file").then([this, fmt](file_result f) { + _blobs.emplace(x509_crl_key, x509_simple{ to_buffer(f.buf), fmt, std::move(f.file) }); + }); +} + +future<> tls::credentials_builder::set_x509_key_file(const sstring& cf, const sstring& kf, x509_crt_format fmt) { + return read_fully(cf, "certificate file").then([this, fmt, kf = kf](file_result cf) { + return read_fully(kf, "key file").then([this, fmt, cf = std::move(cf)](file_result kf) { + _blobs.emplace(x509_key_key, x509_key{ to_buffer(cf.buf), to_buffer(kf.buf), fmt, std::move(cf.file), std::move(kf.file) }); + }); + }); +} + +future<> tls::credentials_builder::set_simple_pkcs12_file(const sstring& pkcs12file, x509_crt_format fmt, const sstring& password) { + return read_fully(pkcs12file, "pkcs12 file").then([this, fmt, password = password](file_result f) { + _blobs.emplace(pkcs12_key, pkcs12_simple{ to_buffer(f.buf), fmt, password, std::move(f.file) }); + }); +} + +future<> tls::credentials_builder::set_system_trust() { + // TODO / Caveat: + // We cannot actually issue a loading of system trust here, + // because we have no actual tls context. + // And we probably _don't want to get into the guessing game + // of where the system trust cert chains are, since this is + // super distro dependent, and usually compiled into the library. + // Pretent it is raining, and just set a flag. + // Leave the function returning future, so if we change our + // minds and want to do explicit loading, we can... + _blobs.emplace(system_trust, true); + return make_ready_future(); +} + +void tls::credentials_builder::set_client_auth(client_auth auth) { + _client_auth = auth; +} + +void tls::credentials_builder::set_session_resume_mode(session_resume_mode m) { + _session_resume_mode = m; +} + +template +static void visit_blobs(Blobs& blobs, Visitor&& visitor) { + auto visit = [&](const sstring& key, auto* vt) { + auto tr = blobs.equal_range(key); + for (auto& p : boost::make_iterator_range(tr.first, tr.second)) { + auto* v = std::any_cast>(&p.second); + visitor(key, *v); + } + }; + visit(x509_trust_key, static_cast(nullptr)); + visit(x509_crl_key, static_cast(nullptr)); + visit(x509_key_key, static_cast(nullptr)); + visit(pkcs12_key, static_cast(nullptr)); +} + +void tls::credentials_builder::apply_to(certificate_credentials& creds) const { + // Could potentially be templated down, but why bother... + visit_blobs(_blobs, make_visitor( + [&](const sstring& key, const x509_simple& info) { + if (key == x509_trust_key) { + creds.set_x509_trust(info.data, info.format); + } else if (key == x509_crl_key) { + creds.set_x509_crl(info.data, info.format); + } + }, + [&](const sstring&, const x509_key& info) { + creds.set_x509_key(info.cert, info.key, info.format); + }, + [&](const sstring&, const pkcs12_simple& info) { + creds.set_simple_pkcs12(info.data, info.format, info.password); + } + )); + + // TODO / Caveat: + // We cannot do this immediately, because we are not a continuation, and + // potentially blocking calls are a no-no. + // Doing this detached would be indeterministic, so set a flag in + // credentials, and do actual loading in first handshake (see session) + if (_blobs.count(system_trust)) { + creds.enable_load_system_trust(); + } + +#ifdef SEASTAR_USE_GNUTLS + if (!_priority.empty()) { + creds.set_priority_string(_priority); + } +#endif + +#ifdef SEASTAR_USE_OPENSSL + if (!_cipher_string.empty()) { + creds.set_cipher_string(_cipher_string); + } + + if (!_ciphersuites.empty()) { + creds.set_ciphersuites(_ciphersuites); + } + + if (_enable_server_precedence) { + creds.enable_server_precedence(); + } + + if (_min_tls_version.has_value()) { + creds.set_minimum_tls_version(*_min_tls_version); + } + + if (_max_tls_version.has_value()) { + creds.set_maximum_tls_version(*_max_tls_version); + } +#endif + + creds.set_client_auth(_client_auth); + creds.set_session_resume_mode(_session_resume_mode); +} + +shared_ptr tls::credentials_builder::build_certificate_credentials() const { + auto creds = make_shared(); + apply_to(*creds); + return creds; +} + +shared_ptr tls::credentials_builder::build_server_credentials() const { + auto i = _blobs.find(dh_level_key); + if (i == _blobs.end()) { +#if GNUTLS_VERSION_NUMBER < 0x030600 && SEASTAR_USE_GNUTLS + throw std::invalid_argument("No DH level set"); +#else + auto creds = make_shared(); + apply_to(*creds); + return creds; +#endif + } + auto creds = make_shared(dh_params(std::any_cast(i->second))); + apply_to(*creds); + return creds; +} + +using namespace std::chrono_literals; + +class tls::reloadable_credentials_base { +public: + using delay_type = std::chrono::milliseconds; + static inline constexpr delay_type default_tolerance = 500ms; + + class reloading_builder + : public credentials_builder + , public enable_shared_from_this + { + public: + using time_point = std::chrono::system_clock::time_point; + + reloading_builder(credentials_builder b, reload_callback cb, reloadable_credentials_base* creds, delay_type delay) + : credentials_builder(std::move(b)) + , _cb(std::move(cb)) + , _creds(creds) + , _delay(delay) + {} + future<> init() { + std::vector> futures; + visit_blobs(_blobs, make_visitor( + [&](const sstring&, const x509_simple& info) { + _all_files.emplace(info.file.filename); + }, + [&](const sstring&, const x509_key& info) { + _all_files.emplace(info.cert_file.filename); + _all_files.emplace(info.key_file.filename); + }, + [&](const sstring&, const pkcs12_simple& info) { + _all_files.emplace(info.file.filename); + } + )); + return parallel_for_each(_all_files, [this](auto& f) { + if (!f.empty()) { + return add_watch(f).discard_result(); + } + return make_ready_future<>(); + }).finally([me = shared_from_this()] {}); + } + void start() { + // run the loop in a thread. makes code almost readable. + (void)async(std::bind(&reloading_builder::run, this)).finally([me = shared_from_this()] {}); + } + void run() { + while (_creds) { + try { + auto events = _fsn.wait().get(); + if (events.empty() && _creds == nullptr) { + return; + } + rebuild(events); + _timer.cancel(); + } catch (...) { + if (!_timer.armed()) { + _timer.set_callback([this, ep = std::current_exception()]() mutable { + do_callback(std::move(ep)); + }); + _timer.arm(_delay); + } + } + } + } + void detach() { + _creds = nullptr; + _cb = {}; + _fsn.shutdown(); + _timer.cancel(); + } + + private: + using fsnotifier = experimental::fsnotifier; + + // called from seastar::thread + void rebuild(const std::vector& events) { + for (auto& e : events) { + // don't use at. We could be getting two events for + // same watch (mod + delete), but we only need to care + // about one... + auto i = _watches.find(e.id); + if (i != _watches.end()) { + auto& filename = i->second.second; + // only add actual file watches to + // query set. If this was a directory + // watch, the file should already be + // in there. + if (_all_files.count(filename)) { + _files[filename] = e.mask; + } + _watches.erase(i); + } + } + auto num_changed = 0; + + auto maybe_reload = [&](const sstring& filename, buffer_type& dst) { + if (filename.empty() || !_files.count(filename)) { + return; + } + // #756 + // first, add a watch to nearest parent dir we + // can find. If user deleted folders, we could end + // up looking at modifications to root. + // The idea is that should adding a watch to actual file + // fail (deleted file/folder), we wait for changes to closest + // parent. When this happens, we will retry all files + // that have not been successfully replaced (and maybe more), + // repeating the process. At some point, we hopefully + // get new, current data. + add_dir_watch(filename); + // #756 add watch _first_. File could change while we are + // reading this. + try { + add_watch(filename).get(); + } catch (...) { + // let's just assume if this happens, it's because the file or folder was deleted. + // just ignore for now, and hope the dir watch will tell us when it is back... + return; + } + temporary_buffer buf = read_fully(filename, "reloading").get(); + dst = to_buffer(buf); + ++num_changed; + }; + visit_blobs(_blobs, make_visitor( + [&](const sstring&, x509_simple& info) { + maybe_reload(info.file.filename, info.data); + }, + [&](const sstring&, x509_key& info) { + maybe_reload(info.cert_file.filename, info.cert); + maybe_reload(info.key_file.filename, info.key); + }, + [&](const sstring&, pkcs12_simple& info) { + maybe_reload(info.file.filename, info.data); + } + )); + // only try this if anything was in fact successfully loaded. + // if files were missing, or pairs incomplete, we can just skip. + if (num_changed == 0) { + return; + } + try { + if (_creds) { + _creds->rebuild(*this); + } + } catch (...) { + if (std::any_of(_files.begin(), _files.end(), [](auto& p) { return p.second == fsnotifier::flags::ignored; })) { + // if any file in the reload set was deleted - i.e. we have not seen a "closed" yet - assume + // this is a spurious reload and we'd better wait for next event - hopefully a "closed" - + // and try again + return; + } + throw; + } + // if we got here, all files loaded, all watches were created, + // and gnutls was ok with the content. success. + do_callback(); + on_success(); + } + void on_success() { + _files.clear(); + // remove all directory watches, since we've successfully + // reloaded -> the file watches themselves should suffice now + auto i = _watches.begin(); + auto e = _watches.end(); + while (i != e) { + if (!_all_files.count(i->second.second)) { + i = _watches.erase(i); + continue; + } + ++i; + } + } + void do_callback(std::exception_ptr ep = {}) { + if (_cb && !_files.empty()) { + _cb(boost::copy_range>(_files | boost::adaptors::map_keys), std::move(ep)); + } + } + // called from seastar::thread + fsnotifier::watch_token add_dir_watch(const sstring& filename) { + auto dir = std::filesystem::path(filename).parent_path(); + for (;;) { + try { + return add_watch(dir.native(), fsnotifier::flags::create_child | fsnotifier::flags::move).get(); + } catch (...) { + auto parent = dir.parent_path(); + if (parent.empty() || dir == parent) { + throw; + } + dir = std::move(parent); + continue; + } + } + } + future add_watch(const sstring& filename, fsnotifier::flags flags = fsnotifier::flags::close_write|fsnotifier::flags::delete_self) { + return _fsn.create_watch(filename, flags).then([this, filename = filename](fsnotifier::watch w) { + auto t = w.token(); + // we might create multiple watches for same token in case of dirs, avoid deleting previously + // created one + if (_watches.count(t)) { + w.release(); + } else { + _watches.emplace(t, std::make_pair(std::move(w), filename)); + } + return t; + }); + } + + reload_callback _cb; + reloadable_credentials_base* _creds; + fsnotifier _fsn; + std::unordered_map> _watches; + std::unordered_map _files; + std::unordered_set _all_files; + timer<> _timer; + delay_type _delay; + }; + reloadable_credentials_base(credentials_builder builder, reload_callback cb, delay_type delay = default_tolerance) + : _builder(seastar::make_shared(std::move(builder), std::move(cb), this, delay)) + { + _builder->start(); + } + future<> init() { + return _builder->init(); + } + virtual ~reloadable_credentials_base() { + _builder->detach(); + } + virtual void rebuild(const credentials_builder&) = 0; + virtual const tls::certificate_credentials& as_certificate_credentials() const noexcept = 0; +private: + shared_ptr _builder; +}; + +template +class tls::reloadable_credentials : public Base, public tls::reloadable_credentials_base { +public: + reloadable_credentials(credentials_builder builder, reload_callback cb, Base b, delay_type delay = default_tolerance) + : Base(std::move(b)) + , tls::reloadable_credentials_base(std::move(builder), std::move(cb), delay) + {} + void rebuild(const credentials_builder&) override; + const tls::certificate_credentials& as_certificate_credentials() const noexcept override; + +}; + +template<> +void tls::reloadable_credentials::rebuild(const credentials_builder& builder) { + auto tmp = builder.build_certificate_credentials(); + this->_impl = std::move(tmp->_impl); +} + +template <> +const tls::certificate_credentials& tls::reloadable_credentials::as_certificate_credentials() const noexcept { + return *this; +} + +template<> +void tls::reloadable_credentials::rebuild(const credentials_builder& builder) { + auto tmp = builder.build_server_credentials(); + this->_impl = std::move(tmp->_impl); +} + +template <> +const tls::certificate_credentials& tls::reloadable_credentials::as_certificate_credentials() const noexcept{ + return *this; +} + +future> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback cb, std::optional tolerance) const { + auto creds = seastar::make_shared>(*this, std::move(cb), std::move(*build_certificate_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance)); + return creds->init().then([creds] { + return make_ready_future>(creds); + }); +} + + +future> tls::credentials_builder::build_reloadable_server_credentials(reload_callback cb, std::optional tolerance) const { + auto creds = seastar::make_shared>(*this, std::move(cb), std::move(*build_server_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance)); + return creds->init().then([creds] { + return make_ready_future>(creds); + }); +} + +data_source tls::tls_connected_socket_impl::source() { + return data_source(std::make_unique(_session)); +} + +data_sink tls::tls_connected_socket_impl::sink() { + return data_sink(std::make_unique(_session)); +} + +future tls::connect(shared_ptr cred, socket_address sa, sstring name) { + tls_options options{.server_name = std::move(name)}; + return connect(std::move(cred), std::move(sa), std::move(options)); +} + +future tls::connect(shared_ptr cred, socket_address sa, socket_address local, sstring name) { + tls_options options{.server_name = std::move(name)}; + return connect(std::move(cred), std::move(sa), std::move(local), std::move(options)); +} + +future tls::connect(shared_ptr cred, socket_address sa, tls_options options) { + return engine().connect(sa).then([cred = std::move(cred), options = std::move(options)](connected_socket s) mutable { + return wrap_client(std::move(cred), std::move(s), std::move(options)); + }); +} + +future tls::connect(shared_ptr cred, socket_address sa, socket_address local, tls_options options) { + return engine().connect(sa, local).then([cred = std::move(cred), options = std::move(options)](connected_socket s) mutable { + return wrap_client(std::move(cred), std::move(s), std::move(options)); + }); +} + +socket tls::socket(shared_ptr cred, sstring name) { + tls_options options{.server_name = std::move(name)}; + return tls::socket(std::move(cred), std::move(options)); +} + +socket tls::socket(shared_ptr cred, tls_options options) { + return ::seastar::socket(std::make_unique(std::move(cred), std::move(options))); +} + +server_socket tls::listen(shared_ptr creds, socket_address sa, listen_options opts) { + return listen(std::move(creds), seastar::listen(sa, opts)); +} + +server_socket tls::listen(shared_ptr creds, server_socket ss) { + server_socket ssls(std::make_unique(creds, std::move(ss))); + return server_socket(std::move(ssls)); +} + +static tls::tls_connected_socket_impl* get_tls_socket(connected_socket& socket) { + auto impl = net::get_impl::maybe_get_ptr(socket); + if (impl == nullptr) { + // the socket is not yet created or moved from + throw std::system_error(ENOTCONN, std::system_category()); + } + auto tls_impl = dynamic_cast(impl); + if (!tls_impl) { + // bad cast here means that we're dealing with wrong socket type + throw std::invalid_argument("Not a TLS socket"); + } + return tls_impl; +} + +future> tls::get_dn_information(connected_socket& socket) { + return get_tls_socket(socket)->get_distinguished_name(); +} + +future> tls::get_alt_name_information(connected_socket& socket, std::unordered_set types) { + return get_tls_socket(socket)->get_alt_name_information(std::move(types)); +} + +future tls::check_session_is_resumed(connected_socket& socket) { + return get_tls_socket(socket)->check_session_is_resumed(); +} + +future tls::get_session_resume_data(connected_socket& socket) { + return get_tls_socket(socket)->get_session_resume_data(); +} + +std::string_view tls::format_as(subject_alt_name_type type) { + switch (type) { + case subject_alt_name_type::dnsname: + return "DNS"; + case subject_alt_name_type::rfc822name: + return "EMAIL"; + case subject_alt_name_type::uri: + return "URI"; + case subject_alt_name_type::ipaddress: + return "IP"; + case subject_alt_name_type::othername: + return "OTHERNAME"; + case subject_alt_name_type::dn: + return "DIRNAME"; + default: + return "UNKNOWN"; + } +} + +std::string_view tls::format_as(tls_version version) { + switch(version) { + case tls::tls_version::tlsv1_0: + return "TLSv1.0"; + case tls::tls_version::tlsv1_1: + return "TLSv1.1"; + case tls::tls_version::tlsv1_2: + return "TLSv1.2"; + case tls::tls_version::tlsv1_3: + return "TLSv1.3"; + } + + __builtin_unreachable(); +} + +std::ostream& tls::operator<<(std::ostream& os, const tls_version & version) { + return os << format_as(version); +} + +std::ostream& tls::operator<<(std::ostream& os, subject_alt_name_type type) { + return os << format_as(type); +} + +std::ostream& tls::operator<<(std::ostream& os, const subject_alt_name::value_type& v) { + fmt::print(os, "{}", v); + return os; +} + +std::ostream& tls::operator<<(std::ostream& os, const subject_alt_name& a) { + fmt::print(os, "{}", a); + return os; +} + +} diff --git a/src/net/tls-impl.hh b/src/net/tls-impl.hh new file mode 100644 index 0000000000..e0ec99c4c3 --- /dev/null +++ b/src/net/tls-impl.hh @@ -0,0 +1,224 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2015 Cloudius Systems + */ +#pragma once + +#include +#include +#include +#include +#include +namespace seastar { + +class net::get_impl { +public: + static std::unique_ptr get(connected_socket s); + + static connected_socket_impl* maybe_get_ptr(connected_socket& s); +}; + +namespace tls { + +extern logger tls_log; + +class session_impl { +public: + virtual future<> put(net::packet) = 0; + virtual future<> flush() noexcept = 0; + virtual future> get() = 0; + virtual void close() = 0; + virtual future> get_distinguished_name() = 0; + virtual seastar::net::connected_socket_impl & socket() const = 0; + virtual future> get_alt_name_information(std::unordered_set) = 0; + virtual future is_resumed() = 0; + virtual future get_session_resume_data() = 0; +}; + +struct session_ref { + session_ref() = default; + session_ref(shared_ptr session) + : _session(std::move(session)) { + } + session_ref(session_ref&&) = default; + session_ref(const session_ref&) = default; + ~session_ref(); + + session_ref& operator=(session_ref&&) = default; + session_ref& operator=(const session_ref&) = default; + + shared_ptr _session; +}; + +class tls_connected_socket_impl : public net::connected_socket_impl, public session_ref { +public: + tls_connected_socket_impl(session_ref&& sess) + : session_ref(std::move(sess)) + {} + + class source_impl; + class sink_impl; + + using net::connected_socket_impl::source; + data_source source() override; + data_sink sink() override; + + void shutdown_input() override { + _session->close(); + } + void shutdown_output() override { + _session->close(); + } + void set_nodelay(bool nodelay) override { + _session->socket().set_nodelay(nodelay); + } + bool get_nodelay() const override { + return _session->socket().get_nodelay(); + } + void set_keepalive(bool keepalive) override { + _session->socket().set_keepalive(keepalive); + } + bool get_keepalive() const override { + return _session->socket().get_keepalive(); + } + void set_keepalive_parameters(const net::keepalive_params& p) override { + _session->socket().set_keepalive_parameters(p); + } + net::keepalive_params get_keepalive_parameters() const override { + return _session->socket().get_keepalive_parameters(); + } + void set_sockopt(int level, int optname, const void* data, size_t len) override { + _session->socket().set_sockopt(level, optname, data, len); + } + int get_sockopt(int level, int optname, void* data, size_t len) const override { + return _session->socket().get_sockopt(level, optname, data, len); + } + socket_address local_address() const noexcept override { + return _session->socket().local_address(); + } + socket_address remote_address() const noexcept override { + return _session->socket().remote_address(); + } + future> get_distinguished_name() { + return _session->get_distinguished_name(); + } + future> get_alt_name_information(std::unordered_set types) { + return _session->get_alt_name_information(std::move(types)); + } + future<> wait_input_shutdown() override { + return _session->socket().wait_input_shutdown(); + } + future check_session_is_resumed() { + return _session->is_resumed(); + } + future get_session_resume_data() { + return _session->get_session_resume_data(); + } +}; + + +class tls_connected_socket_impl::source_impl: public data_source_impl, public session_ref { +public: + using session_ref::session_ref; +private: + future> get() override { + return _session->get(); + } + future<> close() override { + _session->close(); + return make_ready_future<>(); + } +}; + +// Note: source/sink, and by extension, the in/out streams +// produced, cannot exist outside the direct life span of +// the connected_socket itself. This is consistent with +// other sockets in seastar, though I am than less fond of it... +class tls_connected_socket_impl::sink_impl: public data_sink_impl, public session_ref { +public: + using session_ref::session_ref; +private: + future<> flush() override { + return _session->flush(); + } + using data_sink_impl::put; + future<> put(net::packet p) override { + return _session->put(std::move(p)); + } + future<> close() override { + _session->close(); + return make_ready_future<>(); + } +}; + +class server_session : public net::server_socket_impl { +public: + server_session(shared_ptr creds, server_socket sock) + : _creds(std::move(creds)), _sock(std::move(sock)) { + } + future accept() override { + // We're not actually doing anything very SSL until we get + // an actual connection. Then we create a "server" session + // and wrap it up after handshaking. + return _sock.accept().then([this](accept_result ar) { + return wrap_server(_creds, std::move(ar.connection)).then([addr = std::move(ar.remote_address)](connected_socket s) { + return make_ready_future(accept_result{std::move(s), addr}); + }); + }); + } + void abort_accept() override { + _sock.abort_accept(); + } + socket_address local_address() const override { + return _sock.local_address(); + } +private: + + shared_ptr _creds; + server_socket _sock; +}; + +class tls_socket_impl : public net::socket_impl { + shared_ptr _cred; + tls_options _options; + ::seastar::socket _socket; +public: + tls_socket_impl(shared_ptr cred, tls_options options) + : _cred(cred), _options(std::move(options)), _socket(make_socket()) { + } + virtual future connect(socket_address sa, socket_address local, transport proto = transport::TCP) override { + return _socket.connect(sa, local, proto).then([cred = std::move(_cred), options = std::move(_options)](connected_socket s) mutable { + return wrap_client(cred, std::move(s), std::move(options)); + }); + } + void set_reuseaddr(bool reuseaddr) override { + _socket.set_reuseaddr(reuseaddr); + } + bool get_reuseaddr() const override { + return _socket.get_reuseaddr(); + } + virtual void shutdown() override { + _socket.shutdown(); + } +}; + +} // namespace tls + + +} diff --git a/src/net/tls.cc b/src/net/tls.cc index f1e3410d53..18e6c11207 100644 --- a/src/net/tls.cc +++ b/src/net/tls.cc @@ -58,27 +58,15 @@ module seastar; #include #include #include +#include #include #include #include +#include "net/tls-impl.hh" #endif namespace seastar { -class net::get_impl { -public: - static std::unique_ptr get(connected_socket s) { - return std::move(s._csi); - } - - static connected_socket_impl* maybe_get_ptr(connected_socket& s) { - if (s._csi) { - return s._csi.get(); - } - return nullptr; - } -}; - class blob_wrapper: public gnutls_datum_t { public: blob_wrapper(const tls::blob& in) @@ -110,43 +98,6 @@ class gnutlsobj { } }; -// Helper -struct file_info { - sstring filename; - std::chrono::system_clock::time_point modified; -}; - -struct file_result { - temporary_buffer buf; - file_info file; - operator temporary_buffer&&() && { - return std::move(buf); - } -}; - -static future read_fully(const sstring& name, const sstring& what) { - return open_file_dma(name, open_flags::ro).then([name = name](file f) mutable { - return do_with(std::move(f), [name = std::move(name)](file& f) mutable { - return f.stat().then([&f, name = std::move(name)](struct stat s) mutable { - return f.dma_read_bulk(0, s.st_size).then([s, name = std::move(name)](temporary_buffer buf) mutable { - return file_result{ std::move(buf), file_info{ - std::move(name), std::chrono::system_clock::from_time_t(s.st_mtim.tv_sec) + - std::chrono::duration_cast(std::chrono::nanoseconds(s.st_mtim.tv_nsec)) - } }; - }); - }).finally([&f]() { - return f.close(); - }); - }); - }).handle_exception([name = name, what = what](std::exception_ptr ep) -> future { - try { - std::rethrow_exception(std::move(ep)); - } catch (...) { - std::throw_with_nested(std::runtime_error(sstring("Could not read ") + what + " " + name)); - } - }); -} - // Note: we are not using gnutls++ interfaces, mainly because we // want to keep _our_ interface reasonably non-gnutls (well...) // and once we get to this level, their abstractions don't help @@ -200,6 +151,10 @@ static auto get_gtls_string = [](auto func, auto... args) noexcept { } +void tls::credentials_builder::set_priority_string(const sstring& prio) { + _priority = prio; +} + class tls::dh_params::impl : gnutlsobj { static gnutls_sec_param_t to_gnutls_level(level l) { switch (l) { @@ -284,13 +239,6 @@ tls::dh_params::~dh_params() { tls::dh_params::dh_params(dh_params&&) noexcept = default; tls::dh_params& tls::dh_params::operator=(dh_params&&) noexcept = default; -future tls::dh_params::from_file( - const sstring& filename, x509_crt_format fmt) { - return read_fully(filename, "dh parameters").then([fmt](temporary_buffer buf) { - return make_ready_future(dh_params(blob(buf.get()), fmt)); - }); -} - class tls::x509_cert::impl : gnutlsobj { public: impl() @@ -327,13 +275,6 @@ tls::x509_cert::x509_cert(const blob& b, x509_crt_format fmt) : x509_cert(::seastar::make_shared(b, fmt)) { } -future tls::x509_cert::from_file( - const sstring& filename, x509_crt_format fmt) { - return read_fully(filename, "x509 certificate").then([fmt](temporary_buffer buf) { - return make_ready_future(x509_cert(blob(buf.get()), fmt)); - }); -} - // wrapper for gnutls_datum, with raii free struct gnutls_datum : public gnutls_datum_t { gnutls_datum() { @@ -470,7 +411,7 @@ class tls::certificate_credentials::impl: public gnutlsobj { } private: - friend class credentials_builder; + friend class certificate_credentials; friend class session; bool need_load_system_trust() const { @@ -529,37 +470,6 @@ void tls::certificate_credentials::set_simple_pkcs12(const blob& b, _impl->set_simple_pkcs12(b, fmt, password); } -future<> tls::abstract_credentials::set_x509_trust_file( - const sstring& cafile, x509_crt_format fmt) { - return read_fully(cafile, "trust file").then([this, fmt](temporary_buffer buf) { - set_x509_trust(blob(buf.get(), buf.size()), fmt); - }); -} - -future<> tls::abstract_credentials::set_x509_crl_file( - const sstring& crlfile, x509_crt_format fmt) { - return read_fully(crlfile, "crl file").then([this, fmt](temporary_buffer buf) { - set_x509_crl(blob(buf.get(), buf.size()), fmt); - }); -} - -future<> tls::abstract_credentials::set_x509_key_file( - const sstring& cf, const sstring& kf, x509_crt_format fmt) { - return read_fully(cf, "certificate file").then([this, fmt, kf = kf](temporary_buffer buf) { - return read_fully(kf, "key file").then([this, fmt, buf = std::move(buf)](temporary_buffer buf2) { - set_x509_key(blob(buf.get(), buf.size()), blob(buf2.get(), buf2.size()), fmt); - }); - }); -} - -future<> tls::abstract_credentials::set_simple_pkcs12_file( - const sstring& pkcs12file, x509_crt_format fmt, - const sstring& password) { - return read_fully(pkcs12file, "pkcs12 file").then([this, fmt, password = password](temporary_buffer buf) { - set_simple_pkcs12(blob(buf.get(), buf.size()), fmt, password); - }); -} - future<> tls::certificate_credentials::set_system_trust() { return _impl->set_system_trust(); } @@ -576,6 +486,18 @@ void tls::certificate_credentials::set_enable_certificate_verification(bool enab _impl->set_enable_certificate_verification(enable); } +void tls::certificate_credentials::enable_load_system_trust() { + _impl->_load_system_trust = true; +} + +void tls::certificate_credentials::set_client_auth(client_auth ca) { + _impl->set_client_auth(ca); +} + +void tls::certificate_credentials::set_session_resume_mode(session_resume_mode m) { + _impl->set_session_resume_mode(m); +} + tls::server_credentials::server_credentials() #if GNUTLS_VERSION_NUMBER < 0x030600 : server_credentials(dh_params{}) @@ -603,452 +525,6 @@ void tls::server_credentials::set_session_resume_mode(session_resume_mode m) { } -static const sstring dh_level_key = "dh_level"; -static const sstring x509_trust_key = "x509_trust"; -static const sstring x509_crl_key = "x509_crl"; -static const sstring x509_key_key = "x509_key"; -static const sstring pkcs12_key = "pkcs12"; -static const sstring system_trust = "system_trust"; - -using buffer_type = std::basic_string>; - -struct x509_simple { - buffer_type data; - tls::x509_crt_format format; - file_info file; -}; - -struct x509_key { - buffer_type cert; - buffer_type key; - tls::x509_crt_format format; - file_info cert_file; - file_info key_file; -}; - -struct pkcs12_simple { - buffer_type data; - tls::x509_crt_format format; - sstring password; - file_info file; -}; - -void tls::credentials_builder::set_dh_level(dh_params::level level) { - _blobs.emplace(dh_level_key, level); -} - -void tls::credentials_builder::set_x509_trust(const blob& b, x509_crt_format fmt) { - _blobs.emplace(x509_trust_key, x509_simple{ std::string(b), fmt }); -} - -void tls::credentials_builder::set_x509_crl(const blob& b, x509_crt_format fmt) { - _blobs.emplace(x509_crl_key, x509_simple{ std::string(b), fmt }); -} - -void tls::credentials_builder::set_x509_key(const blob& cert, const blob& key, x509_crt_format fmt) { - _blobs.emplace(x509_key_key, x509_key { std::string(cert), std::string(key), fmt }); -} - -void tls::credentials_builder::set_simple_pkcs12(const blob& b, x509_crt_format fmt, const sstring& password) { - _blobs.emplace(pkcs12_key, pkcs12_simple{std::string(b), fmt, password }); -} - -static buffer_type to_buffer(const temporary_buffer& buf) { - return buffer_type(buf.get(), buf.get() + buf.size()); -} - -future<> tls::credentials_builder::set_x509_trust_file(const sstring& cafile, x509_crt_format fmt) { - return read_fully(cafile, "trust file").then([this, fmt](file_result f) { - _blobs.emplace(x509_trust_key, x509_simple{ to_buffer(f.buf), fmt, std::move(f.file) }); - }); -} - -future<> tls::credentials_builder::set_x509_crl_file(const sstring& crlfile, x509_crt_format fmt) { - return read_fully(crlfile, "crl file").then([this, fmt](file_result f) { - _blobs.emplace(x509_crl_key, x509_simple{ to_buffer(f.buf), fmt, std::move(f.file) }); - }); -} - -future<> tls::credentials_builder::set_x509_key_file(const sstring& cf, const sstring& kf, x509_crt_format fmt) { - return read_fully(cf, "certificate file").then([this, fmt, kf = kf](file_result cf) { - return read_fully(kf, "key file").then([this, fmt, cf = std::move(cf)](file_result kf) { - _blobs.emplace(x509_key_key, x509_key{ to_buffer(cf.buf), to_buffer(kf.buf), fmt, std::move(cf.file), std::move(kf.file) }); - }); - }); -} - -future<> tls::credentials_builder::set_simple_pkcs12_file(const sstring& pkcs12file, x509_crt_format fmt, const sstring& password) { - return read_fully(pkcs12file, "pkcs12 file").then([this, fmt, password = password](file_result f) { - _blobs.emplace(pkcs12_key, pkcs12_simple{ to_buffer(f.buf), fmt, password, std::move(f.file) }); - }); -} - -future<> tls::credentials_builder::set_system_trust() { - // TODO / Caveat: - // We cannot actually issue a loading of system trust here, - // because we have no actual tls context. - // And we probably _don't want to get into the guessing game - // of where the system trust cert chains are, since this is - // super distro dependent, and usually compiled into the library. - // Pretent it is raining, and just set a flag. - // Leave the function returning future, so if we change our - // minds and want to do explicit loading, we can... - _blobs.emplace(system_trust, true); - return make_ready_future(); -} - -void tls::credentials_builder::set_client_auth(client_auth auth) { - _client_auth = auth; -} - -void tls::credentials_builder::set_priority_string(const sstring& prio) { - _priority = prio; -} - -void tls::credentials_builder::set_session_resume_mode(session_resume_mode m) { - _session_resume_mode = m; -} - -template -static void visit_blobs(Blobs& blobs, Visitor&& visitor) { - auto visit = [&](const sstring& key, auto* vt) { - auto tr = blobs.equal_range(key); - for (auto& p : boost::make_iterator_range(tr.first, tr.second)) { - auto* v = std::any_cast>(&p.second); - visitor(key, *v); - } - }; - visit(x509_trust_key, static_cast(nullptr)); - visit(x509_crl_key, static_cast(nullptr)); - visit(x509_key_key, static_cast(nullptr)); - visit(pkcs12_key, static_cast(nullptr)); -} - -void tls::credentials_builder::apply_to(certificate_credentials& creds) const { - // Could potentially be templated down, but why bother... - visit_blobs(_blobs, make_visitor( - [&](const sstring& key, const x509_simple& info) { - if (key == x509_trust_key) { - creds.set_x509_trust(info.data, info.format); - } else if (key == x509_crl_key) { - creds.set_x509_crl(info.data, info.format); - } - }, - [&](const sstring&, const x509_key& info) { - creds.set_x509_key(info.cert, info.key, info.format); - }, - [&](const sstring&, const pkcs12_simple& info) { - creds.set_simple_pkcs12(info.data, info.format, info.password); - } - )); - - // TODO / Caveat: - // We cannot do this immediately, because we are not a continuation, and - // potentially blocking calls are a no-no. - // Doing this detached would be indeterministic, so set a flag in - // credentials, and do actual loading in first handshake (see session) - if (_blobs.count(system_trust)) { - creds._impl->_load_system_trust = true; - } - - if (!_priority.empty()) { - creds.set_priority_string(_priority); - } - - creds._impl->set_client_auth(_client_auth); - // Note: this causes server session key rotation on cert reload - creds._impl->set_session_resume_mode(_session_resume_mode); -} - -shared_ptr tls::credentials_builder::build_certificate_credentials() const { - auto creds = make_shared(); - apply_to(*creds); - return creds; -} - -shared_ptr tls::credentials_builder::build_server_credentials() const { - auto i = _blobs.find(dh_level_key); - if (i == _blobs.end()) { -#if GNUTLS_VERSION_NUMBER < 0x030600 - throw std::invalid_argument("No DH level set"); -#else - auto creds = make_shared(); - apply_to(*creds); - return creds; -#endif - } - auto creds = make_shared(dh_params(std::any_cast(i->second))); - apply_to(*creds); - return creds; -} - -using namespace std::chrono_literals; - -class tls::reloadable_credentials_base { -public: - using delay_type = std::chrono::milliseconds; - static inline constexpr delay_type default_tolerance = 500ms; - - class reloading_builder - : public credentials_builder - , public enable_shared_from_this - { - public: - using time_point = std::chrono::system_clock::time_point; - - reloading_builder(credentials_builder b, reload_callback cb, reloadable_credentials_base* creds, delay_type delay) - : credentials_builder(std::move(b)) - , _cb(std::move(cb)) - , _creds(creds) - , _delay(delay) - {} - future<> init() { - std::vector> futures; - visit_blobs(_blobs, make_visitor( - [&](const sstring&, const x509_simple& info) { - _all_files.emplace(info.file.filename); - }, - [&](const sstring&, const x509_key& info) { - _all_files.emplace(info.cert_file.filename); - _all_files.emplace(info.key_file.filename); - }, - [&](const sstring&, const pkcs12_simple& info) { - _all_files.emplace(info.file.filename); - } - )); - return parallel_for_each(_all_files, [this](auto& f) { - if (!f.empty()) { - return add_watch(f).discard_result(); - } - return make_ready_future<>(); - }).finally([me = shared_from_this()] {}); - } - void start() { - // run the loop in a thread. makes code almost readable. - (void)async(std::bind(&reloading_builder::run, this)).finally([me = shared_from_this()] {}); - } - void run() { - while (_creds) { - try { - auto events = _fsn.wait().get(); - if (events.empty() && _creds == nullptr) { - return; - } - rebuild(events); - _timer.cancel(); - } catch (...) { - if (!_timer.armed()) { - _timer.set_callback([this, ep = std::current_exception()]() mutable { - do_callback(std::move(ep)); - }); - _timer.arm(_delay); - } - } - } - } - void detach() { - _creds = nullptr; - _cb = {}; - _fsn.shutdown(); - _timer.cancel(); - } - private: - using fsnotifier = experimental::fsnotifier; - - // called from seastar::thread - void rebuild(const std::vector& events) { - for (auto& e : events) { - // don't use at. We could be getting two events for - // same watch (mod + delete), but we only need to care - // about one... - auto i = _watches.find(e.id); - if (i != _watches.end()) { - auto& filename = i->second.second; - // only add actual file watches to - // query set. If this was a directory - // watch, the file should already be - // in there. - if (_all_files.count(filename)) { - _files[filename] = e.mask; - } - _watches.erase(i); - } - } - auto num_changed = 0; - - auto maybe_reload = [&](const sstring& filename, buffer_type& dst) { - if (filename.empty() || !_files.count(filename)) { - return; - } - // #756 - // first, add a watch to nearest parent dir we - // can find. If user deleted folders, we could end - // up looking at modifications to root. - // The idea is that should adding a watch to actual file - // fail (deleted file/folder), we wait for changes to closest - // parent. When this happens, we will retry all files - // that have not been successfully replaced (and maybe more), - // repeating the process. At some point, we hopefully - // get new, current data. - add_dir_watch(filename); - // #756 add watch _first_. File could change while we are - // reading this. - try { - add_watch(filename).get(); - } catch (...) { - // let's just assume if this happens, it's because the file or folder was deleted. - // just ignore for now, and hope the dir watch will tell us when it is back... - return; - } - temporary_buffer buf = read_fully(filename, "reloading").get(); - dst = to_buffer(buf); - ++num_changed; - }; - visit_blobs(_blobs, make_visitor( - [&](const sstring&, x509_simple& info) { - maybe_reload(info.file.filename, info.data); - }, - [&](const sstring&, x509_key& info) { - maybe_reload(info.cert_file.filename, info.cert); - maybe_reload(info.key_file.filename, info.key); - }, - [&](const sstring&, pkcs12_simple& info) { - maybe_reload(info.file.filename, info.data); - } - )); - // only try this if anything was in fact successfully loaded. - // if files were missing, or pairs incomplete, we can just skip. - if (num_changed == 0) { - return; - } - try { - if (_creds) { - _creds->rebuild(*this); - } - } catch (...) { - if (std::any_of(_files.begin(), _files.end(), [](auto& p) { return p.second == fsnotifier::flags::ignored; })) { - // if any file in the reload set was deleted - i.e. we have not seen a "closed" yet - assume - // this is a spurious reload and we'd better wait for next event - hopefully a "closed" - - // and try again - return; - } - throw; - } - // if we got here, all files loaded, all watches were created, - // and gnutls was ok with the content. success. - do_callback(); - on_success(); - } - void on_success() { - _files.clear(); - // remove all directory watches, since we've successfully - // reloaded -> the file watches themselves should suffice now - auto i = _watches.begin(); - auto e = _watches.end(); - while (i != e) { - if (!_all_files.count(i->second.second)) { - i = _watches.erase(i); - continue; - } - ++i; - } - } - void do_callback(std::exception_ptr ep = {}) { - if (_cb && !_files.empty()) { - _cb(boost::copy_range>(_files | boost::adaptors::map_keys), std::move(ep)); - } - } - // called from seastar::thread - fsnotifier::watch_token add_dir_watch(const sstring& filename) { - auto dir = std::filesystem::path(filename).parent_path(); - for (;;) { - try { - return add_watch(dir.native(), fsnotifier::flags::create_child | fsnotifier::flags::move).get(); - } catch (...) { - auto parent = dir.parent_path(); - if (parent.empty() || dir == parent) { - throw; - } - dir = std::move(parent); - continue; - } - } - } - future add_watch(const sstring& filename, fsnotifier::flags flags = fsnotifier::flags::close_write|fsnotifier::flags::delete_self) { - return _fsn.create_watch(filename, flags).then([this, filename = filename](fsnotifier::watch w) { - auto t = w.token(); - // we might create multiple watches for same token in case of dirs, avoid deleting previously - // created one - if (_watches.count(t)) { - w.release(); - } else { - _watches.emplace(t, std::make_pair(std::move(w), filename)); - } - return t; - }); - } - - reload_callback _cb; - reloadable_credentials_base* _creds; - fsnotifier _fsn; - std::unordered_map> _watches; - std::unordered_map _files; - std::unordered_set _all_files; - timer<> _timer; - delay_type _delay; - }; - reloadable_credentials_base(credentials_builder builder, reload_callback cb, delay_type delay = default_tolerance) - : _builder(seastar::make_shared(std::move(builder), std::move(cb), this, delay)) - { - _builder->start(); - } - future<> init() { - return _builder->init(); - } - virtual ~reloadable_credentials_base() { - _builder->detach(); - } - virtual void rebuild(const credentials_builder&) = 0; -private: - shared_ptr _builder; -}; - -template -class tls::reloadable_credentials : public Base, public tls::reloadable_credentials_base { -public: - reloadable_credentials(credentials_builder builder, reload_callback cb, Base b, delay_type delay = default_tolerance) - : Base(std::move(b)) - , tls::reloadable_credentials_base(std::move(builder), std::move(cb), delay) - {} - void rebuild(const credentials_builder&) override; -}; - -template<> -void tls::reloadable_credentials::rebuild(const credentials_builder& builder) { - auto tmp = builder.build_certificate_credentials(); - this->_impl = std::move(tmp->_impl); -} - -template<> -void tls::reloadable_credentials::rebuild(const credentials_builder& builder) { - auto tmp = builder.build_server_credentials(); - this->_impl = std::move(tmp->_impl); -} - -future> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback cb, std::optional tolerance) const { - auto creds = seastar::make_shared>(*this, std::move(cb), std::move(*build_certificate_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance)); - return creds->init().then([creds] { - return make_ready_future>(creds); - }); -} - -future> tls::credentials_builder::build_reloadable_server_credentials(reload_callback cb, std::optional tolerance) const { - auto creds = seastar::make_shared>(*this, std::move(cb), std::move(*build_server_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance)); - return creds->init().then([creds] { - return make_ready_future>(creds); - }); -} - namespace tls { /** @@ -1060,7 +536,7 @@ namespace tls { * of these, since we handle handshake etc. * */ -class session : public enable_lw_shared_from_this { +class session : public enable_shared_from_this, public session_impl { public: enum class type : uint32_t { @@ -1138,6 +614,10 @@ class session : public enable_lw_shared_from_this { assert(_output_pending.available()); } + const char * get_type_string() const { + return _type == type::CLIENT ? "Client": "Server"; + } + typedef temporary_buffer buf_type; sstring cert_status_to_string(gnutls_certificate_type_t type, unsigned int status) { @@ -1823,224 +1303,6 @@ class session : public enable_lw_shared_from_this { std::unique_ptr, void(*)(gnutls_session_t)> _session; }; -struct session::session_ref { - session_ref() = default; - session_ref(lw_shared_ptr session) - : _session(std::move(session)) { - } - session_ref(session_ref&&) = default; - session_ref(const session_ref&) = default; - ~session_ref() { - // This is not super pretty. But we take some care to only own sessions - // through session_ref, and we need to initiate shutdown on "last owner", - // since we cannot revive the session in destructor. - if (_session && _session.use_count() == 1) { - _session->close(); - } - } - - session_ref& operator=(session_ref&&) = default; - session_ref& operator=(const session_ref&) = default; - - lw_shared_ptr _session; -}; - -class tls_connected_socket_impl : public net::connected_socket_impl, public session::session_ref { -public: - tls_connected_socket_impl(session_ref&& sess) - : session_ref(std::move(sess)) - {} - - class source_impl; - class sink_impl; - - using net::connected_socket_impl::source; - data_source source() override; - data_sink sink() override; - - void shutdown_input() override { - _session->close(); - } - void shutdown_output() override { - _session->close(); - } - void set_nodelay(bool nodelay) override { - _session->socket().set_nodelay(nodelay); - } - bool get_nodelay() const override { - return _session->socket().get_nodelay(); - } - void set_keepalive(bool keepalive) override { - _session->socket().set_keepalive(keepalive); - } - bool get_keepalive() const override { - return _session->socket().get_keepalive(); - } - void set_keepalive_parameters(const net::keepalive_params& p) override { - _session->socket().set_keepalive_parameters(p); - } - net::keepalive_params get_keepalive_parameters() const override { - return _session->socket().get_keepalive_parameters(); - } - void set_sockopt(int level, int optname, const void* data, size_t len) override { - _session->socket().set_sockopt(level, optname, data, len); - } - int get_sockopt(int level, int optname, void* data, size_t len) const override { - return _session->socket().get_sockopt(level, optname, data, len); - } - socket_address local_address() const noexcept override { - return _session->socket().local_address(); - } - socket_address remote_address() const noexcept override { - return _session->socket().remote_address(); - } - future> get_distinguished_name() { - return _session->get_distinguished_name(); - } - future> get_alt_name_information(std::unordered_set types) { - return _session->get_alt_name_information(std::move(types)); - } - future<> wait_input_shutdown() override { - return _session->socket().wait_input_shutdown(); - } - future check_session_is_resumed() { - return _session->is_resumed(); - } - future get_session_resume_data() { - return _session->get_session_resume_data(); - } -}; - - -class tls_connected_socket_impl::source_impl: public data_source_impl, public session::session_ref { -public: - using session_ref::session_ref; -private: - future> get() override { - return _session->get(); - } - future<> close() override { - _session->close(); - return make_ready_future<>(); - } -}; - -// Note: source/sink, and by extension, the in/out streams -// produced, cannot exist outside the direct life span of -// the connected_socket itself. This is consistent with -// other sockets in seastar, though I am than less fond of it... -class tls_connected_socket_impl::sink_impl: public data_sink_impl, public session::session_ref { -public: - using session_ref::session_ref; -private: - future<> flush() override { - return _session->flush(); - } - using data_sink_impl::put; - future<> put(net::packet p) override { - return _session->put(std::move(p)); - } - future<> close() override { - _session->close(); - return make_ready_future<>(); - } - bool can_batch_flushes() const noexcept override { return true; } - void on_batch_flush_error() noexcept override { - _session->close(); - } -}; - -class server_session : public net::server_socket_impl { -public: - server_session(shared_ptr creds, server_socket sock) - : _creds(std::move(creds)), _sock(std::move(sock)) { - } - future accept() override { - // We're not actually doing anything very SSL until we get - // an actual connection. Then we create a "server" session - // and wrap it up after handshaking. - return _sock.accept().then([this](accept_result ar) { - return wrap_server(_creds, std::move(ar.connection)).then([addr = std::move(ar.remote_address)](connected_socket s) { - return make_ready_future(accept_result{std::move(s), addr}); - }); - }); - } - void abort_accept() override { - _sock.abort_accept(); - } - socket_address local_address() const override { - return _sock.local_address(); - } -private: - - shared_ptr _creds; - server_socket _sock; -}; - -class tls_socket_impl : public net::socket_impl { - shared_ptr _cred; - tls_options _options; - ::seastar::socket _socket; -public: - tls_socket_impl(shared_ptr cred, tls_options options) - : _cred(cred), _options(std::move(options)), _socket(make_socket()) { - } - virtual future connect(socket_address sa, socket_address local, transport proto = transport::TCP) override { - return _socket.connect(sa, local, proto).then([cred = std::move(_cred), options = std::move(_options)](connected_socket s) mutable { - return wrap_client(cred, std::move(s), std::move(options)); - }); - } - void set_reuseaddr(bool reuseaddr) override { - _socket.set_reuseaddr(reuseaddr); - } - bool get_reuseaddr() const override { - return _socket.get_reuseaddr(); - } - virtual void shutdown() override { - _socket.shutdown(); - } -}; - -} - -data_source tls::tls_connected_socket_impl::source() { - return data_source(std::make_unique(_session)); -} - -data_sink tls::tls_connected_socket_impl::sink() { - return data_sink(std::make_unique(_session)); -} - - -future tls::connect(shared_ptr cred, socket_address sa, sstring name) { - tls_options options{.server_name = std::move(name)}; - return connect(std::move(cred), std::move(sa), std::move(options)); -} - -future tls::connect(shared_ptr cred, socket_address sa, socket_address local, sstring name) { - tls_options options{.server_name = std::move(name)}; - return connect(std::move(cred), std::move(sa), std::move(local), std::move(options)); -} - -future tls::connect(shared_ptr cred, socket_address sa, tls_options options) { - return engine().connect(sa).then([cred = std::move(cred), options = std::move(options)](connected_socket s) mutable { - return wrap_client(std::move(cred), std::move(s), std::move(options)); - }); -} - -future tls::connect(shared_ptr cred, socket_address sa, socket_address local, tls_options options) { - return engine().connect(sa, local).then([cred = std::move(cred), options = std::move(options)](connected_socket s) mutable { - return wrap_client(std::move(cred), std::move(s), std::move(options)); - }); -} - -socket tls::socket(shared_ptr cred, sstring name) { - tls_options options{.server_name = std::move(name)}; - return tls::socket(std::move(cred), std::move(options)); -} - -socket tls::socket(shared_ptr cred, tls_options options) { - return ::seastar::socket(std::make_unique(std::move(cred), std::move(options))); } future tls::wrap_client(shared_ptr cred, connected_socket&& s, sstring name) { @@ -2049,90 +1311,17 @@ future tls::wrap_client(shared_ptr cr } future tls::wrap_client(shared_ptr cred, connected_socket&& s, tls_options options) { - session::session_ref sess(make_lw_shared(session::type::CLIENT, std::move(cred), std::move(s), options)); + session_ref sess(seastar::make_shared(session::type::CLIENT, std::move(cred), std::move(s), options)); connected_socket sock(std::make_unique(std::move(sess))); return make_ready_future(std::move(sock)); } future tls::wrap_server(shared_ptr cred, connected_socket&& s) { - session::session_ref sess(make_lw_shared(session::type::SERVER, std::move(cred), std::move(s))); + session_ref sess(seastar::make_shared(session::type::SERVER, std::move(cred), std::move(s))); connected_socket sock(std::make_unique(std::move(sess))); return make_ready_future(std::move(sock)); } -server_socket tls::listen(shared_ptr creds, socket_address sa, listen_options opts) { - return listen(std::move(creds), seastar::listen(sa, opts)); -} - -server_socket tls::listen(shared_ptr creds, server_socket ss) { - server_socket ssls(std::make_unique(creds, std::move(ss))); - return server_socket(std::move(ssls)); -} - -static tls::tls_connected_socket_impl* get_tls_socket(connected_socket& socket) { - auto impl = net::get_impl::maybe_get_ptr(socket); - if (impl == nullptr) { - // the socket is not yet created or moved from - throw std::system_error(ENOTCONN, std::system_category()); - } - auto tls_impl = dynamic_cast(impl); - if (!tls_impl) { - // bad cast here means that we're dealing with wrong socket type - throw std::invalid_argument("Not a TLS socket"); - } - return tls_impl; -} - -future> tls::get_dn_information(connected_socket& socket) { - return get_tls_socket(socket)->get_distinguished_name(); -} - -future> tls::get_alt_name_information(connected_socket& socket, std::unordered_set types) { - return get_tls_socket(socket)->get_alt_name_information(std::move(types)); -} - -future tls::check_session_is_resumed(connected_socket& socket) { - return get_tls_socket(socket)->check_session_is_resumed(); -} - -future tls::get_session_resume_data(connected_socket& socket) { - return get_tls_socket(socket)->get_session_resume_data(); -} - -std::string_view tls::format_as(subject_alt_name_type type) { - switch (type) { - case subject_alt_name_type::dnsname: - return "DNS"; - case subject_alt_name_type::rfc822name: - return "EMAIL"; - case subject_alt_name_type::uri: - return "URI"; - case subject_alt_name_type::ipaddress: - return "IP"; - case subject_alt_name_type::othername: - return "OTHERNAME"; - case subject_alt_name_type::dn: - return "DIRNAME"; - default: - return "UNKNOWN"; - } -} - -std::ostream& tls::operator<<(std::ostream& os, subject_alt_name_type type) { - return os << format_as(type); -} - -std::ostream& tls::operator<<(std::ostream& os, const subject_alt_name::value_type& v) { - fmt::print(os, "{}", v); - return os; -} - -std::ostream& tls::operator<<(std::ostream& os, const subject_alt_name& a) { - fmt::print(os, "{}", a); - return os; -} - - } const int seastar::tls::ERROR_UNKNOWN_COMPRESSION_ALGORITHM = GNUTLS_E_UNKNOWN_COMPRESSION_ALGORITHM; diff --git a/src/seastar.cc b/src/seastar.cc index 29295518f3..3275b48f08 100644 --- a/src/seastar.cc +++ b/src/seastar.cc @@ -120,7 +120,11 @@ module; #include #include #include +#ifdef SEASTAR_USE_OPENSSL +#include +#else #include +#endif #ifdef SEASTAR_HAVE_HWLOC #include #endif @@ -335,6 +339,7 @@ module : private; #include #include "net/native-stack-impl.hh" +#include "net/tls-impl.hh" #include #include diff --git a/src/websocket/base64-gnutls.cc b/src/websocket/base64-gnutls.cc new file mode 100644 index 0000000000..87ecb2d4b8 --- /dev/null +++ b/src/websocket/base64-gnutls.cc @@ -0,0 +1,51 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2015 Cloudius Systems + */ + +#include +#include + +#include "websocket/base64.hh" + +#include +#include + +namespace seastar::experimental::websocket { +std::string sha1_base64(std::string_view source) { + unsigned char hash[20]; + assert(sizeof(hash) == gnutls_hash_get_len(GNUTLS_DIG_SHA1)); + if (int ret = gnutls_hash_fast(GNUTLS_DIG_SHA1, source.data(), source.size(), hash); + ret != GNUTLS_E_SUCCESS) { + throw websocket::exception(fmt::format("gnutls_hash_fast: {}", gnutls_strerror(ret))); + } + gnutls_datum_t hash_data{ + .data = hash, + .size = sizeof(hash), + }; + gnutls_datum_t base64_encoded; + if (int ret = gnutls_base64_encode2(&hash_data, &base64_encoded); + ret != GNUTLS_E_SUCCESS) { + throw websocket::exception(fmt::format("gnutls_base64_encode2: {}", gnutls_strerror(ret))); + } + auto free_base64_encoded = defer([&] () noexcept { gnutls_free(base64_encoded.data); }); + // base64_encoded.data is "unsigned char *" + return std::string(reinterpret_cast(base64_encoded.data), base64_encoded.size); +} +} diff --git a/src/websocket/base64-openssl.cc b/src/websocket/base64-openssl.cc new file mode 100644 index 0000000000..f5e24719f2 --- /dev/null +++ b/src/websocket/base64-openssl.cc @@ -0,0 +1,58 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2015 Cloudius Systems + */ + +#include +#include +#include + +#include "websocket/base64.hh" + +#include + +namespace seastar::experimental::websocket { + +std::string sha1_base64(std::string_view source) { + unsigned char hash[20]; + + const auto encode_capacity = [](size_t input_size) { + return (((4 * input_size) / 3) + 3) & ~0x3U; + }; + unsigned int hash_size = sizeof(hash); + auto md_ptr = EVP_MD_fetch(nullptr, "SHA1", nullptr); + if (!md_ptr) { + throw websocket::exception("Failed to fetch SHA-1 algorithm from OpenSSL"); + } + + auto free_evp_md_ptr = + defer([&]() noexcept { EVP_MD_free(md_ptr); }); + + assert(hash_size == static_cast(EVP_MD_get_size(md_ptr))); + + if (1 != EVP_Digest(source.data(), source.size(), hash, &hash_size, md_ptr, nullptr)) { + throw websocket::exception("Failed to perform SHA-1 digest in OpenSSL"); + } + + auto base64_encoded = uninitialized_string(encode_capacity(hash_size)); + EVP_EncodeBlock(reinterpret_cast(base64_encoded.data()), hash, hash_size); + return base64_encoded; +} + +} \ No newline at end of file diff --git a/src/websocket/base64.hh b/src/websocket/base64.hh new file mode 100644 index 0000000000..0875581ad5 --- /dev/null +++ b/src/websocket/base64.hh @@ -0,0 +1,30 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2015 Cloudius Systems + */ +#pragma once + +#include +#include + +namespace seastar::experimental::websocket { + +std::string sha1_base64(std::string_view source); + +} \ No newline at end of file diff --git a/src/websocket/server.cc b/src/websocket/server.cc index f9c26bfe29..d9dd2b4736 100644 --- a/src/websocket/server.cc +++ b/src/websocket/server.cc @@ -27,8 +27,8 @@ #include #include #include -#include -#include + +#include "websocket/base64.hh" namespace seastar::experimental::websocket { @@ -124,27 +124,6 @@ future<> connection::process() { }); } -static std::string sha1_base64(std::string_view source) { - unsigned char hash[20]; - assert(sizeof(hash) == gnutls_hash_get_len(GNUTLS_DIG_SHA1)); - if (int ret = gnutls_hash_fast(GNUTLS_DIG_SHA1, source.data(), source.size(), hash); - ret != GNUTLS_E_SUCCESS) { - throw websocket::exception(fmt::format("gnutls_hash_fast: {}", gnutls_strerror(ret))); - } - gnutls_datum_t hash_data{ - .data = hash, - .size = sizeof(hash), - }; - gnutls_datum_t base64_encoded; - if (int ret = gnutls_base64_encode2(&hash_data, &base64_encoded); - ret != GNUTLS_E_SUCCESS) { - throw websocket::exception(fmt::format("gnutls_base64_encode2: {}", gnutls_strerror(ret))); - } - auto free_base64_encoded = defer([&] () noexcept { gnutls_free(base64_encoded.data); }); - // base64_encoded.data is "unsigned char *" - return std::string(reinterpret_cast(base64_encoded.data), base64_encoded.size); -} - future<> connection::read_http_upgrade_request() { _http_parser.init(); return _read_buf.consume(_http_parser).then([this] () mutable { diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 265babde45..7c2242e354 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -563,12 +563,11 @@ function(seastar_add_certgen name) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${CERT_CAROOT}" - COMMAND ${OPENSSL} req -x509 -new -nodes -key ${CERT_CAPRIVKEY} -days ${CERT_DAYS} -config ${CERT_NAME}.cfg -out ${CERT_CAROOT} + COMMAND ${OPENSSL} req -x509 -new -nodes -key ${CERT_CAPRIVKEY} -days ${CERT_DAYS} -config ${CERT_NAME}.cfg -out ${CERT_CAROOT} -extensions v3_ca DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/${CERT_CAPRIVKEY}" "${CMAKE_CURRENT_BINARY_DIR}/${CERT_NAME}.cfg" WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) - add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${CERT_CERT}" COMMAND ${OPENSSL} x509 -req -in ${CERT_REQ} -CA ${CERT_CAROOT} -CAkey ${CERT_CAPRIVKEY} -CAcreateserial -out ${CERT_CERT} -days ${CERT_DAYS} -extensions req_ext -extfile ${CERT_NAME}.cfg DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/${CERT_REQ}" "${CMAKE_CURRENT_BINARY_DIR}/${CERT_CAROOT}" "${CMAKE_CURRENT_BINARY_DIR}/${CERT_NAME}.cfg" @@ -651,7 +650,7 @@ function(seastar_gen_mtls_certs) seastar_sign_cert("mtls_server" CA ${cert_ca} SERIAL_NUM 1 - COMMON_NAME "redpanda.com") + COMMON_NAME "server.org") # client1 certificates seastar_sign_cert("mtls_client1" CA ${cert_ca} diff --git a/tests/unit/loopback_socket.hh b/tests/unit/loopback_socket.hh index 3b9df90188..d99e02ef0c 100644 --- a/tests/unit/loopback_socket.hh +++ b/tests/unit/loopback_socket.hh @@ -54,7 +54,7 @@ public: }; private: bool _aborted = false; - queue> _q{1}; + queue> _q{2}; loopback_error_injector* _error_injector; type _type; std::optional> _shutdown; diff --git a/tests/unit/tls_test.cc b/tests/unit/tls_test.cc index 431de2c73c..80d5a7651b 100644 --- a/tests/unit/tls_test.cc +++ b/tests/unit/tls_test.cc @@ -46,7 +46,9 @@ #include "loopback_socket.hh" #include "tmpdir.hh" +#ifdef SEASTAR_USE_GNUTLS #include +#endif #if 0 @@ -162,6 +164,16 @@ SEASTAR_TEST_CASE(test_x509_client_with_builder_system_trust_multiple) { } SEASTAR_TEST_CASE(test_x509_client_with_system_trust_and_priority_strings) { +#ifdef SEASTAR_USE_OPENSSL + static std::vector prios( { + "PSK-CHACHA20-POLY1305", + "DHE-PSK-AES128-GCM-SHA256", + "ECDHE-RSA-AES128-GCM-SHA256", + "RSA-PSK-AES128-CBC-SHA", + "ECDHE-ECDSA-AES256-GCM-SHA384", + "AES128-GCM-SHA256" + }); +#elif defined(SEASTAR_USE_GNUTLS) static std::vector prios( { "NORMAL:+ARCFOUR-128", // means normal ciphers plus ARCFOUR-128. "SECURE128:-VERS-SSL3.0:+COMP-DEFLATE", // means that only secure ciphers are enabled, SSL3.0 is disabled, and libz compression enabled. @@ -173,22 +185,48 @@ SEASTAR_TEST_CASE(test_x509_client_with_system_trust_and_priority_strings) { "SECURE128:-VERS-TLS1.0:+COMP-DEFLATE", "SECURE128:+SECURE192:-VERS-TLS-ALL:+VERS-TLS1.2" }); +#else +#error "Unknown cryptographic provider" +#endif return do_for_each(prios, [](const sstring & prio) { tls::credentials_builder b; (void)b.set_system_trust(); +#ifdef SEASTAR_USE_OPENSSL + b.set_cipher_string(prio); +#elif defined(SEASTAR_USE_GNUTLS) b.set_priority_string(prio); +#else +#error "Unknown cryptographic provider" +#endif return connect_to_ssl_google(b.build_certificate_credentials()); }); } SEASTAR_TEST_CASE(test_x509_client_with_system_trust_and_priority_strings_fail) { +#ifdef SEASTAR_USE_OPENSSL + static std::vector prios( { + "RSA-MD5-AES256-CBC-SHA" + }); +#elif defined(SEASTAR_USE_GNUTLS) static std::vector prios( { "NONE", "NONE:+CURVE-SECP256R1" }); +#else +#error "Unknown cryptographic provider" +#endif return do_for_each(prios, [](const sstring & prio) { tls::credentials_builder b; (void)b.set_system_trust(); + +#ifdef SEASTAR_USE_OPENSSL + b.set_cipher_string(prio); + b.set_minimum_tls_version(tls::tls_version::tlsv1_0); + b.set_maximum_tls_version(tls::tls_version::tlsv1_1); +#elif defined(SEASTAR_USE_GNUTLS) b.set_priority_string(prio); +#else +#error "Unknown cryptographic provider" +#endif try { return connect_to_ssl_google(b.build_certificate_credentials()).then([] { BOOST_FAIL("Expected exception"); @@ -299,6 +337,16 @@ SEASTAR_THREAD_TEST_CASE(test_x509_client_with_builder_multiple) { } SEASTAR_THREAD_TEST_CASE(test_x509_client_with_priority_strings) { +#ifdef SEASTAR_USE_OPENSSL + static std::vector prios( { + "PSK-CHACHA20-POLY1305", + "DHE-PSK-AES128-GCM-SHA256", + "ECDHE-RSA-AES128-GCM-SHA256", + "RSA-PSK-AES128-CBC-SHA", + "ECDHE-ECDSA-AES256-GCM-SHA384", + "AES128-GCM-SHA256" + }); +#elif defined(SEASTAR_USE_GNUTLS) static std::vector prios( { "NORMAL:+ARCFOUR-128", // means normal ciphers plus ARCFOUR-128. "SECURE128:-VERS-SSL3.0:+COMP-DEFLATE", // means that only secure ciphers are enabled, SSL3.0 is disabled, and libz compression enabled. @@ -310,26 +358,126 @@ SEASTAR_THREAD_TEST_CASE(test_x509_client_with_priority_strings) { "SECURE128:-VERS-TLS1.0:+COMP-DEFLATE", "SECURE128:+SECURE192:-VERS-TLS-ALL:+VERS-TLS1.2" }); +#else +#error "Unknown cryptographic provider" +#endif tls::credentials_builder b; https_server server; b.set_x509_trust_file(server.cert(), tls::x509_crt_format::PEM).get(); auto addr = server.addr(); do_for_each(prios, [&b, addr](const sstring& prio) { +#ifdef SEASTAR_USE_OPENSSL + b.set_cipher_string(prio); +#elif defined(SEASTAR_USE_GNUTLS) b.set_priority_string(prio); +#else +#error "Unknown cryptographic provider" +#endif return connect_to_ssl_addr(b.build_certificate_credentials(), addr); }).get(); } SEASTAR_THREAD_TEST_CASE(test_x509_client_with_priority_strings_fail) { +#ifdef SEASTAR_USE_OPENSSL + static std::vector prios( { + "RSA-MD5-AES256-CBC-SHA" + }); +#elif defined(SEASTAR_USE_GNUTLS) static std::vector prios( { "NONE", "NONE:+CURVE-SECP256R1" }); +#else +#error "Unknown cryptographic provider" +#endif + tls::credentials_builder b; + https_server server; + b.set_x509_trust_file(server.cert(), tls::x509_crt_format::PEM).get(); + auto addr = server.addr(); + do_for_each(prios, [&b, addr](const sstring& prio) { +#ifdef SEASTAR_USE_OPENSSL + b.set_cipher_string(prio); + b.set_minimum_tls_version(tls::tls_version::tlsv1_0); + b.set_maximum_tls_version(tls::tls_version::tlsv1_1); +#elif defined(SEASTAR_USE_GNUTLS) + b.set_priority_string(prio); +#else +#error "Unknown cryptographic provider" +#endif + try { + return connect_to_ssl_addr(b.build_certificate_credentials(), addr).then([] { + BOOST_FAIL("Expected exception"); + }).handle_exception([](auto ep) { + // ok. + }); + } catch (...) { + // also ok + } + return make_ready_future<>(); + }).get(); +} + +SEASTAR_THREAD_TEST_CASE(test_x509_client_tls13) { +#ifdef SEASTAR_USE_OPENSSL + static std::vector prios({ + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256", + }); +#elif defined(SEASTAR_USE_GNUTLS) + static std::vector prios({ + "NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL:+AES-128-GCM", + "NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL:+AES-256-GCM", + "NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL:+CHACHA20-POLY1305" + }); +#else +#error "Unknown cryptographic provider" +#endif tls::credentials_builder b; https_server server; b.set_x509_trust_file(server.cert(), tls::x509_crt_format::PEM).get(); auto addr = server.addr(); do_for_each(prios, [&b, addr](const sstring& prio) { + BOOST_TEST_CHECKPOINT("Checking priority string " << prio); +#ifdef SEASTAR_USE_OPENSSL + b.set_ciphersuites(prio); + b.set_minimum_tls_version(tls::tls_version::tlsv1_3); + b.set_maximum_tls_version(tls::tls_version::tlsv1_3); +#elif defined(SEASTAR_USE_GNUTLS) b.set_priority_string(prio); +#else +#error "Unknown cryptographic provider" +#endif + return connect_to_ssl_addr(b.build_certificate_credentials(), addr); + }).get(); +} + +SEASTAR_THREAD_TEST_CASE(test_x509_client_tls13_fail) { +#ifdef SEASTAR_USE_OPENSSL + static std::vector prios({ + "TLS_AES_128_CCM_8_SHA256" + }); +#elif defined(SEASTAR_USE_GNUTLS) + static std::vector prios({ + "NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL:+AES-128-CCM-8" + }); +#else +#error "Unknown cryptographic provider" +#endif + tls::credentials_builder b; + https_server server; + b.set_x509_trust_file(server.cert(), tls::x509_crt_format::PEM).get(); + auto addr = server.addr(); + do_for_each(prios, [&b, addr](const sstring& prio) { + BOOST_TEST_CHECKPOINT("Checking priority string " << prio); +#ifdef SEASTAR_USE_OPENSSL + b.set_ciphersuites(prio); + b.set_minimum_tls_version(tls::tls_version::tlsv1_3); + b.set_maximum_tls_version(tls::tls_version::tlsv1_3); +#elif defined(SEASTAR_USE_GNUTLS) + b.set_priority_string(prio); +#else +#error "Unknown cryptographic provider" +#endif try { return connect_to_ssl_addr(b.build_certificate_credentials(), addr).then([] { BOOST_FAIL("Expected exception"); @@ -654,7 +802,7 @@ SEASTAR_TEST_CASE(test_simple_x509_client_server_again) { return run_echo_test(message, 20, certfile("catest.pem"), "test.scylladb.org"); } -#if GNUTLS_VERSION_NUMBER >= 0x030600 +#if GNUTLS_VERSION_NUMBER >= 0x030600 || SEASTAR_USE_OPENSSL // Test #769 - do not set dh_params in server certs - let gnutls negotiate. SEASTAR_TEST_CASE(test_simple_server_default_dhparams) { return run_echo_test(message, 20, certfile("catest.pem"), "test.scylladb.org", @@ -920,8 +1068,12 @@ SEASTAR_THREAD_TEST_CASE(test_reload_certificates) { } try { - f2.get(); - BOOST_FAIL("should not reach"); + auto res = f2.get(); + // If the server completes sending data to the client + // during the handshake before the client has fully + // closed its connection, then the get() call will + // succeed by return an empty buffer indicating EOF + BOOST_REQUIRE(res.size() == 0); } catch (...) { // ok } @@ -1348,6 +1500,10 @@ SEASTAR_THREAD_TEST_CASE(test_dn_name_handling) { fout.get(); auto dn = fdn.get(); + BOOST_REQUIRE(dn.has_value()); + BOOST_REQUIRE_EQUAL(dn->subject, fmt::format("C=GB,ST=London,L=London,O=Redpanda Data,OU=Core,CN={}", id)); + BOOST_REQUIRE_EQUAL(dn->issuer, "C=GB,ST=London,L=London,O=Redpanda Data,OU=Core,CN=redpanda.com"); + auto client_id = fin.get(); in.close().get(); @@ -1499,7 +1655,11 @@ SEASTAR_THREAD_TEST_CASE(test_tls13_session_tickets) { b.set_x509_key_file(certfile("test.crt"), certfile("test.key"), tls::x509_crt_format::PEM).get(); b.set_x509_trust_file(certfile("catest.pem"), tls::x509_crt_format::PEM).get(); b.set_session_resume_mode(tls::session_resume_mode::TLS13_SESSION_TICKET); +#ifdef SEASTAR_USE_OPENSSL + b.set_minimum_tls_version(tls::tls_version::tlsv1_3); +#else b.set_priority_string("SECURE128:+SECURE192:-VERS-TLS-ALL:+VERS-TLS1.3"); +#endif auto creds = b.build_certificate_credentials(); auto serv = b.build_server_credentials();