Skip to content

Commit

Permalink
net/proxy use perfect forwarding for templated callbacks
Browse files Browse the repository at this point in the history
This workarounds of a caveat in the boost asio library, where async operations prepare buffers and move self in a single line, see read_op [1] for example. When the callback is a value type instead of a reference, self could be moved before the buffer is accessed, resulting in a use-after-move error.

[1] https://github.com/boostorg/asio/blob/d6e7b5a547daaddfd19c548d2f602cb5b15361df/include/boost/asio/impl/read.hpp#L398
  • Loading branch information
iceboy233 committed Aug 26, 2024
1 parent 080c6b5 commit 6c2dc33
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 68 deletions.
28 changes: 14 additions & 14 deletions net/proxy/datagram.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,56 @@ class Datagram {

virtual ~Datagram() = default;

virtual void async_receive_from(
virtual void receive_from(
absl::Span<mutable_buffer const> buffers,
udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) = 0;

virtual void async_send_to(
virtual void send_to(
absl::Span<const_buffer const> buffers,
const udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) = 0;

virtual any_io_executor get_executor() = 0;
virtual void close() = 0;

template <typename BuffersT>
template <typename BuffersT, typename CallbackT>
void async_receive_from(
const BuffersT &buffers,
udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback);
CallbackT &&callback);

template <typename BuffersT>
template <typename BuffersT, typename CallbackT>
void async_send_to(
const BuffersT &buffers,
const udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback);
CallbackT &&callback);
};

template <typename BuffersT>
template <typename BuffersT, typename CallbackT>
void Datagram::async_receive_from(
const BuffersT &buffers,
udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
async_receive_from(
CallbackT &&callback) {
receive_from(
absl::Span<mutable_buffer const>(
buffer_sequence_begin(buffers),
buffer_sequence_end(buffers) - buffer_sequence_begin(buffers)),
endpoint,
std::move(callback));
std::forward<CallbackT>(callback));
}

template <typename BuffersT>
template <typename BuffersT, typename CallbackT>
void Datagram::async_send_to(
const BuffersT &buffers,
const udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
async_send_to(
CallbackT &&callback) {
send_to(
absl::Span<const_buffer const>(
buffer_sequence_begin(buffers),
buffer_sequence_end(buffers) - buffer_sequence_begin(buffers)),
endpoint,
std::move(callback));
std::forward<CallbackT>(callback));
}

} // namespace proxy
Expand Down
28 changes: 14 additions & 14 deletions net/proxy/shadowsocks/connector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class Connector::TcpStream : public proxy::Stream {
const_buffer initial_data,
absl::AnyInvocable<void(std::error_code) &&> callback);

void async_read_some(
void read(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) override;

void async_write_some(
void write(
absl::Span<const_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) override;

Expand All @@ -49,7 +49,7 @@ class Connector::TcpStream : public proxy::Stream {
private:
void connect(absl::AnyInvocable<void(std::error_code) &&> callback);

void read(
void read_internal(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback);

Expand Down Expand Up @@ -310,14 +310,14 @@ void Connector::TcpStream::connect(
}
}

void Connector::TcpStream::async_read_some(
void Connector::TcpStream::read(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
while (true) {
switch (read_state_) {
case ReadState::init:
if (!decryptor_.init(connector_.pre_shared_key_)) {
read(buffers, std::move(callback));
read_internal(buffers, std::move(callback));
return;
}
if (!connector_.pre_shared_key_.method().is_spec_2022()) {
Expand All @@ -329,21 +329,21 @@ void Connector::TcpStream::async_read_some(
case ReadState::header:
if (!decryptor_.start_chunk(
connector_.pre_shared_key_.method().salt_size() + 11)) {
read(buffers, std::move(callback));
read_internal(buffers, std::move(callback));
return;
}
if (!connector_.salt_filter_.test_and_insert({
decryptor_.salt(),
connector_.pre_shared_key_.method().salt_size()})) {
LOG(warning) << "duplicated salt";
decryptor_.discard();
read(buffers, std::move(callback));
read_internal(buffers, std::move(callback));
return;
}
if (decryptor_.pop_u8() != 1) {
LOG(warning) << "unexpected header type";
decryptor_.discard();
read(buffers, std::move(callback));
read_internal(buffers, std::move(callback));
return;
}
if (std::abs(static_cast<int64_t>(decryptor_.pop_big_u64()) -
Expand All @@ -352,7 +352,7 @@ void Connector::TcpStream::async_read_some(
.count()) > 30) {
LOG(warning) << "time difference too large";
decryptor_.discard();
read(buffers, std::move(callback));
read_internal(buffers, std::move(callback));
return;
}
if (memcmp(
Expand All @@ -362,7 +362,7 @@ void Connector::TcpStream::async_read_some(
connector_.pre_shared_key_.method().salt_size())) {
LOG(warning) << "salt mismatch";
decryptor_.discard();
read(buffers, std::move(callback));
read_internal(buffers, std::move(callback));
return;
}
read_length_ = decryptor_.pop_big_u16();
Expand All @@ -371,7 +371,7 @@ void Connector::TcpStream::async_read_some(
continue;
case ReadState::length:
if (!decryptor_.start_chunk(2)) {
read(buffers, std::move(callback));
read_internal(buffers, std::move(callback));
return;
}
read_length_ = decryptor_.pop_big_u16();
Expand All @@ -380,7 +380,7 @@ void Connector::TcpStream::async_read_some(
[[fallthrough]];
case ReadState::payload:
if (!decryptor_.start_chunk(read_length_)) {
read(buffers, std::move(callback));
read_internal(buffers, std::move(callback));
return;
}
read_buffer_ = {decryptor_.pop_buffer(read_length_), read_length_};
Expand Down Expand Up @@ -409,7 +409,7 @@ void Connector::TcpStream::async_read_some(
}
}

void Connector::TcpStream::read(
void Connector::TcpStream::read_internal(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
absl::FixedArray<mutable_buffer, 1> buffers_copy(
Expand All @@ -429,7 +429,7 @@ void Connector::TcpStream::read(
});
}

void Connector::TcpStream::async_write_some(
void Connector::TcpStream::write(
absl::Span<const_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
size_t total_size = 0;
Expand Down
36 changes: 14 additions & 22 deletions net/proxy/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,40 @@ class Stream {

virtual ~Stream() = default;

virtual void async_read_some(
virtual void read(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) = 0;

virtual void async_write_some(
virtual void write(
absl::Span<const_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) = 0;

virtual any_io_executor get_executor() = 0;
virtual void close() = 0;

template <typename BuffersT>
void async_read_some(
const BuffersT &buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback);
template <typename BuffersT, typename CallbackT>
void async_read_some(const BuffersT &buffers, CallbackT &&callback);

template <typename BuffersT>
void async_write_some(
const BuffersT &buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback);
template <typename BuffersT, typename CallbackT>
void async_write_some(const BuffersT &buffers, CallbackT &&callback);
};

template <typename BuffersT>
void Stream::async_read_some(
const BuffersT &buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
async_read_some(
template <typename BuffersT, typename CallbackT>
void Stream::async_read_some(const BuffersT &buffers, CallbackT &&callback) {
read(
absl::Span<mutable_buffer const>(
buffer_sequence_begin(buffers),
buffer_sequence_end(buffers) - buffer_sequence_begin(buffers)),
std::move(callback));
std::forward<CallbackT>(callback));
}

template <typename BuffersT>
void Stream::async_write_some(
const BuffersT &buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
async_write_some(
template <typename BuffersT, typename CallbackT>
void Stream::async_write_some(const BuffersT &buffers, CallbackT &&callback) {
write(
absl::Span<const_buffer const>(
buffer_sequence_begin(buffers),
buffer_sequence_end(buffers) - buffer_sequence_begin(buffers)),
std::move(callback));
std::forward<CallbackT>(callback));
}

} // namespace proxy
Expand Down
4 changes: 2 additions & 2 deletions net/proxy/system/stdio-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ StdioStream::StdioStream(const any_io_executor &executor)
#endif
{}

void StdioStream::async_read_some(
void StdioStream::read(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
#ifndef _WIN32
Expand All @@ -49,7 +49,7 @@ void StdioStream::async_read_some(
#endif
}

void StdioStream::async_write_some(
void StdioStream::write(
absl::Span<const_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
#ifndef _WIN32
Expand Down
4 changes: 2 additions & 2 deletions net/proxy/system/stdio-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class StdioStream : public Stream {
StdioStream(const StdioStream &) = delete;
StdioStream &operator=(const StdioStream &) = delete;

void async_read_some(
void read(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) override;

void async_write_some(
void write(
absl::Span<const_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) override;

Expand Down
4 changes: 2 additions & 2 deletions net/proxy/system/tcp-socket-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TcpSocketStream::TcpSocketStream(tcp::socket socket, TimerList &timer_list)
: socket_(std::move(socket)),
timer_(timer_list, [this]() { close(); }) {}

void TcpSocketStream::async_read_some(
void TcpSocketStream::read(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
socket_.async_read_some(
Expand All @@ -19,7 +19,7 @@ void TcpSocketStream::async_read_some(
timer_.update();
}

void TcpSocketStream::async_write_some(
void TcpSocketStream::write(
absl::Span<const_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
socket_.async_write_some(
Expand Down
7 changes: 2 additions & 5 deletions net/proxy/system/tcp-socket-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,17 @@ class TcpSocketStream : public Stream {
TcpSocketStream(const TcpSocketStream &) = delete;
TcpSocketStream &operator=(const TcpSocketStream &) = delete;

void async_read_some(
void read(
absl::Span<mutable_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) override;

void async_write_some(
void write(
absl::Span<const_buffer const> buffers,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) override;

any_io_executor get_executor() override { return socket_.get_executor(); }
void close() override;

using Stream::async_read_some;
using Stream::async_write_some;

tcp::socket &socket() { return socket_; }
const tcp::socket &socket() const { return socket_; }

Expand Down
4 changes: 2 additions & 2 deletions net/proxy/system/udp-socket-datagram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace system {
UdpSocketDatagram::UdpSocketDatagram(udp::socket socket)
: socket_(std::move(socket)) {}

void UdpSocketDatagram::async_receive_from(
void UdpSocketDatagram::receive_from(
absl::Span<mutable_buffer const> buffers,
udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
Expand All @@ -19,7 +19,7 @@ void UdpSocketDatagram::async_receive_from(
std::move(callback));
}

void UdpSocketDatagram::async_send_to(
void UdpSocketDatagram::send_to(
absl::Span<const_buffer const> buffers,
const udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) {
Expand Down
7 changes: 2 additions & 5 deletions net/proxy/system/udp-socket-datagram.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,19 @@ class UdpSocketDatagram : public Datagram {
UdpSocketDatagram(const UdpSocketDatagram &) = delete;
UdpSocketDatagram &operator=(const UdpSocketDatagram &) = delete;

void async_receive_from(
void receive_from(
absl::Span<mutable_buffer const> buffers,
udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) override;

void async_send_to(
void send_to(
absl::Span<const_buffer const> buffers,
const udp::endpoint &endpoint,
absl::AnyInvocable<void(std::error_code, size_t) &&> callback) override;

any_io_executor get_executor() override { return socket_.get_executor(); }
void close() override;

using Datagram::async_receive_from;
using Datagram::async_send_to;

udp::socket &socket() { return socket_; }
const udp::socket &socket() const { return socket_; }

Expand Down

0 comments on commit 6c2dc33

Please sign in to comment.