From 724ff214471012ca76544bec6fa8690c622cb864 Mon Sep 17 00:00:00 2001 From: anon Date: Mon, 28 Jun 2021 19:13:02 +0000 Subject: [PATCH 1/3] connection: add segfault and deadlocks demo --- tests/unit_tests/epee_boosted_tcp_server.cpp | 254 ++++++++++++++++++- 1 file changed, 252 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/epee_boosted_tcp_server.cpp b/tests/unit_tests/epee_boosted_tcp_server.cpp index 54d27be1b..d64431edf 100644 --- a/tests/unit_tests/epee_boosted_tcp_server.cpp +++ b/tests/unit_tests/epee_boosted_tcp_server.cpp @@ -31,6 +31,8 @@ #include #include #include +#include +#include #include "gtest/gtest.h" @@ -276,6 +278,11 @@ TEST(test_epee_connection, test_lifetime) ASSERT_TRUE(shared_state->get_connections_count() == 0); constexpr auto DELAY = 30; constexpr auto TIMEOUT = 1; + while (server.get_connections_count()) { + server.get_config_shared()->del_in_connections( + server.get_config_shared()->get_in_connections_count() + ); + } server.get_config_shared()->set_handler(new command_handler_t(DELAY), &command_handler_t::destroy); for (auto i = 0; i < N; ++i) { tag = create_connection(); @@ -332,7 +339,7 @@ TEST(test_epee_connection, test_lifetime) ), &command_handler_t::destroy ); - for (auto i = 0; i < N; ++i) { + for (auto i = 0; i < N * N * N; ++i) { { connection_ptr conn(new connection_t(io_context, shared_state, {}, {})); conn->socket().connect(endpoint); @@ -342,6 +349,7 @@ TEST(test_epee_connection, test_lifetime) } ASSERT_TRUE(shared_state->get_connections_count() == 1); shared_state->del_out_connections(1); + while (shared_state->sock_count); ASSERT_TRUE(shared_state->get_connections_count() == 0); } @@ -452,7 +460,11 @@ TEST(test_epee_connection, test_lifetime) } for (;workers.size(); workers.pop_back()) workers.back().join(); - + while (server.get_connections_count()) { + server.get_config_shared()->del_in_connections( + server.get_config_shared()->get_in_connections_count() + ); + } }); for (auto& w: workers) { @@ -462,3 +474,241 @@ TEST(test_epee_connection, test_lifetime) server.timed_wait_server_stop(5 * 1000); server.deinit_server(); } + +TEST(test_epee_connection, ssl_shutdown) +{ + struct context_t: epee::net_utils::connection_context_base { + static constexpr size_t get_max_bytes(int) noexcept { return -1; } + static constexpr int handshake_command() noexcept { return 1001; } + static constexpr bool handshake_complete() noexcept { return true; } + }; + + struct command_handler_t: epee::levin::levin_commands_handler { + virtual int invoke(int, const epee::span, epee::byte_stream&, context_t&) override { return {}; } + virtual int notify(int, const epee::span, context_t&) override { return {}; } + virtual void callback(context_t&) override {} + virtual void on_connection_new(context_t&) override {} + virtual void on_connection_close(context_t&) override { } + virtual ~command_handler_t() override {} + static void destroy(epee::levin::levin_commands_handler* ptr) { delete ptr; } + }; + + using handler_t = epee::levin::async_protocol_handler; + using io_context_t = boost::asio::io_service; + using endpoint_t = boost::asio::ip::tcp::endpoint; + using server_t = epee::net_utils::boosted_tcp_server; + using socket_t = boost::asio::ip::tcp::socket; + using ssl_socket_t = boost::asio::ssl::stream; + using ssl_context_t = boost::asio::ssl::context; + using ec_t = boost::system::error_code; + + endpoint_t endpoint(boost::asio::ip::address::from_string("127.0.0.1"), 5263); + server_t server(epee::net_utils::e_connection_type_P2P); + server.init_server(endpoint.port(), + endpoint.address().to_string(), + 0, + "", + false, + true, + epee::net_utils::ssl_support_t::e_ssl_support_enabled + ); + server.get_config_shared()->set_handler(new command_handler_t, &command_handler_t::destroy); + server.run_server(2, false); + + ssl_context_t ssl_context{boost::asio::ssl::context::sslv23}; + io_context_t io_context; + ssl_socket_t socket(io_context, ssl_context); + ec_t ec; + socket.next_layer().connect(endpoint, ec); + EXPECT_EQ(ec.value(), 0); + socket.handshake(boost::asio::ssl::stream_base::client, ec); + EXPECT_EQ(ec.value(), 0); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + while (server.get_config_shared()->get_connections_count() < 1); + server.get_config_shared()->del_in_connections(1); + while (server.get_config_shared()->get_connections_count() > 0); + server.send_stop_signal(); + EXPECT_TRUE(server.timed_wait_server_stop(5 * 1000)); + server.deinit_server(); + socket.next_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); + socket.next_layer().close(ec); + socket.shutdown(ec); +} + +TEST(test_epee_connection, ssl_handshake) +{ + using io_context_t = boost::asio::io_service; + using work_t = boost::asio::io_service::work; + using work_ptr = std::shared_ptr; + using workers_t = std::vector; + using socket_t = boost::asio::ip::tcp::socket; + using ssl_socket_t = boost::asio::ssl::stream; + using ssl_socket_ptr = std::unique_ptr; + using ssl_options_t = epee::net_utils::ssl_options_t; + io_context_t io_context; + work_ptr work(std::make_shared(io_context)); + workers_t workers; + auto constexpr N = 2; + while (workers.size() < N) { + workers.emplace_back([&io_context]{ + io_context.run(); + }); + } + ssl_options_t ssl_options{{}}; + auto ssl_context = ssl_options.create_context(); + for (size_t i = 0; i < N * N * N; ++i) { + ssl_socket_ptr ssl_socket(new ssl_socket_t(io_context, ssl_context)); + ssl_socket->next_layer().open(boost::asio::ip::tcp::v4()); + for (size_t i = 0; i < N; ++i) { + io_context.post([]{ + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + }); + } + EXPECT_EQ( + ssl_options.handshake( + *ssl_socket, + ssl_socket_t::server, + {}, + {}, + std::chrono::milliseconds(0) + ), + false + ); + ssl_socket->next_layer().close(); + ssl_socket.reset(); + } + work.reset(); + for (;workers.size(); workers.pop_back()) + workers.back().join(); +} + + +TEST(boosted_tcp_server, strand_deadlock) +{ + using context_t = epee::net_utils::connection_context_base; + using lock_t = std::mutex; + using unique_lock_t = std::unique_lock; + + struct config_t { + using condition_t = std::condition_variable_any; + using lock_guard_t = std::lock_guard; + void notify_success() + { + lock_guard_t guard(lock); + success = true; + condition.notify_all(); + } + lock_t lock; + condition_t condition; + bool success; + }; + + struct handler_t { + using config_type = config_t; + using connection_context = context_t; + using byte_slice_t = epee::byte_slice; + using socket_t = epee::net_utils::i_service_endpoint; + + handler_t(socket_t *socket, config_t &config, context_t &context): + socket(socket), + config(config), + context(context) + {} + void after_init_connection() + { + unique_lock_t guard(lock); + if (not context.m_is_income) { + guard.unlock(); + socket->do_send(byte_slice_t{"."}); + } + } + void handle_qued_callback() + { + } + bool handle_recv(const char *data, size_t bytes_transferred) + { + unique_lock_t guard(lock); + if (not context.m_is_income) { + if (context.m_recv_cnt == 1024) { + guard.unlock(); + socket->do_send(byte_slice_t{"."}); + } + } + else { + if (context.m_recv_cnt == 1) { + for(size_t i = 0; i < 1024; ++i) { + guard.unlock(); + socket->do_send(byte_slice_t{"."}); + guard.lock(); + } + } + else if(context.m_recv_cnt == 2) { + guard.unlock(); + socket->close(); + } + } + return true; + } + void release_protocol() + { + unique_lock_t guard(lock); + if(not context.m_is_income + and context.m_recv_cnt == 1024 + and context.m_send_cnt == 2 + ) { + guard.unlock(); + config.notify_success(); + } + } + + lock_t lock; + socket_t *socket; + config_t &config; + context_t &context; + }; + + using server_t = epee::net_utils::boosted_tcp_server; + using endpoint_t = boost::asio::ip::tcp::endpoint; + + endpoint_t endpoint(boost::asio::ip::address::from_string("127.0.0.1"), 5262); + server_t server(epee::net_utils::e_connection_type_P2P); + server.init_server( + endpoint.port(), + endpoint.address().to_string(), + {}, + {}, + {}, + true, + epee::net_utils::ssl_support_t::e_ssl_support_disabled + ); + server.run_server(2, {}); + server.async_call( + [&]{ + context_t context; + ASSERT_TRUE( + server.connect( + endpoint.address().to_string(), + std::to_string(endpoint.port()), + 5, + context, + "0.0.0.0", + epee::net_utils::ssl_support_t::e_ssl_support_disabled + ) + ); + } + ); + { + unique_lock_t guard(server.get_config_object().lock); + EXPECT_TRUE( + server.get_config_object().condition.wait_for( + guard, + std::chrono::seconds(5), + [&] { return server.get_config_object().success; } + ) + ); + } + + server.send_stop_signal(); + server.timed_wait_server_stop(5 * 1000); + server.deinit_server(); +} From 3be1dbd0963a76a05fd7f72f28100726daa1c4e7 Mon Sep 17 00:00:00 2001 From: anon Date: Mon, 28 Jun 2021 19:13:02 +0000 Subject: [PATCH 2/3] connection: fix implementation --- CMakeLists.txt | 1 + .../epee/include/net/abstract_tcp_server2.h | 231 ++- .../epee/include/net/abstract_tcp_server2.inl | 1824 +++++++++-------- contrib/epee/include/net/net_ssl.h | 5 + contrib/epee/src/net_ssl.cpp | 121 +- 5 files changed, 1288 insertions(+), 894 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3abd0722a..b05c087cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1076,6 +1076,7 @@ if(STATIC) set(Boost_USE_STATIC_RUNTIME ON) endif() find_package(Boost 1.58 QUIET REQUIRED COMPONENTS system filesystem thread date_time chrono regex serialization program_options locale) +add_definitions(-DBOOST_ASIO_ENABLE_SEQUENTIAL_STRAND_ALLOCATION) set(CMAKE_FIND_LIBRARY_SUFFIXES ${OLD_LIB_SUFFIXES}) if(NOT Boost_FOUND) diff --git a/contrib/epee/include/net/abstract_tcp_server2.h b/contrib/epee/include/net/abstract_tcp_server2.h index 51aa9f275..0684573f2 100644 --- a/contrib/epee/include/net/abstract_tcp_server2.h +++ b/contrib/epee/include/net/abstract_tcp_server2.h @@ -44,12 +44,16 @@ #include #include #include +#include #include #include +#include +#include #include #include #include +#include #include "byte_slice.h" #include "net_utils_base.h" #include "syncobj.h" @@ -85,6 +89,181 @@ namespace net_utils public i_service_endpoint, public connection_basic { + private: + using string_t = std::string; + using handler_t = t_protocol_handler; + using context_t = typename handler_t::connection_context; + using connection_t = connection; + using connection_ptr = boost::shared_ptr; + using ssl_support_t = epee::net_utils::ssl_support_t; + using timer_t = boost::asio::steady_timer; + using duration_t = timer_t::duration; + using lock_t = std::mutex; + using condition_t = std::condition_variable_any; + using lock_guard_t = std::lock_guard; + using unique_lock_t = std::unique_lock; + using byte_slice_t = epee::byte_slice; + using ec_t = boost::system::error_code; + using handshake_t = boost::asio::ssl::stream_base::handshake_type; + + using io_context_t = boost::asio::io_service; + using strand_t = boost::asio::io_service::strand; + using socket_t = boost::asio::ip::tcp::socket; + + using write_queue_t = std::deque; + using read_buffer_t = std::array; + using network_throttle_t = epee::net_utils::network_throttle; + using network_throttle_manager_t = epee::net_utils::network_throttle_manager; + + unsigned int host_count(int delta = 0); + duration_t get_default_timeout(); + duration_t get_timeout_from_bytes_read(size_t bytes) const; + + void start_timer(duration_t duration, bool add = {}); + void async_wait_timer(); + void cancel_timer(); + + void start_handshake(); + void start_read(); + void start_write(); + void start_shutdown(); + void cancel_socket(); + + void cancel_handler(); + + void interrupt(); + void on_interrupted(); + + void terminate(); + void on_terminating(); + + bool send(byte_slice_t message); + bool start_internal( + bool is_income, + bool is_multithreaded, + boost::optional real_remote + ); + + struct state_t { + struct stat_t { + struct { + network_throttle_t throttle{"speed_in", "throttle_speed_in"}; + } in; + struct { + network_throttle_t throttle{"speed_out", "throttle_speed_out"}; + } out; + }; + + struct data_t { + struct { + read_buffer_t buffer; + } read; + struct { + write_queue_t queue; + bool wait_consume; + } write; + }; + + struct ssl_t { + bool enabled; + bool forced; + bool detected; + bool handshaked; + }; + + struct socket_t { + bool connected; + + bool wait_handshake; + bool cancel_handshake; + + bool wait_read; + bool handle_read; + bool cancel_read; + + bool wait_write; + bool handle_write; + bool cancel_write; + + bool wait_shutdown; + bool cancel_shutdown; + }; + + struct timer_t { + bool wait_expire; + bool cancel_expire; + bool reset_expire; + }; + + struct timers_t { + struct throttle_t { + timer_t in; + timer_t out; + }; + + timer_t general; + throttle_t throttle; + }; + + enum status_t { + TERMINATED, + RUNNING, + INTERRUPTED, + TERMINATING, + WASTED, + }; + + struct protocol_t { + size_t reference_counter; + bool released; + bool initialized; + + bool wait_release; + bool wait_init; + size_t wait_callback; + }; + + lock_t lock; + condition_t condition; + status_t status; + socket_t socket; + ssl_t ssl; + timers_t timers; + protocol_t protocol; + stat_t stat; + data_t data; + }; + + using status_t = typename state_t::status_t; + + struct timers_t { + timers_t(io_context_t &io_context): + general(io_context), + throttle(io_context) + {} + struct throttle_t { + throttle_t(io_context_t &io_context): + in(io_context), + out(io_context) + {} + timer_t in; + timer_t out; + }; + + timer_t general; + throttle_t throttle; + }; + + io_context_t &io_context; + t_connection_type connection_type; + context_t context{}; + strand_t strand; + timers_t timers; + connection_ptr self{}; + bool local{}; + string_t host{}; + state_t state{}; + handler_t handler; public: typedef typename t_protocol_handler::connection_context t_connection_context; @@ -141,58 +320,6 @@ namespace net_utils virtual bool add_ref(); virtual bool release(); //------------------------------------------------------ - bool do_send_chunk(byte_slice chunk); ///< will send (or queue) a part of data. internal use only - - boost::shared_ptr > safe_shared_from_this(); - bool shutdown(); - /// Handle completion of a receive operation. - void handle_receive(const boost::system::error_code& e, - std::size_t bytes_transferred); - - /// Handle completion of a read operation. - void handle_read(const boost::system::error_code& e, - std::size_t bytes_transferred); - - /// Handle completion of a write operation. - void handle_write(const boost::system::error_code& e, size_t cb); - - /// reset connection timeout timer and callback - void reset_timer(boost::posix_time::milliseconds ms, bool add); - boost::posix_time::milliseconds get_default_timeout(); - boost::posix_time::milliseconds get_timeout_from_bytes_read(size_t bytes); - - /// host connection count tracking - unsigned int host_count(const std::string &host, int delta = 0); - - /// Buffer for incoming data. - boost::array buffer_; - size_t buffer_ssl_init_fill; - - t_connection_context context; - - // TODO what do they mean about wait on destructor?? --rfree : - //this should be the last one, because it could be wait on destructor, while other activities possible on other threads - t_protocol_handler m_protocol_handler; - //typename t_protocol_handler::config_type m_dummy_config; - size_t m_reference_count = 0; // reference count managed through add_ref/release support - boost::shared_ptr > m_self_ref; // the reference to hold - critical_section m_self_refs_lock; - critical_section m_chunking_lock; // held while we add small chunks of the big do_send() to small do_send_chunk() - critical_section m_shutdown_lock; // held while shutting down - - t_connection_type m_connection_type; - - // for calculate speed (last 60 sec) - network_throttle m_throttle_speed_in; - network_throttle m_throttle_speed_out; - boost::mutex m_throttle_speed_in_mutex; - boost::mutex m_throttle_speed_out_mutex; - - boost::asio::deadline_timer m_timer; - bool m_local; - bool m_ready_to_close; - std::string m_host; - public: void setRpcStation(); }; diff --git a/contrib/epee/include/net/abstract_tcp_server2.inl b/contrib/epee/include/net/abstract_tcp_server2.inl index 0c3b457bc..0fc9228b1 100644 --- a/contrib/epee/include/net/abstract_tcp_server2.inl +++ b/contrib/epee/include/net/abstract_tcp_server2.inl @@ -76,670 +76,12 @@ namespace net_utils /************************************************************************/ /* */ /************************************************************************/ - template - connection::connection( boost::asio::io_service& io_service, - std::shared_ptr state, - t_connection_type connection_type, - ssl_support_t ssl_support - ) - : connection(boost::asio::ip::tcp::socket{io_service}, std::move(state), connection_type, ssl_support) + template + unsigned int connection::host_count(int delta) { - } - - template - connection::connection( boost::asio::ip::tcp::socket&& sock, - std::shared_ptr state, - t_connection_type connection_type, - ssl_support_t ssl_support - ) - : - connection_basic(std::move(sock), state, ssl_support), - m_protocol_handler(this, check_and_get(state), context), - buffer_ssl_init_fill(0), - m_connection_type( connection_type ), - m_throttle_speed_in("speed_in", "throttle_speed_in"), - m_throttle_speed_out("speed_out", "throttle_speed_out"), - m_timer(GET_IO_SERVICE(socket_)), - m_local(false), - m_ready_to_close(false) - { - MDEBUG("test, connection constructor set m_connection_type="< - connection::~connection() noexcept(false) - { - if(!m_was_shutdown) - { - _dbg3("[sock " << socket().native_handle() << "] Socket destroyed without shutdown."); - shutdown(); - } - - _dbg3("[sock " << socket().native_handle() << "] Socket destroyed"); - } - //--------------------------------------------------------------------------------- - template - boost::shared_ptr > connection::safe_shared_from_this() - { - try - { - return connection::shared_from_this(); - } - catch (const boost::bad_weak_ptr&) - { - // It happens when the connection is being deleted - return boost::shared_ptr >(); - } - } - //--------------------------------------------------------------------------------- - template - bool connection::start(bool is_income, bool is_multithreaded) - { - TRY_ENTRY(); - - boost::system::error_code ec; - auto remote_ep = socket().remote_endpoint(ec); - CHECK_AND_NO_ASSERT_MES(!ec, false, "Failed to get remote endpoint: " << ec.message() << ':' << ec.value()); - CHECK_AND_NO_ASSERT_MES(remote_ep.address().is_v4() || remote_ep.address().is_v6(), false, "only IPv4 and IPv6 supported here"); - - if (remote_ep.address().is_v4()) - { - const unsigned long ip_ = boost::asio::detail::socket_ops::host_to_network_long(remote_ep.address().to_v4().to_ulong()); - return start(is_income, is_multithreaded, ipv4_network_address{uint32_t(ip_), remote_ep.port()}); - } - else - { - const auto ip_ = remote_ep.address().to_v6(); - return start(is_income, is_multithreaded, ipv6_network_address{ip_, remote_ep.port()}); - } - CATCH_ENTRY_L0("connection::start()", false); - } - //--------------------------------------------------------------------------------- - template - bool connection::start(bool is_income, bool is_multithreaded, network_address real_remote) - { - TRY_ENTRY(); - - // Use safe_shared_from_this, because of this is public method and it can be called on the object being deleted - auto self = safe_shared_from_this(); - if(!self) - return false; - - m_is_multithreaded = is_multithreaded; - m_local = real_remote.is_loopback() || real_remote.is_local(); - - // create a random uuid, we don't need crypto strength here - const boost::uuids::uuid random_uuid = boost::uuids::random_generator()(); - - context = t_connection_context{}; - bool ssl = m_ssl_support == epee::net_utils::ssl_support_t::e_ssl_support_enabled; - context.set_details(random_uuid, std::move(real_remote), is_income, ssl); - - boost::system::error_code ec; - auto local_ep = socket().local_endpoint(ec); - CHECK_AND_NO_ASSERT_MES(!ec, false, "Failed to get local endpoint: " << ec.message() << ':' << ec.value()); - - _dbg3("[sock " << socket_.native_handle() << "] new connection from " << print_connection_context_short(context) << - " to " << local_ep.address().to_string() << ':' << local_ep.port() << - ", total sockets objects " << get_state().sock_count); - - if(static_cast(get_state()).pfilter && !static_cast(get_state()).pfilter->is_remote_host_allowed(context.m_remote_address)) - { - _dbg2("[sock " << socket().native_handle() << "] host denied " << context.m_remote_address.host_str() << ", shutdowning connection"); - close(); - return false; - } - - m_host = context.m_remote_address.host_str(); - try { host_count(m_host, 1); } catch(...) { /* ignore */ } - - m_protocol_handler.after_init_connection(); - - reset_timer(boost::posix_time::milliseconds(m_local ? NEW_CONNECTION_TIMEOUT_LOCAL : NEW_CONNECTION_TIMEOUT_REMOTE), false); - - // first read on the raw socket to detect SSL for the server - buffer_ssl_init_fill = 0; - if (is_income && m_ssl_support != epee::net_utils::ssl_support_t::e_ssl_support_disabled) - socket().async_receive(boost::asio::buffer(buffer_), - strand_.wrap( - std::bind(&connection::handle_receive, self, - std::placeholders::_1, - std::placeholders::_2))); - else - async_read_some(boost::asio::buffer(buffer_), - strand_.wrap( - std::bind(&connection::handle_read, self, - std::placeholders::_1, - std::placeholders::_2))); -#if !defined(_WIN32) || !defined(__i686) - // not supported before Windows7, too lazy for runtime check - // Just exclude for 32bit windows builds - //set ToS flag - int tos = get_tos_flag(); - boost::asio::detail::socket_option::integer< IPPROTO_IP, IP_TOS > - optionTos( tos ); - socket().set_option( optionTos ); - //_dbg1("Set ToS flag to " << tos); -#endif - - boost::asio::ip::tcp::no_delay noDelayOption(false); - socket().set_option(noDelayOption); - - return true; - - CATCH_ENTRY_L0("connection::start()", false); - } - //--------------------------------------------------------------------------------- - template - bool connection::request_callback() - { - TRY_ENTRY(); - _dbg2("[" << print_connection_context_short(context) << "] request_callback"); - // Use safe_shared_from_this, because of this is public method and it can be called on the object being deleted - auto self = safe_shared_from_this(); - if(!self) - return false; - - strand_.post(boost::bind(&connection::call_back_starter, self)); - CATCH_ENTRY_L0("connection::request_callback()", false); - return true; - } - //--------------------------------------------------------------------------------- - template - boost::asio::io_service& connection::get_io_service() - { - return GET_IO_SERVICE(socket()); - } - //--------------------------------------------------------------------------------- - template - bool connection::add_ref() - { - TRY_ENTRY(); - - // Use safe_shared_from_this, because of this is public method and it can be called on the object being deleted - auto self = safe_shared_from_this(); - if(!self) - return false; - //_dbg3("[sock " << socket().native_handle() << "] add_ref, m_peer_number=" << mI->m_peer_number); - CRITICAL_REGION_LOCAL(self->m_self_refs_lock); - //_dbg3("[sock " << socket().native_handle() << "] add_ref 2, m_peer_number=" << mI->m_peer_number); - ++m_reference_count; - m_self_ref = std::move(self); - return true; - CATCH_ENTRY_L0("connection::add_ref()", false); - } - //--------------------------------------------------------------------------------- - template - bool connection::release() - { - TRY_ENTRY(); - boost::shared_ptr > back_connection_copy; - LOG_TRACE_CC(context, "[sock " << socket().native_handle() << "] release"); - CRITICAL_REGION_BEGIN(m_self_refs_lock); - CHECK_AND_ASSERT_MES(m_reference_count, false, "[sock " << socket().native_handle() << "] m_reference_count already at 0 at connection::release() call"); - // is this the last reference? - if (--m_reference_count == 0) { - // move the held reference to a local variable, keeping the object alive until the function terminates - std::swap(back_connection_copy, m_self_ref); - } - CRITICAL_REGION_END(); - return true; - CATCH_ENTRY_L0("connection::release()", false); - } - //--------------------------------------------------------------------------------- - template - void connection::call_back_starter() - { - TRY_ENTRY(); - _dbg2("[" << print_connection_context_short(context) << "] fired_callback"); - m_protocol_handler.handle_qued_callback(); - CATCH_ENTRY_L0("connection::call_back_starter()", void()); - } - //--------------------------------------------------------------------------------- - template - void connection::save_dbg_log() - { - std::string address, port; - boost::system::error_code e; - - boost::asio::ip::tcp::endpoint endpoint = socket().remote_endpoint(e); - if (e) - { - address = ""; - port = ""; - } - else - { - address = endpoint.address().to_string(); - port = boost::lexical_cast(endpoint.port()); - } - MDEBUG(" connection type " << to_string( m_connection_type ) << " " - << socket().local_endpoint().address().to_string() << ":" << socket().local_endpoint().port() - << " <--> " << context.m_remote_address.str() << " (via " << address << ":" << port << ")"); - } - //--------------------------------------------------------------------------------- - template - void connection::handle_read(const boost::system::error_code& e, - std::size_t bytes_transferred) - { - TRY_ENTRY(); - //_info("[sock " << socket().native_handle() << "] Async read calledback."); - - if (m_was_shutdown) - return; - - if (!e) - { - double current_speed_down; - { - CRITICAL_REGION_LOCAL(m_throttle_speed_in_mutex); - m_throttle_speed_in.handle_trafic_exact(bytes_transferred); - current_speed_down = m_throttle_speed_in.get_current_speed(); - } - context.m_current_speed_down = current_speed_down; - context.m_max_speed_down = std::max(context.m_max_speed_down, current_speed_down); - - { - CRITICAL_REGION_LOCAL( epee::net_utils::network_throttle_manager::network_throttle_manager::m_lock_get_global_throttle_in ); - epee::net_utils::network_throttle_manager::network_throttle_manager::get_global_throttle_in().handle_trafic_exact(bytes_transferred); - } - - double delay=0; // will be calculated - how much we should sleep to obey speed limit etc - - - if (speed_limit_is_enabled()) { - do // keep sleeping if we should sleep - { - { //_scope_dbg1("CRITICAL_REGION_LOCAL"); - CRITICAL_REGION_LOCAL( epee::net_utils::network_throttle_manager::m_lock_get_global_throttle_in ); - delay = epee::net_utils::network_throttle_manager::get_global_throttle_in().get_sleep_time_after_tick( bytes_transferred ); - } - - if (m_was_shutdown) - return; - - delay *= 0.5; - long int ms = (long int)(delay * 100); - if (ms > 0) { - reset_timer(boost::posix_time::milliseconds(ms + 1), true); - boost::this_thread::sleep_for(boost::chrono::milliseconds(ms)); - } - } while(delay > 0); - } // any form of sleeping - - //_info("[sock " << socket().native_handle() << "] RECV " << bytes_transferred); - logger_handle_net_read(bytes_transferred); - context.m_last_recv = time(NULL); - context.m_recv_cnt += bytes_transferred; - m_ready_to_close = false; - bool recv_res = m_protocol_handler.handle_recv(buffer_.data(), bytes_transferred); - if(!recv_res) - { - //_info("[sock " << socket().native_handle() << "] protocol_want_close"); - //some error in protocol, protocol handler ask to close connection - m_want_close_connection = true; - bool do_shutdown = false; - CRITICAL_REGION_BEGIN(m_send_que_lock); - if(!m_send_que.size()) - do_shutdown = true; - CRITICAL_REGION_END(); - if(do_shutdown) - shutdown(); - }else - { - reset_timer(get_timeout_from_bytes_read(bytes_transferred), false); - async_read_some(boost::asio::buffer(buffer_), - strand_.wrap( - boost::bind(&connection::handle_read, connection::shared_from_this(), - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred))); - //_info("[sock " << socket().native_handle() << "]Async read requested."); - } - }else - { - _dbg3("[sock " << socket().native_handle() << "] Some not success at read: " << e.message() << ':' << e.value()); - if(e.value() != 2) - { - _dbg3("[sock " << socket().native_handle() << "] Some problems at read: " << e.message() << ':' << e.value()); - shutdown(); - } - else - { - _dbg3("[sock " << socket().native_handle() << "] peer closed connection"); - bool do_shutdown = false; - CRITICAL_REGION_BEGIN(m_send_que_lock); - if(!m_send_que.size()) - do_shutdown = true; - CRITICAL_REGION_END(); - if (m_ready_to_close || do_shutdown) - shutdown(); - } - m_ready_to_close = true; - } - // If an error occurs then no new asynchronous operations are started. This - // means that all shared_ptr references to the connection object will - // disappear and the object will be destroyed automatically after this - // handler returns. The connection class's destructor closes the socket. - CATCH_ENTRY_L0("connection::handle_read", void()); - } - //--------------------------------------------------------------------------------- - template - void connection::handle_receive(const boost::system::error_code& e, - std::size_t bytes_transferred) - { - TRY_ENTRY(); - - if (m_was_shutdown) return; - - if (e) - { - // offload the error case - handle_read(e, bytes_transferred); - return; - } - - buffer_ssl_init_fill += bytes_transferred; - MTRACE("we now have " << buffer_ssl_init_fill << "/" << get_ssl_magic_size() << " bytes needed to detect SSL"); - if (buffer_ssl_init_fill < get_ssl_magic_size()) - { - socket().async_receive(boost::asio::buffer(buffer_.data() + buffer_ssl_init_fill, buffer_.size() - buffer_ssl_init_fill), - strand_.wrap( - boost::bind(&connection::handle_receive, connection::shared_from_this(), - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred))); - return; - } - - // detect SSL - if (m_ssl_support == epee::net_utils::ssl_support_t::e_ssl_support_autodetect) - { - if (is_ssl((const unsigned char*)buffer_.data(), buffer_ssl_init_fill)) - { - MDEBUG("That looks like SSL"); - m_ssl_support = epee::net_utils::ssl_support_t::e_ssl_support_enabled; // read/write to the SSL socket - } - else - { - MDEBUG("That does not look like SSL"); - m_ssl_support = epee::net_utils::ssl_support_t::e_ssl_support_disabled; // read/write to the raw socket - } - } - - if (m_ssl_support == epee::net_utils::ssl_support_t::e_ssl_support_enabled) - { - // Handshake - if (!handshake(boost::asio::ssl::stream_base::server, boost::asio::const_buffer(buffer_.data(), buffer_ssl_init_fill))) - { - MERROR("SSL handshake failed"); - m_want_close_connection = true; - m_ready_to_close = true; - bool do_shutdown = false; - CRITICAL_REGION_BEGIN(m_send_que_lock); - if(!m_send_que.size()) - do_shutdown = true; - CRITICAL_REGION_END(); - if(do_shutdown) - shutdown(); - return; - } - } - else - { - handle_read(e, buffer_ssl_init_fill); - return; - } - - async_read_some(boost::asio::buffer(buffer_), - strand_.wrap( - boost::bind(&connection::handle_read, connection::shared_from_this(), - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred))); - - // If an error occurs then no new asynchronous operations are started. This - // means that all shared_ptr references to the connection object will - // disappear and the object will be destroyed automatically after this - // handler returns. The connection class's destructor closes the socket. - CATCH_ENTRY_L0("connection::handle_receive", void()); - } - //--------------------------------------------------------------------------------- - template - bool connection::call_run_once_service_io() - { - TRY_ENTRY(); - if(!m_is_multithreaded) - { - //single thread model, we can wait in blocked call - size_t cnt = GET_IO_SERVICE(socket()).run_one(); - if(!cnt)//service is going to quit - return false; - }else - { - //multi thread model, we can't(!) wait in blocked call - //so we make non blocking call and releasing CPU by calling sleep(0); - //if no handlers were called - //TODO: Maybe we need to have have critical section + event + callback to upper protocol to - //ask it inside(!) critical region if we still able to go in event wait... - size_t cnt = GET_IO_SERVICE(socket()).poll_one(); - if(!cnt) - misc_utils::sleep_no_w(1); - } - - return true; - CATCH_ENTRY_L0("connection::call_run_once_service_io", false); - } - //--------------------------------------------------------------------------------- - template - bool connection::do_send(byte_slice message) { - TRY_ENTRY(); - - // Use safe_shared_from_this, because of this is public method and it can be called on the object being deleted - auto self = safe_shared_from_this(); - if (!self) return false; - if (m_was_shutdown) return false; - // TODO avoid copy - - std::uint8_t const* const message_data = message.data(); - const std::size_t message_size = message.size(); - - const double factor = 32; // TODO config - typedef long long signed int t_safe; // my t_size to avoid any overunderflow in arithmetic - const t_safe chunksize_good = (t_safe)( 1024 * std::max(1.0,factor) ); - const t_safe chunksize_max = chunksize_good * 2 ; - const bool allow_split = (m_connection_type == e_connection_type_RPC) ? false : true; // do not split RPC data - - CHECK_AND_ASSERT_MES(! (chunksize_max<0), false, "Negative chunksize_max" ); // make sure it is unsigned before removin sign with cast: - long long unsigned int chunksize_max_unsigned = static_cast( chunksize_max ) ; - - if (allow_split && (message_size > chunksize_max_unsigned)) { - { // LOCK: chunking - epee::critical_region_t send_guard(m_chunking_lock); // *** critical *** - - MDEBUG("do_send() will SPLIT into small chunks, from packet="<::do_send", false); - } // do_send() - - //--------------------------------------------------------------------------------- - template - bool connection::do_send_chunk(byte_slice chunk) - { - TRY_ENTRY(); - // Use safe_shared_from_this, because of this is public method and it can be called on the object being deleted - auto self = safe_shared_from_this(); - if(!self) - return false; - if(m_was_shutdown) - return false; - double current_speed_up; - { - CRITICAL_REGION_LOCAL(m_throttle_speed_out_mutex); - m_throttle_speed_out.handle_trafic_exact(chunk.size()); - current_speed_up = m_throttle_speed_out.get_current_speed(); - } - context.m_current_speed_up = current_speed_up; - context.m_max_speed_up = std::max(context.m_max_speed_up, current_speed_up); - - //_info("[sock " << socket().native_handle() << "] SEND " << cb); - context.m_last_send = time(NULL); - context.m_send_cnt += chunk.size(); - //some data should be wrote to stream - //request complete - - // No sleeping here; sleeping is done once and for all in "handle_write" - - m_send_que_lock.lock(); // *** critical *** - epee::misc_utils::auto_scope_leave_caller scope_exit_handler = epee::misc_utils::create_scope_leave_handler([&](){m_send_que_lock.unlock();}); - - long int retry=0; - const long int retry_limit = 5*4; - while (m_send_que.size() > ABSTRACT_SERVER_SEND_QUE_MAX_COUNT) - { - retry++; - - /* if ( ::cryptonote::core::get_is_stopping() ) { // TODO re-add fast stop - _fact("ABORT queue wait due to stopping"); - return false; // aborted - }*/ - - using engine = std::mt19937; - - engine rng; - std::random_device dev; - std::seed_seq::result_type rand[engine::state_size]{}; // Use complete bit space - - std::generate_n(rand, engine::state_size, std::ref(dev)); - std::seed_seq seed(rand, rand + engine::state_size); - rng.seed(seed); - - long int ms = 250 + (rng() % 50); - MDEBUG("Sleeping because QUEUE is FULL, in " << __FUNCTION__ << " for " << ms << " ms before packet_size="< retry_limit) { - MWARNING("send que size is more than ABSTRACT_SERVER_SEND_QUE_MAX_COUNT(" << ABSTRACT_SERVER_SEND_QUE_MAX_COUNT << "), shutting down connection"); - shutdown(); - return false; - } - } - - m_send_que.push_back(std::move(chunk)); - - if(m_send_que.size() > 1) - { // active operation should be in progress, nothing to do, just wait last operation callback - auto size_now = m_send_que.back().size(); - MDEBUG("do_send_chunk() NOW just queues: packet="<::handle_write, self, std::placeholders::_1, std::placeholders::_2) - ) - ); - //_dbg3("(chunk): " << size_now); - //logger_handle_net_write(size_now); - //_info("[sock " << socket().native_handle() << "] Async send requested " << m_send_que.front().size()); - } - - //do_send_handler_stop( ptr , cb ); // empty function - - return true; - - CATCH_ENTRY_L0("connection::do_send_chunk", false); - } // do_send_chunk - //--------------------------------------------------------------------------------- - template - boost::posix_time::milliseconds connection::get_default_timeout() - { - unsigned count; - try { count = host_count(m_host); } catch (...) { count = 0; } - const unsigned shift = get_state().sock_count > AGGRESSIVE_TIMEOUT_THRESHOLD ? std::min(std::max(count, 1u) - 1, 8u) : 0; - boost::posix_time::milliseconds timeout(0); - if (m_local) - timeout = boost::posix_time::milliseconds(DEFAULT_TIMEOUT_MS_LOCAL >> shift); - else - timeout = boost::posix_time::milliseconds(DEFAULT_TIMEOUT_MS_REMOTE >> shift); - return timeout; - } - //--------------------------------------------------------------------------------- - template - boost::posix_time::milliseconds connection::get_timeout_from_bytes_read(size_t bytes) - { - boost::posix_time::milliseconds ms = (boost::posix_time::milliseconds)(unsigned)(bytes * TIMEOUT_EXTRA_MS_PER_BYTE); - const auto cur = m_timer.expires_from_now().total_milliseconds(); - if (cur > 0) - ms += (boost::posix_time::milliseconds)cur; - if (ms > get_default_timeout()) - ms = get_default_timeout(); - return ms; - } - //--------------------------------------------------------------------------------- - template - unsigned int connection::host_count(const std::string &host, int delta) - { - static boost::mutex hosts_mutex; - CRITICAL_REGION_LOCAL(hosts_mutex); - static std::map hosts; + static lock_t hosts_mutex; + lock_guard_t guard(hosts_mutex); + static std::map hosts; unsigned int &val = hosts[host]; if (delta > 0) MTRACE("New connection from host " << host << ": " << val); @@ -750,185 +92,1033 @@ namespace net_utils val += delta; return val; } - //--------------------------------------------------------------------------------- - template - void connection::reset_timer(boost::posix_time::milliseconds ms, bool add) + + template + typename connection::duration_t connection::get_default_timeout() { - const auto tms = ms.total_milliseconds(); - if (tms < 0 || (add && tms == 0)) - { - MWARNING("Ignoring negative timeout " << ms); + unsigned count{}; + try { count = host_count(); } catch (...) {} + const unsigned shift = ( + connection_basic::get_state().sock_count > AGGRESSIVE_TIMEOUT_THRESHOLD ? + std::min(std::max(count, 1u) - 1, 8u) : + 0 + ); + return ( + local ? + std::chrono::milliseconds(DEFAULT_TIMEOUT_MS_LOCAL >> shift) : + std::chrono::milliseconds(DEFAULT_TIMEOUT_MS_REMOTE >> shift) + ); + } + + template + typename connection::duration_t connection::get_timeout_from_bytes_read(size_t bytes) const + { + return std::chrono::duration_cast::duration_t>( + std::chrono::duration( + bytes * TIMEOUT_EXTRA_MS_PER_BYTE + ) + ); + } + + template + void connection::start_timer(duration_t duration, bool add) + { + if (state.timers.general.wait_expire) { + state.timers.general.cancel_expire = true; + state.timers.general.reset_expire = true; + ec_t ec; + timers.general.expires_from_now( + std::min( + duration + (add ? timers.general.expires_from_now() : duration_t{}), + get_default_timeout() + ), + ec + ); + } + else { + ec_t ec; + timers.general.expires_from_now( + std::min( + duration + (add ? timers.general.expires_from_now() : duration_t{}), + get_default_timeout() + ), + ec + ); + async_wait_timer(); + } + } + + template + void connection::async_wait_timer() + { + if (state.timers.general.wait_expire) return; - } - MTRACE((add ? "Adding" : "Setting") << " " << ms << " expiry"); - auto self = safe_shared_from_this(); - if(!self) - { - MERROR("Resetting timer on a dead object"); - return; - } - if (m_was_shutdown) - { - MERROR("Setting timer on a shut down object"); - return; - } - if (add) - { - const auto cur = m_timer.expires_from_now().total_milliseconds(); - if (cur > 0) - ms += (boost::posix_time::milliseconds)cur; - } - m_timer.expires_from_now(ms); - m_timer.async_wait([=](const boost::system::error_code& ec) - { - if(ec == boost::asio::error::operation_aborted) - return; - MDEBUG(context << "connection timeout, closing"); - self->close(); + state.timers.general.wait_expire = true; + auto self = connection::shared_from_this(); + timers.general.async_wait([this, self](const ec_t & ec){ + lock_guard_t guard(state.lock); + state.timers.general.wait_expire = false; + if (state.timers.general.cancel_expire) { + state.timers.general.cancel_expire = false; + if (state.timers.general.reset_expire) { + state.timers.general.reset_expire = false; + async_wait_timer(); + } + else if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + else if (state.status == status_t::RUNNING) + interrupt(); + else if (state.status == status_t::INTERRUPTED) + terminate(); }); } - //--------------------------------------------------------------------------------- - template - bool connection::shutdown() + + template + void connection::cancel_timer() { - CRITICAL_REGION_BEGIN(m_shutdown_lock); - if (m_was_shutdown) - return true; - m_was_shutdown = true; - // Initiate graceful connection closure. - m_timer.cancel(); - boost::system::error_code ignored_ec; - if (m_ssl_support == epee::net_utils::ssl_support_t::e_ssl_support_enabled) - { - const shared_state &state = static_cast(get_state()); - if (!state.stop_signal_sent) - socket_.shutdown(ignored_ec); - } - socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignored_ec); - if (!m_host.empty()) - { - try { host_count(m_host, -1); } catch (...) { /* ignore */ } - m_host = ""; - } - CRITICAL_REGION_END(); - m_protocol_handler.release_protocol(); - return true; + if (not state.timers.general.wait_expire) + return; + state.timers.general.cancel_expire = true; + state.timers.general.reset_expire = false; + ec_t ec; + timers.general.cancel(ec); } - //--------------------------------------------------------------------------------- - template - bool connection::close() + + template + void connection::start_handshake() { - TRY_ENTRY(); - auto self = safe_shared_from_this(); - if(!self) + if (state.socket.wait_handshake) + return; + static_assert( + epee::net_utils::get_ssl_magic_size() <= sizeof(state.data.read.buffer), + "" + ); + auto self = connection::shared_from_this(); + if (not state.ssl.forced and not state.ssl.detected) { + state.socket.wait_read = true; + boost::asio::async_read( + connection_basic::socket_.next_layer(), + boost::asio::buffer( + state.data.read.buffer.data(), + state.data.read.buffer.size() + ), + boost::asio::transfer_exactly(epee::net_utils::get_ssl_magic_size()), + strand.wrap( + [this, self](const ec_t &ec, size_t bytes_transferred){ + lock_guard_t guard(state.lock); + state.socket.wait_read = false; + if (state.socket.cancel_read) { + state.socket.cancel_read = false; + if (state.status == status_t::RUNNING) + interrupt(); + else if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + else if (ec.value()) { + terminate(); + } + else if ( + not epee::net_utils::is_ssl( + static_cast( + state.data.read.buffer.data() + ), + bytes_transferred + ) + ) { + state.ssl.enabled = false; + state.socket.handle_read = true; + connection_basic::strand_.post( + [this, self, bytes_transferred]{ + bool success = handler.handle_recv( + reinterpret_cast(state.data.read.buffer.data()), + bytes_transferred + ); + lock_guard_t guard(state.lock); + state.socket.handle_read = false; + if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + else if (not success) + interrupt(); + else { + start_read(); + } + } + ); + } + else { + state.ssl.detected = true; + start_handshake(); + } + } + ) + ); + return; + } + + state.socket.wait_handshake = true; + auto on_handshake = [this, self](const ec_t &ec, size_t bytes_transferred){ + lock_guard_t guard(state.lock); + state.socket.wait_handshake = false; + if (state.socket.cancel_handshake) { + state.socket.cancel_handshake = false; + if (state.status == status_t::RUNNING) + interrupt(); + else if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + else if (ec.value()) { + ec_t ec; + connection_basic::socket_.next_layer().shutdown( + socket_t::shutdown_both, + ec + ); + connection_basic::socket_.next_layer().close(ec); + state.socket.connected = false; + interrupt(); + } + else { + state.ssl.handshaked = true; + start_write(); + start_read(); + } + }; + const auto handshake = handshake_t::server; + static_cast( + connection_basic::get_state() + ).ssl_options().configure(connection_basic::socket_, handshake); + strand.post( + [this, self, on_handshake]{ + connection_basic::socket_.async_handshake( + handshake, + boost::asio::buffer( + state.data.read.buffer.data(), + state.ssl.forced ? 0 : + epee::net_utils::get_ssl_magic_size() + ), + strand.wrap(on_handshake) + ); + } + ); + } + + template + void connection::start_read() + { + if (state.timers.throttle.in.wait_expire || state.socket.wait_read || + state.socket.handle_read + ) + return; + auto self = connection::shared_from_this(); + if (connection_type != e_connection_type_RPC) { + auto calc_duration = []{ + CRITICAL_REGION_LOCAL( + network_throttle_manager_t::m_lock_get_global_throttle_in + ); + return std::chrono::duration_cast::duration_t>( + std::chrono::duration( + std::min( + network_throttle_manager_t::get_global_throttle_in( + ).get_sleep_time_after_tick(1), + 1.0 + ) + ) + ); + }; + const auto duration = calc_duration(); + if (duration > duration_t{}) { + ec_t ec; + timers.throttle.in.expires_from_now(duration, ec); + state.timers.throttle.in.wait_expire = true; + timers.throttle.in.async_wait([this, self](const ec_t &ec){ + lock_guard_t guard(state.lock); + state.timers.throttle.in.wait_expire = false; + if (state.timers.throttle.in.cancel_expire) { + state.timers.throttle.in.cancel_expire = false; + if (state.status == status_t::RUNNING) + interrupt(); + else if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + else if (ec.value()) + interrupt(); + else + start_read(); + }); + return; + } + } + state.socket.wait_read = true; + auto on_read = [this, self](const ec_t &ec, size_t bytes_transferred){ + lock_guard_t guard(state.lock); + state.socket.wait_read = false; + if (state.socket.cancel_read) { + state.socket.cancel_read = false; + if (state.status == status_t::RUNNING) + interrupt(); + else if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + else if (ec.value()) + terminate(); + else { + { + state.stat.in.throttle.handle_trafic_exact(bytes_transferred); + const auto speed = state.stat.in.throttle.get_current_speed(); + context.m_current_speed_down = speed; + context.m_max_speed_down = std::max( + context.m_max_speed_down, + speed + ); + { + CRITICAL_REGION_LOCAL( + network_throttle_manager_t::m_lock_get_global_throttle_in + ); + network_throttle_manager_t::get_global_throttle_in( + ).handle_trafic_exact(bytes_transferred); + } + connection_basic::logger_handle_net_read(bytes_transferred); + context.m_last_recv = time(NULL); + context.m_recv_cnt += bytes_transferred; + start_timer(get_timeout_from_bytes_read(bytes_transferred), true); + } + state.socket.handle_read = true; + connection_basic::strand_.post( + [this, self, bytes_transferred]{ + bool success = handler.handle_recv( + reinterpret_cast(state.data.read.buffer.data()), + bytes_transferred + ); + lock_guard_t guard(state.lock); + state.socket.handle_read = false; + if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + else if (not success) + interrupt(); + else { + start_read(); + } + } + ); + } + }; + if (not state.ssl.enabled) + connection_basic::socket_.next_layer().async_read_some( + boost::asio::buffer( + state.data.read.buffer.data(), + state.data.read.buffer.size() + ), + strand.wrap(on_read) + ); + else + strand.post( + [this, self, on_read]{ + connection_basic::socket_.async_read_some( + boost::asio::buffer( + state.data.read.buffer.data(), + state.data.read.buffer.size() + ), + strand.wrap(on_read) + ); + } + ); + } + + template + void connection::start_write() + { + if (state.timers.throttle.out.wait_expire || state.socket.wait_write || + state.data.write.queue.empty() || + (state.ssl.enabled && not state.ssl.handshaked) + ) + return; + auto self = connection::shared_from_this(); + if (connection_type != e_connection_type_RPC) { + auto calc_duration = [this]{ + CRITICAL_REGION_LOCAL( + network_throttle_manager_t::m_lock_get_global_throttle_out + ); + return std::chrono::duration_cast::duration_t>( + std::chrono::duration( + std::min( + network_throttle_manager_t::get_global_throttle_out( + ).get_sleep_time_after_tick( + state.data.write.queue.back().size() + ), + 1.0 + ) + ) + ); + }; + const auto duration = calc_duration(); + if (duration > duration_t{}) { + ec_t ec; + timers.throttle.out.expires_from_now(duration, ec); + state.timers.throttle.out.wait_expire = true; + timers.throttle.out.async_wait([this, self](const ec_t &ec){ + lock_guard_t guard(state.lock); + state.timers.throttle.out.wait_expire = false; + if (state.timers.throttle.out.cancel_expire) { + state.timers.throttle.out.cancel_expire = false; + if (state.status == status_t::RUNNING) + interrupt(); + else if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + else if (ec.value()) + interrupt(); + else + start_write(); + }); + } + } + + state.socket.wait_write = true; + auto on_write = [this, self](const ec_t &ec, size_t bytes_transferred){ + lock_guard_t guard(state.lock); + state.socket.wait_write = false; + if (state.socket.cancel_write) { + state.socket.cancel_write = false; + state.data.write.queue.clear(); + if (state.status == status_t::RUNNING) + interrupt(); + else if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + else if (ec.value()) { + state.data.write.queue.clear(); + interrupt(); + } + else { + { + state.stat.out.throttle.handle_trafic_exact(bytes_transferred); + const auto speed = state.stat.out.throttle.get_current_speed(); + context.m_current_speed_up = speed; + context.m_max_speed_down = std::max( + context.m_max_speed_down, + speed + ); + { + CRITICAL_REGION_LOCAL( + network_throttle_manager_t::m_lock_get_global_throttle_out + ); + network_throttle_manager_t::get_global_throttle_out( + ).handle_trafic_exact(bytes_transferred); + } + connection_basic::logger_handle_net_write(bytes_transferred); + context.m_last_send = time(NULL); + context.m_send_cnt += bytes_transferred; + + start_timer(get_default_timeout(), true); + } + assert(bytes_transferred == state.data.write.queue.back().size()); + state.data.write.queue.pop_back(); + state.condition.notify_all(); + start_write(); + } + }; + if (not state.ssl.enabled) + boost::asio::async_write( + connection_basic::socket_.next_layer(), + boost::asio::buffer( + state.data.write.queue.back().data(), + state.data.write.queue.back().size() + ), + strand.wrap(on_write) + ); + else + strand.post( + [this, self, on_write]{ + boost::asio::async_write( + connection_basic::socket_, + boost::asio::buffer( + state.data.write.queue.back().data(), + state.data.write.queue.back().size() + ), + strand.wrap(on_write) + ); + } + ); + } + + template + void connection::start_shutdown() + { + if (state.socket.wait_shutdown) + return; + auto self = connection::shared_from_this(); + state.socket.wait_shutdown = true; + auto on_shutdown = [this, self](const ec_t &ec){ + lock_guard_t guard(state.lock); + state.socket.wait_shutdown = false; + if (state.socket.cancel_shutdown) { + state.socket.cancel_shutdown = false; + if (state.status == status_t::RUNNING) + interrupt(); + else if (state.status == status_t::INTERRUPTED) + terminate(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + else if (ec.value()) + terminate(); + else { + cancel_timer(); + on_interrupted(); + } + }; + strand.post( + [this, self, on_shutdown]{ + connection_basic::socket_.async_shutdown( + strand.wrap(on_shutdown) + ); + } + ); + start_timer(get_default_timeout()); + } + + template + void connection::cancel_socket() + { + bool wait_socket = false; + if (state.socket.wait_handshake) + wait_socket = state.socket.cancel_handshake = true; + if (state.timers.throttle.in.wait_expire) { + state.timers.throttle.in.cancel_expire = true; + ec_t ec; + timers.throttle.in.cancel(ec); + } + if (state.socket.wait_read) + wait_socket = state.socket.cancel_read = true; + if (state.timers.throttle.out.wait_expire) { + state.timers.throttle.out.cancel_expire = true; + ec_t ec; + timers.throttle.out.cancel(ec); + } + if (state.socket.wait_write) + wait_socket = state.socket.cancel_write = true; + if (state.socket.wait_shutdown) + wait_socket = state.socket.cancel_shutdown = true; + if (wait_socket) { + ec_t ec; + connection_basic::socket_.next_layer().cancel(ec); + } + } + + template + void connection::cancel_handler() + { + if (state.protocol.released || state.protocol.wait_release) + return; + state.protocol.wait_release = true; + state.lock.unlock(); + handler.release_protocol(); + state.lock.lock(); + state.protocol.wait_release = false; + state.protocol.released = true; + if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + } + + template + void connection::interrupt() + { + if (state.status != status_t::RUNNING) + return; + state.status = status_t::INTERRUPTED; + cancel_timer(); + cancel_socket(); + on_interrupted(); + state.condition.notify_all(); + cancel_handler(); + } + + template + void connection::on_interrupted() + { + assert(state.status == status_t::INTERRUPTED); + if (state.timers.general.wait_expire) + return; + if (state.socket.wait_handshake) + return; + if (state.timers.throttle.in.wait_expire) + return; + if (state.socket.wait_read) + return; + if (state.socket.handle_read) + return; + if (state.timers.throttle.out.wait_expire) + return; + if (state.socket.wait_write) + return; + if (state.socket.wait_shutdown) + return; + if (state.protocol.wait_init) + return; + if (state.protocol.wait_callback) + return; + if (state.protocol.wait_release) + return; + if (state.socket.connected) { + if (not state.ssl.enabled) { + ec_t ec; + connection_basic::socket_.next_layer().shutdown( + socket_t::shutdown_both, + ec + ); + connection_basic::socket_.next_layer().close(ec); + state.socket.connected = false; + state.status = status_t::WASTED; + } + else + start_shutdown(); + } + else + state.status = status_t::WASTED; + } + + template + void connection::terminate() + { + if (state.status != status_t::RUNNING && + state.status != status_t::INTERRUPTED + ) + return; + state.status = status_t::TERMINATING; + cancel_timer(); + cancel_socket(); + on_terminating(); + state.condition.notify_all(); + cancel_handler(); + } + + template + void connection::on_terminating() + { + assert(state.status == status_t::TERMINATING); + if (state.timers.general.wait_expire) + return; + if (state.socket.wait_handshake) + return; + if (state.timers.throttle.in.wait_expire) + return; + if (state.socket.wait_read) + return; + if (state.socket.handle_read) + return; + if (state.timers.throttle.out.wait_expire) + return; + if (state.socket.wait_write) + return; + if (state.socket.wait_shutdown) + return; + if (state.protocol.wait_init) + return; + if (state.protocol.wait_callback) + return; + if (state.protocol.wait_release) + return; + if (state.socket.connected) { + ec_t ec; + connection_basic::socket_.next_layer().shutdown( + socket_t::shutdown_both, + ec + ); + connection_basic::socket_.next_layer().close(ec); + state.socket.connected = false; + } + state.status = status_t::WASTED; + } + + template + bool connection::send(byte_slice_t message) + { + lock_guard_t guard(state.lock); + if (state.status != status_t::RUNNING || state.socket.wait_handshake) return false; - //_info("[sock " << socket().native_handle() << "] Que Shutdown called."); - m_timer.cancel(); - size_t send_que_size = 0; - CRITICAL_REGION_BEGIN(m_send_que_lock); - send_que_size = m_send_que.size(); - CRITICAL_REGION_END(); - m_want_close_connection = true; - if(!send_que_size) - { - shutdown(); + auto wait_consume = [this] { + auto random_delay = []{ + using engine = std::mt19937; + std::random_device dev; + std::seed_seq::result_type rand[ + engine::state_size // Use complete bit space + ]{}; + std::generate_n(rand, engine::state_size, std::ref(dev)); + std::seed_seq seed(rand, rand + engine::state_size); + engine rng(seed); + return std::chrono::milliseconds( + std::uniform_int_distribution<>(5000, 6000)(rng) + ); + }; + if (state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT) + return true; + state.data.write.wait_consume = true; + bool success = state.condition.wait_for( + state.lock, + random_delay(), + [this]{ + return ( + state.status != status_t::RUNNING || + state.data.write.queue.size() <= + ABSTRACT_SERVER_SEND_QUE_MAX_COUNT + ); + } + ); + state.data.write.wait_consume = false; + if (not success) { + terminate(); + return false; + } + else + return state.status == status_t::RUNNING; + }; + auto wait_sender = [this] { + state.condition.wait( + state.lock, + [this] { + return ( + state.status != status_t::RUNNING || + not state.data.write.wait_consume + ); + } + ); + return state.status == status_t::RUNNING; + }; + if (not wait_sender()) + return false; + constexpr size_t CHUNK_SIZE = 32 * 1024; + if (connection_type == e_connection_type_RPC || + message.size() <= 2 * CHUNK_SIZE + ) { + if (not wait_consume()) + return false; + state.data.write.queue.emplace_front(std::move(message)); + start_write(); } - + else { + while (!message.empty()) { + if (not wait_consume()) + return false; + state.data.write.queue.emplace_front( + message.take_slice(CHUNK_SIZE) + ); + start_write(); + } + } + state.condition.notify_all(); return true; - CATCH_ENTRY_L0("connection::close", false); } - //--------------------------------------------------------------------------------- - template - bool connection::send_done() + + template + bool connection::start_internal( + bool is_income, + bool is_multithreaded, + boost::optional real_remote + ) { - if (m_ready_to_close) - return close(); - m_ready_to_close = true; + unique_lock_t guard(state.lock); + if (state.status != status_t::TERMINATED) + return false; + if (not real_remote) { + ec_t ec; + auto endpoint = connection_basic::socket_.next_layer().remote_endpoint( + ec + ); + if (ec.value()) + return false; + real_remote = ( + endpoint.address().is_v6() ? + network_address{ + ipv6_network_address{endpoint.address().to_v6(), endpoint.port()} + } : + network_address{ + ipv4_network_address{ + uint32_t{ + boost::asio::detail::socket_ops::host_to_network_long( + endpoint.address().to_v4().to_ulong() + ) + }, + endpoint.port() + } + } + ); + } + auto *filter = static_cast( + connection_basic::get_state() + ).pfilter; + if (filter and not filter->is_remote_host_allowed(*real_remote)) + return false; + ec_t ec; + #if !defined(_WIN32) || !defined(__i686) + connection_basic::socket_.next_layer().set_option( + boost::asio::detail::socket_option::integer{ + connection_basic::get_tos_flag() + }, + ec + ); + if (ec.value()) + return false; + #endif + connection_basic::socket_.next_layer().set_option( + boost::asio::ip::tcp::no_delay{false}, + ec + ); + if (ec.value()) + return false; + connection_basic::m_is_multithreaded = is_multithreaded; + context.set_details( + boost::uuids::random_generator()(), + *real_remote, + is_income, + connection_basic::m_ssl_support == ssl_support_t::e_ssl_support_enabled + ); + host = real_remote->host_str(); + try { host_count(1); } catch(...) { /* ignore */ } + local = real_remote->is_loopback() || real_remote->is_local(); + state.ssl.enabled = ( + connection_basic::m_ssl_support != ssl_support_t::e_ssl_support_disabled + ); + state.ssl.forced = ( + connection_basic::m_ssl_support == ssl_support_t::e_ssl_support_enabled + ); + state.socket.connected = true; + state.status = status_t::RUNNING; + start_timer( + std::chrono::milliseconds( + local ? NEW_CONNECTION_TIMEOUT_LOCAL : NEW_CONNECTION_TIMEOUT_REMOTE + ) + ); + state.protocol.wait_init = true; + guard.unlock(); + handler.after_init_connection(); + guard.lock(); + state.protocol.wait_init = false; + state.protocol.initialized = true; + if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + else if (not is_income || not state.ssl.enabled) + start_read(); + else + start_handshake(); return true; } - //--------------------------------------------------------------------------------- - template - bool connection::cancel() + + template + connection::connection( + io_context_t &io_context, + std::shared_ptr shared_state, + t_connection_type connection_type, + ssl_support_t ssl_support + ): + connection( + std::move(socket_t{io_context}), + std::move(shared_state), + connection_type, + ssl_support + ) + { + } + + template + connection::connection( + socket_t &&socket, + std::shared_ptr shared_state, + t_connection_type connection_type, + ssl_support_t ssl_support + ): + connection_basic(std::move(socket), shared_state, ssl_support), + handler(this, *shared_state, context), + connection_type(connection_type), + io_context{GET_IO_SERVICE(connection_basic::socket_)}, + strand{io_context}, + timers{io_context} + { + } + + template + connection::~connection() noexcept(false) + { + lock_guard_t guard(state.lock); + assert(state.status == status_t::TERMINATED || + state.status == status_t::WASTED || + io_context.stopped() + ); + if (state.status != status_t::WASTED) + return; + try { host_count(-1); } catch (...) { /* ignore */ } + } + + template + bool connection::start( + bool is_income, + bool is_multithreaded + ) + { + return start_internal(is_income, is_multithreaded, {}); + } + + template + bool connection::start( + bool is_income, + bool is_multithreaded, + network_address real_remote + ) + { + return start_internal(is_income, is_multithreaded, real_remote); + } + + template + void connection::save_dbg_log() + { + lock_guard_t guard(state.lock); + string_t address; + string_t port; + ec_t ec; + auto endpoint = connection_basic::socket().remote_endpoint(ec); + if (ec.value()) { + address = ""; + port = ""; + } + else { + address = endpoint.address().to_string(); + port = std::to_string(endpoint.port()); + } + MDEBUG( + " connection type " << std::to_string(connection_type) << + " " << connection_basic::socket().local_endpoint().address().to_string() << + ":" << connection_basic::socket().local_endpoint().port() << + " <--> " << context.m_remote_address.str() << + " (via " << address << ":" << port << ")" + ); + } + + template + bool connection::speed_limit_is_enabled() const + { + return connection_type != e_connection_type_RPC; + } + + template + bool connection::cancel() { return close(); } - //--------------------------------------------------------------------------------- - template - void connection::handle_write(const boost::system::error_code& e, size_t cb) + + template + bool connection::do_send(byte_slice message) { - TRY_ENTRY(); - LOG_TRACE_CC(context, "[sock " << socket().native_handle() << "] Async send calledback " << cb); - - if (e) - { - _dbg1("[sock " << socket().native_handle() << "] Some problems at write: " << e.message() << ':' << e.value()); - shutdown(); - return; - } - logger_handle_net_write(cb); - - // The single sleeping that is needed for correctly handling "out" speed throttling - if (speed_limit_is_enabled()) { - sleep_before_packet(cb, 1, 1); - } - - bool do_shutdown = false; - CRITICAL_REGION_BEGIN(m_send_que_lock); - if(m_send_que.empty()) - { - _erro("[sock " << socket().native_handle() << "] m_send_que.size() == 0 at handle_write!"); - return; - } - - m_send_que.pop_front(); - if(m_send_que.empty()) - { - if(m_want_close_connection) - { - do_shutdown = true; - } - }else - { - //have more data to send - reset_timer(get_default_timeout(), false); - auto size_now = m_send_que.front().size(); - MDEBUG("handle_write() NOW SENDS: packet="<::handle_write, connection::shared_from_this(), std::placeholders::_1, std::placeholders::_2) - ) - ); - //_dbg3("(normal)" << size_now); - } - CRITICAL_REGION_END(); - - if(do_shutdown) - { - shutdown(); - } - CATCH_ENTRY_L0("connection::handle_write", void()); + return send(std::move(message)); } - //--------------------------------------------------------------------------------- - template - void connection::setRpcStation() + template + bool connection::send_done() { - m_connection_type = e_connection_type_RPC; - MDEBUG("set m_connection_type = RPC "); + return true; } + template + bool connection::close() + { + lock_guard_t guard(state.lock); + if (state.status != status_t::RUNNING) + return false; + terminate(); + return true; + } - template - bool connection::speed_limit_is_enabled() const { - return m_connection_type != e_connection_type_RPC ; - } + template + bool connection::call_run_once_service_io() + { + if(connection_basic::m_is_multithreaded) { + if (not io_context.poll_one()) + misc_utils::sleep_no_w(1); + } + else { + if (!io_context.run_one()) + return false; + } + return true; + } - /************************************************************************/ - /* */ - /************************************************************************/ + template + bool connection::request_callback() + { + lock_guard_t guard(state.lock); + if (state.status != status_t::RUNNING) + return false; + auto self = connection::shared_from_this(); + ++state.protocol.wait_callback; + connection_basic::strand_.post([this, self]{ + handler.handle_qued_callback(); + lock_guard_t guard(state.lock); + --state.protocol.wait_callback; + if (state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (state.status == status_t::TERMINATING) + on_terminating(); + }); + return true; + } + + template + typename connection::io_context_t &connection::get_io_service() + { + return io_context; + } + + template + bool connection::add_ref() + { + try { + auto self = connection::shared_from_this(); + lock_guard_t guard(state.lock); + this->self = std::move(self); + ++state.protocol.reference_counter; + return true; + } + catch (boost::bad_weak_ptr &exception) { + return false; + } + } + + template + bool connection::release() + { + connection_ptr self; + lock_guard_t guard(state.lock); + if (not --state.protocol.reference_counter) + self = std::move(this->self); + return true; + } + + template + void connection::setRpcStation() + { + lock_guard_t guard(state.lock); + connection_type = e_connection_type_RPC; + } template boosted_tcp_server::boosted_tcp_server( t_connection_type connection_type ) : diff --git a/contrib/epee/include/net/net_ssl.h b/contrib/epee/include/net/net_ssl.h index 108e6771b..c79a3acc1 100644 --- a/contrib/epee/include/net/net_ssl.h +++ b/contrib/epee/include/net/net_ssl.h @@ -110,6 +110,11 @@ namespace net_utils //! Search against internal fingerprints. Always false if `behavior() != user_certificate_check`. bool has_fingerprint(boost::asio::ssl::verify_context &ctx) const; + //! configure ssl_stream handshake verification + void configure( + boost::asio::ssl::stream &socket, + boost::asio::ssl::stream_base::handshake_type type, + const std::string& host = {}) const; boost::asio::ssl::context create_context() const; /*! \note If `this->support == autodetect && this->verification != none`, diff --git a/contrib/epee/src/net_ssl.cpp b/contrib/epee/src/net_ssl.cpp index 7dfb56068..7dda65bb5 100644 --- a/contrib/epee/src/net_ssl.cpp +++ b/contrib/epee/src/net_ssl.cpp @@ -32,6 +32,8 @@ #include #include #include +#include +#include #include #include #include @@ -488,12 +490,10 @@ bool ssl_options_t::has_fingerprint(boost::asio::ssl::verify_context &ctx) const return false; } -bool ssl_options_t::handshake( +void ssl_options_t::configure( boost::asio::ssl::stream &socket, boost::asio::ssl::stream_base::handshake_type type, - boost::asio::const_buffer buffer, - const std::string& host, - std::chrono::milliseconds timeout) const + const std::string& host) const { socket.next_layer().set_option(boost::asio::ip::tcp::no_delay(true)); @@ -538,30 +538,101 @@ bool ssl_options_t::handshake( return true; }); } +} - auto& io_service = GET_IO_SERVICE(socket); - boost::asio::steady_timer deadline(io_service, timeout); - deadline.async_wait([&socket](const boost::system::error_code& error) { - if (error != boost::asio::error::operation_aborted) +bool ssl_options_t::handshake( + boost::asio::ssl::stream &socket, + boost::asio::ssl::stream_base::handshake_type type, + boost::asio::const_buffer buffer, + const std::string& host, + std::chrono::milliseconds timeout) const +{ + configure(socket, type, host); + + auto start_handshake = [&]{ + using ec_t = boost::system::error_code; + using timer_t = boost::asio::steady_timer; + using strand_t = boost::asio::io_service::strand; + using lock_t = std::mutex; + using lock_guard_t = std::lock_guard; + using condition_t = std::condition_variable_any; + using socket_t = boost::asio::ip::tcp::socket; + + auto &io_context = GET_IO_SERVICE(socket); + if (io_context.stopped()) + io_context.reset(); + strand_t strand(io_context); + timer_t deadline(io_context, timeout); + + struct state_t { + lock_t lock; + condition_t condition; + ec_t result; + bool wait_timer; + bool wait_handshake; + bool cancel_timer; + bool cancel_handshake; + }; + state_t state{}; + + state.wait_timer = true; + auto on_timer = [&](const ec_t &ec){ + lock_guard_t guard(state.lock); + state.wait_timer = false; + state.condition.notify_all(); + if (not state.cancel_timer) { + state.cancel_handshake = true; + ec_t ec; + socket.next_layer().cancel(ec); + } + }; + + state.wait_handshake = true; + auto on_handshake = [&](const ec_t &ec, size_t bytes_transferred){ + lock_guard_t guard(state.lock); + state.wait_handshake = false; + state.condition.notify_all(); + state.result = ec; + if (not state.cancel_handshake) { + state.cancel_timer = true; + ec_t ec; + deadline.cancel(ec); + } + }; + + deadline.async_wait(on_timer); + strand.post( + [&]{ + socket.async_handshake( + type, + boost::asio::buffer(buffer), + strand.wrap(on_handshake) + ); + } + ); + + while (!io_context.stopped()) { - socket.next_layer().close(); + io_context.poll_one(); + lock_guard_t guard(state.lock); + state.condition.wait_for( + state.lock, + std::chrono::milliseconds(30), + [&]{ + return not state.wait_timer and not state.wait_handshake; + } + ); + if (not state.wait_timer and not state.wait_handshake) + break; } - }); - - boost::system::error_code ec = boost::asio::error::would_block; - socket.async_handshake(type, boost::asio::buffer(buffer), boost::lambda::var(ec) = boost::lambda::_1); - if (io_service.stopped()) - { - io_service.reset(); - } - while (ec == boost::asio::error::would_block && !io_service.stopped()) - { - // should poll_one(), can't run_one() because it can block if there is - // another worker thread executing io_service's tasks - // TODO: once we get Boost 1.66+, replace with run_one_for/run_until - std::this_thread::sleep_for(std::chrono::milliseconds(30)); - io_service.poll_one(); - } + if (state.result.value()) { + ec_t ec; + socket.next_layer().shutdown(socket_t::shutdown_both, ec); + socket.next_layer().close(ec); + } + return state.result; + }; + const auto ec = start_handshake(); if (ec) { From a82fba4b7b944a54d2a14922f44d7eee367e4912 Mon Sep 17 00:00:00 2001 From: j-berman Date: Wed, 6 Jul 2022 16:47:34 -0700 Subject: [PATCH 3/3] address PR comments --- .../epee/include/net/abstract_tcp_server2.h | 82 +-- .../epee/include/net/abstract_tcp_server2.inl | 672 +++++++++--------- contrib/epee/src/net_ssl.cpp | 21 +- tests/unit_tests/epee_boosted_tcp_server.cpp | 10 +- 4 files changed, 390 insertions(+), 395 deletions(-) diff --git a/contrib/epee/include/net/abstract_tcp_server2.h b/contrib/epee/include/net/abstract_tcp_server2.h index 0684573f2..bc0da66e2 100644 --- a/contrib/epee/include/net/abstract_tcp_server2.h +++ b/contrib/epee/include/net/abstract_tcp_server2.h @@ -89,20 +89,14 @@ namespace net_utils public i_service_endpoint, public connection_basic { + public: + typedef typename t_protocol_handler::connection_context t_connection_context; private: - using string_t = std::string; - using handler_t = t_protocol_handler; - using context_t = typename handler_t::connection_context; - using connection_t = connection; + using connection_t = connection; using connection_ptr = boost::shared_ptr; using ssl_support_t = epee::net_utils::ssl_support_t; using timer_t = boost::asio::steady_timer; using duration_t = timer_t::duration; - using lock_t = std::mutex; - using condition_t = std::condition_variable_any; - using lock_guard_t = std::lock_guard; - using unique_lock_t = std::unique_lock; - using byte_slice_t = epee::byte_slice; using ec_t = boost::system::error_code; using handshake_t = boost::asio::ssl::stream_base::handshake_type; @@ -110,8 +104,6 @@ namespace net_utils using strand_t = boost::asio::io_service::strand; using socket_t = boost::asio::ip::tcp::socket; - using write_queue_t = std::deque; - using read_buffer_t = std::array; using network_throttle_t = epee::net_utils::network_throttle; using network_throttle_manager_t = epee::net_utils::network_throttle_manager; @@ -119,6 +111,8 @@ namespace net_utils duration_t get_default_timeout(); duration_t get_timeout_from_bytes_read(size_t bytes) const; + void state_status_check(); + void start_timer(duration_t duration, bool add = {}); void async_wait_timer(); void cancel_timer(); @@ -137,13 +131,21 @@ namespace net_utils void terminate(); void on_terminating(); - bool send(byte_slice_t message); + bool send(epee::byte_slice message); bool start_internal( bool is_income, bool is_multithreaded, boost::optional real_remote ); + enum status_t { + TERMINATED, + RUNNING, + INTERRUPTED, + TERMINATING, + WASTED, + }; + struct state_t { struct stat_t { struct { @@ -156,10 +158,10 @@ namespace net_utils struct data_t { struct { - read_buffer_t buffer; + std::array buffer; } read; struct { - write_queue_t queue; + std::deque queue; bool wait_consume; } write; }; @@ -171,7 +173,7 @@ namespace net_utils bool handshaked; }; - struct socket_t { + struct socket_status_t { bool connected; bool wait_handshake; @@ -189,30 +191,22 @@ namespace net_utils bool cancel_shutdown; }; - struct timer_t { + struct timer_status_t { bool wait_expire; bool cancel_expire; bool reset_expire; }; - struct timers_t { + struct timers_status_t { struct throttle_t { - timer_t in; - timer_t out; + timer_status_t in; + timer_status_t out; }; - timer_t general; + timer_status_t general; throttle_t throttle; }; - enum status_t { - TERMINATED, - RUNNING, - INTERRUPTED, - TERMINATING, - WASTED, - }; - struct protocol_t { size_t reference_counter; bool released; @@ -223,19 +217,17 @@ namespace net_utils size_t wait_callback; }; - lock_t lock; - condition_t condition; + std::mutex lock; + std::condition_variable_any condition; status_t status; - socket_t socket; + socket_status_t socket; ssl_t ssl; - timers_t timers; + timers_status_t timers; protocol_t protocol; stat_t stat; data_t data; }; - using status_t = typename state_t::status_t; - struct timers_t { timers_t(io_context_t &io_context): general(io_context), @@ -254,19 +246,17 @@ namespace net_utils throttle_t throttle; }; - io_context_t &io_context; - t_connection_type connection_type; - context_t context{}; - strand_t strand; - timers_t timers; + io_context_t &m_io_context; + t_connection_type m_connection_type; + t_connection_context m_conn_context{}; + strand_t m_strand; + timers_t m_timers; connection_ptr self{}; - bool local{}; - string_t host{}; - state_t state{}; - handler_t handler; + bool m_local{}; + std::string m_host{}; + state_t m_state{}; + t_protocol_handler m_handler; public: - typedef typename t_protocol_handler::connection_context t_connection_context; - struct shared_state : connection_basic_shared_state, t_protocol_handler::config_type { shared_state() @@ -298,7 +288,7 @@ namespace net_utils // `real_remote` is the actual endpoint (if connection is to proxy, etc.) bool start(bool is_income, bool is_multithreaded, network_address real_remote); - void get_context(t_connection_context& context_){context_ = context;} + void get_context(t_connection_context& context_){context_ = m_conn_context;} void call_back_starter(); diff --git a/contrib/epee/include/net/abstract_tcp_server2.inl b/contrib/epee/include/net/abstract_tcp_server2.inl index 0fc9228b1..81aa725d1 100644 --- a/contrib/epee/include/net/abstract_tcp_server2.inl +++ b/contrib/epee/include/net/abstract_tcp_server2.inl @@ -79,14 +79,14 @@ namespace net_utils template unsigned int connection::host_count(int delta) { - static lock_t hosts_mutex; - lock_guard_t guard(hosts_mutex); - static std::map hosts; - unsigned int &val = hosts[host]; + static std::mutex hosts_mutex; + std::lock_guard guard(hosts_mutex); + static std::map hosts; + unsigned int &val = hosts[m_host]; if (delta > 0) - MTRACE("New connection from host " << host << ": " << val); + MTRACE("New connection from host " << m_host << ": " << val); else if (delta < 0) - MTRACE("Closed connection from host " << host << ": " << val); + MTRACE("Closed connection from host " << m_host << ": " << val); CHECK_AND_ASSERT_THROW_MES(delta >= 0 || val >= (unsigned)-delta, "Count would go negative"); CHECK_AND_ASSERT_THROW_MES(delta <= 0 || val <= std::numeric_limits::max() - (unsigned)delta, "Count would wrap"); val += delta; @@ -104,7 +104,7 @@ namespace net_utils 0 ); return ( - local ? + m_local ? std::chrono::milliseconds(DEFAULT_TIMEOUT_MS_LOCAL >> shift) : std::chrono::milliseconds(DEFAULT_TIMEOUT_MS_REMOTE >> shift) ); @@ -120,16 +120,35 @@ namespace net_utils ); } + template + void connection::state_status_check() + { + switch (m_state.status) + { + case status_t::RUNNING: + interrupt(); + break; + case status_t::INTERRUPTED: + on_interrupted(); + break; + case status_t::TERMINATING: + on_terminating(); + break; + default: + break; + } + } + template void connection::start_timer(duration_t duration, bool add) { - if (state.timers.general.wait_expire) { - state.timers.general.cancel_expire = true; - state.timers.general.reset_expire = true; + if (m_state.timers.general.wait_expire) { + m_state.timers.general.cancel_expire = true; + m_state.timers.general.reset_expire = true; ec_t ec; - timers.general.expires_from_now( + m_timers.general.expires_from_now( std::min( - duration + (add ? timers.general.expires_from_now() : duration_t{}), + duration + (add ? m_timers.general.expires_from_now() : duration_t{}), get_default_timeout() ), ec @@ -137,9 +156,9 @@ namespace net_utils } else { ec_t ec; - timers.general.expires_from_now( + m_timers.general.expires_from_now( std::min( - duration + (add ? timers.general.expires_from_now() : duration_t{}), + duration + (add ? m_timers.general.expires_from_now() : duration_t{}), get_default_timeout() ), ec @@ -151,27 +170,27 @@ namespace net_utils template void connection::async_wait_timer() { - if (state.timers.general.wait_expire) + if (m_state.timers.general.wait_expire) return; - state.timers.general.wait_expire = true; + m_state.timers.general.wait_expire = true; auto self = connection::shared_from_this(); - timers.general.async_wait([this, self](const ec_t & ec){ - lock_guard_t guard(state.lock); - state.timers.general.wait_expire = false; - if (state.timers.general.cancel_expire) { - state.timers.general.cancel_expire = false; - if (state.timers.general.reset_expire) { - state.timers.general.reset_expire = false; + m_timers.general.async_wait([this, self](const ec_t & ec){ + std::lock_guard guard(m_state.lock); + m_state.timers.general.wait_expire = false; + if (m_state.timers.general.cancel_expire) { + m_state.timers.general.cancel_expire = false; + if (m_state.timers.general.reset_expire) { + m_state.timers.general.reset_expire = false; async_wait_timer(); } - else if (state.status == status_t::INTERRUPTED) + else if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); } - else if (state.status == status_t::RUNNING) + else if (m_state.status == status_t::RUNNING) interrupt(); - else if (state.status == status_t::INTERRUPTED) + else if (m_state.status == status_t::INTERRUPTED) terminate(); }); } @@ -179,72 +198,67 @@ namespace net_utils template void connection::cancel_timer() { - if (not state.timers.general.wait_expire) + if (!m_state.timers.general.wait_expire) return; - state.timers.general.cancel_expire = true; - state.timers.general.reset_expire = false; + m_state.timers.general.cancel_expire = true; + m_state.timers.general.reset_expire = false; ec_t ec; - timers.general.cancel(ec); + m_timers.general.cancel(ec); } template void connection::start_handshake() { - if (state.socket.wait_handshake) + if (m_state.socket.wait_handshake) return; static_assert( - epee::net_utils::get_ssl_magic_size() <= sizeof(state.data.read.buffer), + epee::net_utils::get_ssl_magic_size() <= sizeof(m_state.data.read.buffer), "" ); auto self = connection::shared_from_this(); - if (not state.ssl.forced and not state.ssl.detected) { - state.socket.wait_read = true; + if (!m_state.ssl.forced && !m_state.ssl.detected) { + m_state.socket.wait_read = true; boost::asio::async_read( connection_basic::socket_.next_layer(), boost::asio::buffer( - state.data.read.buffer.data(), - state.data.read.buffer.size() + m_state.data.read.buffer.data(), + m_state.data.read.buffer.size() ), boost::asio::transfer_exactly(epee::net_utils::get_ssl_magic_size()), - strand.wrap( + m_strand.wrap( [this, self](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); - state.socket.wait_read = false; - if (state.socket.cancel_read) { - state.socket.cancel_read = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard guard(m_state.lock); + m_state.socket.wait_read = false; + if (m_state.socket.cancel_read) { + m_state.socket.cancel_read = false; + state_status_check(); } else if (ec.value()) { terminate(); } else if ( - not epee::net_utils::is_ssl( + !epee::net_utils::is_ssl( static_cast( - state.data.read.buffer.data() + m_state.data.read.buffer.data() ), bytes_transferred ) ) { - state.ssl.enabled = false; - state.socket.handle_read = true; + m_state.ssl.enabled = false; + m_state.socket.handle_read = true; connection_basic::strand_.post( [this, self, bytes_transferred]{ - bool success = handler.handle_recv( - reinterpret_cast(state.data.read.buffer.data()), + bool success = m_handler.handle_recv( + reinterpret_cast(m_state.data.read.buffer.data()), bytes_transferred ); - lock_guard_t guard(state.lock); - state.socket.handle_read = false; - if (state.status == status_t::INTERRUPTED) + std::lock_guard guard(m_state.lock); + m_state.socket.handle_read = false; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); - else if (not success) + else if (!success) interrupt(); else { start_read(); @@ -253,7 +267,7 @@ namespace net_utils ); } else { - state.ssl.detected = true; + m_state.ssl.detected = true; start_handshake(); } } @@ -262,18 +276,13 @@ namespace net_utils return; } - state.socket.wait_handshake = true; + m_state.socket.wait_handshake = true; auto on_handshake = [this, self](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); - state.socket.wait_handshake = false; - if (state.socket.cancel_handshake) { - state.socket.cancel_handshake = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard guard(m_state.lock); + m_state.socket.wait_handshake = false; + if (m_state.socket.cancel_handshake) { + m_state.socket.cancel_handshake = false; + state_status_check(); } else if (ec.value()) { ec_t ec; @@ -282,11 +291,11 @@ namespace net_utils ec ); connection_basic::socket_.next_layer().close(ec); - state.socket.connected = false; + m_state.socket.connected = false; interrupt(); } else { - state.ssl.handshaked = true; + m_state.ssl.handshaked = true; start_write(); start_read(); } @@ -295,16 +304,16 @@ namespace net_utils static_cast( connection_basic::get_state() ).ssl_options().configure(connection_basic::socket_, handshake); - strand.post( + m_strand.post( [this, self, on_handshake]{ connection_basic::socket_.async_handshake( handshake, boost::asio::buffer( - state.data.read.buffer.data(), - state.ssl.forced ? 0 : + m_state.data.read.buffer.data(), + m_state.ssl.forced ? 0 : epee::net_utils::get_ssl_magic_size() ), - strand.wrap(on_handshake) + m_strand.wrap(on_handshake) ); } ); @@ -313,12 +322,13 @@ namespace net_utils template void connection::start_read() { - if (state.timers.throttle.in.wait_expire || state.socket.wait_read || - state.socket.handle_read - ) + if (m_state.timers.throttle.in.wait_expire || m_state.socket.wait_read || + m_state.socket.handle_read + ) { return; + } auto self = connection::shared_from_this(); - if (connection_type != e_connection_type_RPC) { + if (m_connection_type != e_connection_type_RPC) { auto calc_duration = []{ CRITICAL_REGION_LOCAL( network_throttle_manager_t::m_lock_get_global_throttle_in @@ -336,19 +346,14 @@ namespace net_utils const auto duration = calc_duration(); if (duration > duration_t{}) { ec_t ec; - timers.throttle.in.expires_from_now(duration, ec); - state.timers.throttle.in.wait_expire = true; - timers.throttle.in.async_wait([this, self](const ec_t &ec){ - lock_guard_t guard(state.lock); - state.timers.throttle.in.wait_expire = false; - if (state.timers.throttle.in.cancel_expire) { - state.timers.throttle.in.cancel_expire = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + m_timers.throttle.in.expires_from_now(duration, ec); + m_state.timers.throttle.in.wait_expire = true; + m_timers.throttle.in.async_wait([this, self](const ec_t &ec){ + std::lock_guard guard(m_state.lock); + m_state.timers.throttle.in.wait_expire = false; + if (m_state.timers.throttle.in.cancel_expire) { + m_state.timers.throttle.in.cancel_expire = false; + state_status_check(); } else if (ec.value()) interrupt(); @@ -358,28 +363,23 @@ namespace net_utils return; } } - state.socket.wait_read = true; + m_state.socket.wait_read = true; auto on_read = [this, self](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); - state.socket.wait_read = false; - if (state.socket.cancel_read) { - state.socket.cancel_read = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard guard(m_state.lock); + m_state.socket.wait_read = false; + if (m_state.socket.cancel_read) { + m_state.socket.cancel_read = false; + state_status_check(); } else if (ec.value()) terminate(); else { { - state.stat.in.throttle.handle_trafic_exact(bytes_transferred); - const auto speed = state.stat.in.throttle.get_current_speed(); - context.m_current_speed_down = speed; - context.m_max_speed_down = std::max( - context.m_max_speed_down, + m_state.stat.in.throttle.handle_trafic_exact(bytes_transferred); + const auto speed = m_state.stat.in.throttle.get_current_speed(); + m_conn_context.m_current_speed_down = speed; + m_conn_context.m_max_speed_down = std::max( + m_conn_context.m_max_speed_down, speed ); { @@ -390,24 +390,30 @@ namespace net_utils ).handle_trafic_exact(bytes_transferred); } connection_basic::logger_handle_net_read(bytes_transferred); - context.m_last_recv = time(NULL); - context.m_recv_cnt += bytes_transferred; + m_conn_context.m_last_recv = time(NULL); + m_conn_context.m_recv_cnt += bytes_transferred; start_timer(get_timeout_from_bytes_read(bytes_transferred), true); } - state.socket.handle_read = true; + + // Post handle_recv to a separate `strand_`, distinct from `m_strand` + // which is listening for reads/writes. This avoids a circular dep. + // handle_recv can queue many writes, and `m_strand` will process those + // writes until the connection terminates without deadlocking waiting + // for handle_recv. + m_state.socket.handle_read = true; connection_basic::strand_.post( [this, self, bytes_transferred]{ - bool success = handler.handle_recv( - reinterpret_cast(state.data.read.buffer.data()), + bool success = m_handler.handle_recv( + reinterpret_cast(m_state.data.read.buffer.data()), bytes_transferred ); - lock_guard_t guard(state.lock); - state.socket.handle_read = false; - if (state.status == status_t::INTERRUPTED) + std::lock_guard guard(m_state.lock); + m_state.socket.handle_read = false; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); - else if (not success) + else if (!success) interrupt(); else { start_read(); @@ -416,23 +422,23 @@ namespace net_utils ); } }; - if (not state.ssl.enabled) + if (!m_state.ssl.enabled) connection_basic::socket_.next_layer().async_read_some( boost::asio::buffer( - state.data.read.buffer.data(), - state.data.read.buffer.size() + m_state.data.read.buffer.data(), + m_state.data.read.buffer.size() ), - strand.wrap(on_read) + m_strand.wrap(on_read) ); else - strand.post( + m_strand.post( [this, self, on_read]{ connection_basic::socket_.async_read_some( boost::asio::buffer( - state.data.read.buffer.data(), - state.data.read.buffer.size() + m_state.data.read.buffer.data(), + m_state.data.read.buffer.size() ), - strand.wrap(on_read) + m_strand.wrap(on_read) ); } ); @@ -441,13 +447,14 @@ namespace net_utils template void connection::start_write() { - if (state.timers.throttle.out.wait_expire || state.socket.wait_write || - state.data.write.queue.empty() || - (state.ssl.enabled && not state.ssl.handshaked) - ) + if (m_state.timers.throttle.out.wait_expire || m_state.socket.wait_write || + m_state.data.write.queue.empty() || + (m_state.ssl.enabled && !m_state.ssl.handshaked) + ) { return; + } auto self = connection::shared_from_this(); - if (connection_type != e_connection_type_RPC) { + if (m_connection_type != e_connection_type_RPC) { auto calc_duration = [this]{ CRITICAL_REGION_LOCAL( network_throttle_manager_t::m_lock_get_global_throttle_out @@ -457,7 +464,7 @@ namespace net_utils std::min( network_throttle_manager_t::get_global_throttle_out( ).get_sleep_time_after_tick( - state.data.write.queue.back().size() + m_state.data.write.queue.back().size() ), 1.0 ) @@ -467,19 +474,14 @@ namespace net_utils const auto duration = calc_duration(); if (duration > duration_t{}) { ec_t ec; - timers.throttle.out.expires_from_now(duration, ec); - state.timers.throttle.out.wait_expire = true; - timers.throttle.out.async_wait([this, self](const ec_t &ec){ - lock_guard_t guard(state.lock); - state.timers.throttle.out.wait_expire = false; - if (state.timers.throttle.out.cancel_expire) { - state.timers.throttle.out.cancel_expire = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + m_timers.throttle.out.expires_from_now(duration, ec); + m_state.timers.throttle.out.wait_expire = true; + m_timers.throttle.out.async_wait([this, self](const ec_t &ec){ + std::lock_guard guard(m_state.lock); + m_state.timers.throttle.out.wait_expire = false; + if (m_state.timers.throttle.out.cancel_expire) { + m_state.timers.throttle.out.cancel_expire = false; + state_status_check(); } else if (ec.value()) interrupt(); @@ -489,31 +491,26 @@ namespace net_utils } } - state.socket.wait_write = true; + m_state.socket.wait_write = true; auto on_write = [this, self](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); - state.socket.wait_write = false; - if (state.socket.cancel_write) { - state.socket.cancel_write = false; - state.data.write.queue.clear(); - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard guard(m_state.lock); + m_state.socket.wait_write = false; + if (m_state.socket.cancel_write) { + m_state.socket.cancel_write = false; + m_state.data.write.queue.clear(); + state_status_check(); } else if (ec.value()) { - state.data.write.queue.clear(); + m_state.data.write.queue.clear(); interrupt(); } else { { - state.stat.out.throttle.handle_trafic_exact(bytes_transferred); - const auto speed = state.stat.out.throttle.get_current_speed(); - context.m_current_speed_up = speed; - context.m_max_speed_down = std::max( - context.m_max_speed_down, + m_state.stat.out.throttle.handle_trafic_exact(bytes_transferred); + const auto speed = m_state.stat.out.throttle.get_current_speed(); + m_conn_context.m_current_speed_up = speed; + m_conn_context.m_max_speed_down = std::max( + m_conn_context.m_max_speed_down, speed ); { @@ -524,36 +521,36 @@ namespace net_utils ).handle_trafic_exact(bytes_transferred); } connection_basic::logger_handle_net_write(bytes_transferred); - context.m_last_send = time(NULL); - context.m_send_cnt += bytes_transferred; + m_conn_context.m_last_send = time(NULL); + m_conn_context.m_send_cnt += bytes_transferred; start_timer(get_default_timeout(), true); } - assert(bytes_transferred == state.data.write.queue.back().size()); - state.data.write.queue.pop_back(); - state.condition.notify_all(); + assert(bytes_transferred == m_state.data.write.queue.back().size()); + m_state.data.write.queue.pop_back(); + m_state.condition.notify_all(); start_write(); } }; - if (not state.ssl.enabled) + if (!m_state.ssl.enabled) boost::asio::async_write( connection_basic::socket_.next_layer(), boost::asio::buffer( - state.data.write.queue.back().data(), - state.data.write.queue.back().size() + m_state.data.write.queue.back().data(), + m_state.data.write.queue.back().size() ), - strand.wrap(on_write) + m_strand.wrap(on_write) ); else - strand.post( + m_strand.post( [this, self, on_write]{ boost::asio::async_write( connection_basic::socket_, boost::asio::buffer( - state.data.write.queue.back().data(), - state.data.write.queue.back().size() + m_state.data.write.queue.back().data(), + m_state.data.write.queue.back().size() ), - strand.wrap(on_write) + m_strand.wrap(on_write) ); } ); @@ -562,21 +559,29 @@ namespace net_utils template void connection::start_shutdown() { - if (state.socket.wait_shutdown) + if (m_state.socket.wait_shutdown) return; auto self = connection::shared_from_this(); - state.socket.wait_shutdown = true; + m_state.socket.wait_shutdown = true; auto on_shutdown = [this, self](const ec_t &ec){ - lock_guard_t guard(state.lock); - state.socket.wait_shutdown = false; - if (state.socket.cancel_shutdown) { - state.socket.cancel_shutdown = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - terminate(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard guard(m_state.lock); + m_state.socket.wait_shutdown = false; + if (m_state.socket.cancel_shutdown) { + m_state.socket.cancel_shutdown = false; + switch (m_state.status) + { + case status_t::RUNNING: + interrupt(); + break; + case status_t::INTERRUPTED: + terminate(); + break; + case status_t::TERMINATING: + on_terminating(); + break; + default: + break; + } } else if (ec.value()) terminate(); @@ -585,10 +590,10 @@ namespace net_utils on_interrupted(); } }; - strand.post( + m_strand.post( [this, self, on_shutdown]{ connection_basic::socket_.async_shutdown( - strand.wrap(on_shutdown) + m_strand.wrap(on_shutdown) ); } ); @@ -599,24 +604,24 @@ namespace net_utils void connection::cancel_socket() { bool wait_socket = false; - if (state.socket.wait_handshake) - wait_socket = state.socket.cancel_handshake = true; - if (state.timers.throttle.in.wait_expire) { - state.timers.throttle.in.cancel_expire = true; + if (m_state.socket.wait_handshake) + wait_socket = m_state.socket.cancel_handshake = true; + if (m_state.timers.throttle.in.wait_expire) { + m_state.timers.throttle.in.cancel_expire = true; ec_t ec; - timers.throttle.in.cancel(ec); + m_timers.throttle.in.cancel(ec); } - if (state.socket.wait_read) - wait_socket = state.socket.cancel_read = true; - if (state.timers.throttle.out.wait_expire) { - state.timers.throttle.out.cancel_expire = true; + if (m_state.socket.wait_read) + wait_socket = m_state.socket.cancel_read = true; + if (m_state.timers.throttle.out.wait_expire) { + m_state.timers.throttle.out.cancel_expire = true; ec_t ec; - timers.throttle.out.cancel(ec); + m_timers.throttle.out.cancel(ec); } - if (state.socket.wait_write) - wait_socket = state.socket.cancel_write = true; - if (state.socket.wait_shutdown) - wait_socket = state.socket.cancel_shutdown = true; + if (m_state.socket.wait_write) + wait_socket = m_state.socket.cancel_write = true; + if (m_state.socket.wait_shutdown) + wait_socket = m_state.socket.cancel_shutdown = true; if (wait_socket) { ec_t ec; connection_basic::socket_.next_layer().cancel(ec); @@ -626,136 +631,139 @@ namespace net_utils template void connection::cancel_handler() { - if (state.protocol.released || state.protocol.wait_release) + if (m_state.protocol.released || m_state.protocol.wait_release) return; - state.protocol.wait_release = true; - state.lock.unlock(); - handler.release_protocol(); - state.lock.lock(); - state.protocol.wait_release = false; - state.protocol.released = true; - if (state.status == status_t::INTERRUPTED) + m_state.protocol.wait_release = true; + m_state.lock.unlock(); + m_handler.release_protocol(); + m_state.lock.lock(); + m_state.protocol.wait_release = false; + m_state.protocol.released = true; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); } template void connection::interrupt() { - if (state.status != status_t::RUNNING) + if (m_state.status != status_t::RUNNING) return; - state.status = status_t::INTERRUPTED; + m_state.status = status_t::INTERRUPTED; cancel_timer(); cancel_socket(); on_interrupted(); - state.condition.notify_all(); + m_state.condition.notify_all(); cancel_handler(); } template void connection::on_interrupted() { - assert(state.status == status_t::INTERRUPTED); - if (state.timers.general.wait_expire) + assert(m_state.status == status_t::INTERRUPTED); + if (m_state.timers.general.wait_expire) return; - if (state.socket.wait_handshake) + if (m_state.socket.wait_handshake) return; - if (state.timers.throttle.in.wait_expire) + if (m_state.timers.throttle.in.wait_expire) return; - if (state.socket.wait_read) + if (m_state.socket.wait_read) return; - if (state.socket.handle_read) + if (m_state.socket.handle_read) return; - if (state.timers.throttle.out.wait_expire) + if (m_state.timers.throttle.out.wait_expire) return; - if (state.socket.wait_write) + if (m_state.socket.wait_write) return; - if (state.socket.wait_shutdown) + if (m_state.socket.wait_shutdown) return; - if (state.protocol.wait_init) + if (m_state.protocol.wait_init) return; - if (state.protocol.wait_callback) + if (m_state.protocol.wait_callback) return; - if (state.protocol.wait_release) + if (m_state.protocol.wait_release) return; - if (state.socket.connected) { - if (not state.ssl.enabled) { + if (m_state.socket.connected) { + if (!m_state.ssl.enabled) { ec_t ec; connection_basic::socket_.next_layer().shutdown( socket_t::shutdown_both, ec ); connection_basic::socket_.next_layer().close(ec); - state.socket.connected = false; - state.status = status_t::WASTED; + m_state.socket.connected = false; + m_state.status = status_t::WASTED; } else start_shutdown(); } else - state.status = status_t::WASTED; + m_state.status = status_t::WASTED; } template void connection::terminate() { - if (state.status != status_t::RUNNING && - state.status != status_t::INTERRUPTED + if (m_state.status != status_t::RUNNING && + m_state.status != status_t::INTERRUPTED ) return; - state.status = status_t::TERMINATING; + m_state.status = status_t::TERMINATING; cancel_timer(); cancel_socket(); on_terminating(); - state.condition.notify_all(); + m_state.condition.notify_all(); cancel_handler(); } template void connection::on_terminating() { - assert(state.status == status_t::TERMINATING); - if (state.timers.general.wait_expire) + assert(m_state.status == status_t::TERMINATING); + if (m_state.timers.general.wait_expire) return; - if (state.socket.wait_handshake) + if (m_state.socket.wait_handshake) return; - if (state.timers.throttle.in.wait_expire) + if (m_state.timers.throttle.in.wait_expire) return; - if (state.socket.wait_read) + if (m_state.socket.wait_read) return; - if (state.socket.handle_read) + if (m_state.socket.handle_read) return; - if (state.timers.throttle.out.wait_expire) + if (m_state.timers.throttle.out.wait_expire) return; - if (state.socket.wait_write) + if (m_state.socket.wait_write) return; - if (state.socket.wait_shutdown) + if (m_state.socket.wait_shutdown) return; - if (state.protocol.wait_init) + if (m_state.protocol.wait_init) return; - if (state.protocol.wait_callback) + if (m_state.protocol.wait_callback) return; - if (state.protocol.wait_release) + if (m_state.protocol.wait_release) return; - if (state.socket.connected) { + if (m_state.socket.connected) { ec_t ec; connection_basic::socket_.next_layer().shutdown( socket_t::shutdown_both, ec ); connection_basic::socket_.next_layer().close(ec); - state.socket.connected = false; + m_state.socket.connected = false; } - state.status = status_t::WASTED; + m_state.status = status_t::WASTED; } template - bool connection::send(byte_slice_t message) + bool connection::send(epee::byte_slice message) { - lock_guard_t guard(state.lock); - if (state.status != status_t::RUNNING || state.socket.wait_handshake) + std::lock_guard guard(m_state.lock); + if (m_state.status != status_t::RUNNING || m_state.socket.wait_handshake) return false; + + // Wait for the write queue to fall below the max. If it doesn't after a + // randomized delay, drop the connection. auto wait_consume = [this] { auto random_delay = []{ using engine = std::mt19937; @@ -770,62 +778,62 @@ namespace net_utils std::uniform_int_distribution<>(5000, 6000)(rng) ); }; - if (state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT) + if (m_state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT) return true; - state.data.write.wait_consume = true; - bool success = state.condition.wait_for( - state.lock, + m_state.data.write.wait_consume = true; + bool success = m_state.condition.wait_for( + m_state.lock, random_delay(), [this]{ return ( - state.status != status_t::RUNNING || - state.data.write.queue.size() <= + m_state.status != status_t::RUNNING || + m_state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT ); } ); - state.data.write.wait_consume = false; - if (not success) { + m_state.data.write.wait_consume = false; + if (!success) { terminate(); return false; } else - return state.status == status_t::RUNNING; + return m_state.status == status_t::RUNNING; }; auto wait_sender = [this] { - state.condition.wait( - state.lock, + m_state.condition.wait( + m_state.lock, [this] { return ( - state.status != status_t::RUNNING || - not state.data.write.wait_consume + m_state.status != status_t::RUNNING || + !m_state.data.write.wait_consume ); } ); - return state.status == status_t::RUNNING; + return m_state.status == status_t::RUNNING; }; - if (not wait_sender()) + if (!wait_sender()) return false; constexpr size_t CHUNK_SIZE = 32 * 1024; - if (connection_type == e_connection_type_RPC || + if (m_connection_type == e_connection_type_RPC || message.size() <= 2 * CHUNK_SIZE ) { - if (not wait_consume()) + if (!wait_consume()) return false; - state.data.write.queue.emplace_front(std::move(message)); + m_state.data.write.queue.emplace_front(std::move(message)); start_write(); } else { while (!message.empty()) { - if (not wait_consume()) + if (!wait_consume()) return false; - state.data.write.queue.emplace_front( + m_state.data.write.queue.emplace_front( message.take_slice(CHUNK_SIZE) ); start_write(); } } - state.condition.notify_all(); + m_state.condition.notify_all(); return true; } @@ -836,10 +844,10 @@ namespace net_utils boost::optional real_remote ) { - unique_lock_t guard(state.lock); - if (state.status != status_t::TERMINATED) + std::unique_lock guard(m_state.lock); + if (m_state.status != status_t::TERMINATED) return false; - if (not real_remote) { + if (!real_remote) { ec_t ec; auto endpoint = connection_basic::socket_.next_layer().remote_endpoint( ec @@ -866,7 +874,7 @@ namespace net_utils auto *filter = static_cast( connection_basic::get_state() ).pfilter; - if (filter and not filter->is_remote_host_allowed(*real_remote)) + if (filter && !filter->is_remote_host_allowed(*real_remote)) return false; ec_t ec; #if !defined(_WIN32) || !defined(__i686) @@ -886,39 +894,39 @@ namespace net_utils if (ec.value()) return false; connection_basic::m_is_multithreaded = is_multithreaded; - context.set_details( + m_conn_context.set_details( boost::uuids::random_generator()(), *real_remote, is_income, connection_basic::m_ssl_support == ssl_support_t::e_ssl_support_enabled ); - host = real_remote->host_str(); + m_host = real_remote->host_str(); try { host_count(1); } catch(...) { /* ignore */ } - local = real_remote->is_loopback() || real_remote->is_local(); - state.ssl.enabled = ( + m_local = real_remote->is_loopback() || real_remote->is_local(); + m_state.ssl.enabled = ( connection_basic::m_ssl_support != ssl_support_t::e_ssl_support_disabled ); - state.ssl.forced = ( + m_state.ssl.forced = ( connection_basic::m_ssl_support == ssl_support_t::e_ssl_support_enabled ); - state.socket.connected = true; - state.status = status_t::RUNNING; + m_state.socket.connected = true; + m_state.status = status_t::RUNNING; start_timer( std::chrono::milliseconds( - local ? NEW_CONNECTION_TIMEOUT_LOCAL : NEW_CONNECTION_TIMEOUT_REMOTE + m_local ? NEW_CONNECTION_TIMEOUT_LOCAL : NEW_CONNECTION_TIMEOUT_REMOTE ) ); - state.protocol.wait_init = true; + m_state.protocol.wait_init = true; guard.unlock(); - handler.after_init_connection(); + m_handler.after_init_connection(); guard.lock(); - state.protocol.wait_init = false; - state.protocol.initialized = true; - if (state.status == status_t::INTERRUPTED) + m_state.protocol.wait_init = false; + m_state.protocol.initialized = true; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); - else if (not is_income || not state.ssl.enabled) + else if (!is_income || !m_state.ssl.enabled) start_read(); else start_handshake(); @@ -949,23 +957,23 @@ namespace net_utils ssl_support_t ssl_support ): connection_basic(std::move(socket), shared_state, ssl_support), - handler(this, *shared_state, context), - connection_type(connection_type), - io_context{GET_IO_SERVICE(connection_basic::socket_)}, - strand{io_context}, - timers{io_context} + m_handler(this, *shared_state, m_conn_context), + m_connection_type(connection_type), + m_io_context{GET_IO_SERVICE(connection_basic::socket_)}, + m_strand{m_io_context}, + m_timers{m_io_context} { } template connection::~connection() noexcept(false) { - lock_guard_t guard(state.lock); - assert(state.status == status_t::TERMINATED || - state.status == status_t::WASTED || - io_context.stopped() + std::lock_guard guard(m_state.lock); + assert(m_state.status == status_t::TERMINATED || + m_state.status == status_t::WASTED || + m_io_context.stopped() ); - if (state.status != status_t::WASTED) + if (m_state.status != status_t::WASTED) return; try { host_count(-1); } catch (...) { /* ignore */ } } @@ -992,9 +1000,9 @@ namespace net_utils template void connection::save_dbg_log() { - lock_guard_t guard(state.lock); - string_t address; - string_t port; + std::lock_guard guard(m_state.lock); + std::string address; + std::string port; ec_t ec; auto endpoint = connection_basic::socket().remote_endpoint(ec); if (ec.value()) { @@ -1006,10 +1014,10 @@ namespace net_utils port = std::to_string(endpoint.port()); } MDEBUG( - " connection type " << std::to_string(connection_type) << + " connection type " << std::to_string(m_connection_type) << " " << connection_basic::socket().local_endpoint().address().to_string() << ":" << connection_basic::socket().local_endpoint().port() << - " <--> " << context.m_remote_address.str() << + " <--> " << m_conn_context.m_remote_address.str() << " (via " << address << ":" << port << ")" ); } @@ -1017,7 +1025,7 @@ namespace net_utils template bool connection::speed_limit_is_enabled() const { - return connection_type != e_connection_type_RPC; + return m_connection_type != e_connection_type_RPC; } template @@ -1041,8 +1049,8 @@ namespace net_utils template bool connection::close() { - lock_guard_t guard(state.lock); - if (state.status != status_t::RUNNING) + std::lock_guard guard(m_state.lock); + if (m_state.status != status_t::RUNNING) return false; terminate(); return true; @@ -1052,11 +1060,11 @@ namespace net_utils bool connection::call_run_once_service_io() { if(connection_basic::m_is_multithreaded) { - if (not io_context.poll_one()) + if (!m_io_context.poll_one()) misc_utils::sleep_no_w(1); } else { - if (!io_context.run_one()) + if (!m_io_context.run_one()) return false; } return true; @@ -1065,18 +1073,18 @@ namespace net_utils template bool connection::request_callback() { - lock_guard_t guard(state.lock); - if (state.status != status_t::RUNNING) + std::lock_guard guard(m_state.lock); + if (m_state.status != status_t::RUNNING) return false; auto self = connection::shared_from_this(); - ++state.protocol.wait_callback; + ++m_state.protocol.wait_callback; connection_basic::strand_.post([this, self]{ - handler.handle_qued_callback(); - lock_guard_t guard(state.lock); - --state.protocol.wait_callback; - if (state.status == status_t::INTERRUPTED) + m_handler.handle_qued_callback(); + std::lock_guard guard(m_state.lock); + --m_state.protocol.wait_callback; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); }); return true; @@ -1085,7 +1093,7 @@ namespace net_utils template typename connection::io_context_t &connection::get_io_service() { - return io_context; + return m_io_context; } template @@ -1093,9 +1101,9 @@ namespace net_utils { try { auto self = connection::shared_from_this(); - lock_guard_t guard(state.lock); + std::lock_guard guard(m_state.lock); this->self = std::move(self); - ++state.protocol.reference_counter; + ++m_state.protocol.reference_counter; return true; } catch (boost::bad_weak_ptr &exception) { @@ -1107,8 +1115,8 @@ namespace net_utils bool connection::release() { connection_ptr self; - lock_guard_t guard(state.lock); - if (not --state.protocol.reference_counter) + std::lock_guard guard(m_state.lock); + if (!(--m_state.protocol.reference_counter)) self = std::move(this->self); return true; } @@ -1116,8 +1124,8 @@ namespace net_utils template void connection::setRpcStation() { - lock_guard_t guard(state.lock); - connection_type = e_connection_type_RPC; + std::lock_guard guard(m_state.lock); + m_connection_type = e_connection_type_RPC; } template diff --git a/contrib/epee/src/net_ssl.cpp b/contrib/epee/src/net_ssl.cpp index 7dda65bb5..2d0b7d791 100644 --- a/contrib/epee/src/net_ssl.cpp +++ b/contrib/epee/src/net_ssl.cpp @@ -553,9 +553,6 @@ bool ssl_options_t::handshake( using ec_t = boost::system::error_code; using timer_t = boost::asio::steady_timer; using strand_t = boost::asio::io_service::strand; - using lock_t = std::mutex; - using lock_guard_t = std::lock_guard; - using condition_t = std::condition_variable_any; using socket_t = boost::asio::ip::tcp::socket; auto &io_context = GET_IO_SERVICE(socket); @@ -565,8 +562,8 @@ bool ssl_options_t::handshake( timer_t deadline(io_context, timeout); struct state_t { - lock_t lock; - condition_t condition; + std::mutex lock; + std::condition_variable_any condition; ec_t result; bool wait_timer; bool wait_handshake; @@ -577,10 +574,10 @@ bool ssl_options_t::handshake( state.wait_timer = true; auto on_timer = [&](const ec_t &ec){ - lock_guard_t guard(state.lock); + std::lock_guard guard(state.lock); state.wait_timer = false; state.condition.notify_all(); - if (not state.cancel_timer) { + if (!state.cancel_timer) { state.cancel_handshake = true; ec_t ec; socket.next_layer().cancel(ec); @@ -589,11 +586,11 @@ bool ssl_options_t::handshake( state.wait_handshake = true; auto on_handshake = [&](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); + std::lock_guard guard(state.lock); state.wait_handshake = false; state.condition.notify_all(); state.result = ec; - if (not state.cancel_handshake) { + if (!state.cancel_handshake) { state.cancel_timer = true; ec_t ec; deadline.cancel(ec); @@ -614,15 +611,15 @@ bool ssl_options_t::handshake( while (!io_context.stopped()) { io_context.poll_one(); - lock_guard_t guard(state.lock); + std::lock_guard guard(state.lock); state.condition.wait_for( state.lock, std::chrono::milliseconds(30), [&]{ - return not state.wait_timer and not state.wait_handshake; + return !state.wait_timer && !state.wait_handshake; } ); - if (not state.wait_timer and not state.wait_handshake) + if (!state.wait_timer && !state.wait_handshake) break; } if (state.result.value()) { diff --git a/tests/unit_tests/epee_boosted_tcp_server.cpp b/tests/unit_tests/epee_boosted_tcp_server.cpp index d64431edf..c08a86a5e 100644 --- a/tests/unit_tests/epee_boosted_tcp_server.cpp +++ b/tests/unit_tests/epee_boosted_tcp_server.cpp @@ -617,7 +617,7 @@ TEST(boosted_tcp_server, strand_deadlock) void after_init_connection() { unique_lock_t guard(lock); - if (not context.m_is_income) { + if (!context.m_is_income) { guard.unlock(); socket->do_send(byte_slice_t{"."}); } @@ -628,7 +628,7 @@ TEST(boosted_tcp_server, strand_deadlock) bool handle_recv(const char *data, size_t bytes_transferred) { unique_lock_t guard(lock); - if (not context.m_is_income) { + if (!context.m_is_income) { if (context.m_recv_cnt == 1024) { guard.unlock(); socket->do_send(byte_slice_t{"."}); @@ -652,9 +652,9 @@ TEST(boosted_tcp_server, strand_deadlock) void release_protocol() { unique_lock_t guard(lock); - if(not context.m_is_income - and context.m_recv_cnt == 1024 - and context.m_send_cnt == 2 + if(!context.m_is_income + && context.m_recv_cnt == 1024 + && context.m_send_cnt == 2 ) { guard.unlock(); config.notify_success();