Skip to content

Commit

Permalink
net/proxy/ares support custom server
Browse files Browse the repository at this point in the history
  • Loading branch information
iceboy233 committed Oct 22, 2023
1 parent 91658f1 commit 4c50b39
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 13 deletions.
2 changes: 2 additions & 0 deletions net/proxy/ares/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ cc_library(
"//util:int-allocator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/types:span",
"@org_iceboy_trunk//net:asio",
"@org_iceboy_trunk//net:endpoint",
"@org_iceboy_trunk//net:timer-list",
],
)
Expand Down
33 changes: 31 additions & 2 deletions net/proxy/ares/resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ Resolver::Resolver(
abort();
}
ares_set_socket_functions(channel_, &funcs_, this);
if (!options.servers.empty()) {
set_servers(options.servers);
}
}

Resolver::~Resolver() {
Expand Down Expand Up @@ -92,6 +95,30 @@ void Resolver::wait() {
});
}

void Resolver::set_servers(absl::Span<const Endpoint> servers) {
ares_addr_port_node nodes[servers.size()];
for (size_t i = 0; i < servers.size(); ++i) {
nodes[i].next = i + 1 < servers.size() ? &nodes[i + 1] : nullptr;
const address &address = servers[i].address();
if (address.is_v4()) {
nodes[i].family = AF_INET;
auto address_bytes = address.to_v4().to_bytes();
static_assert(address_bytes.size() == 4);
memcpy(&nodes[i].addr.addr4, address_bytes.data(), 4);
} else {
nodes[i].family = AF_INET6;
auto address_bytes = address.to_v6().to_bytes();
static_assert(address_bytes.size() == 16);
memcpy(&nodes[i].addr.addr6, address_bytes.data(), 16);
}
nodes[i].udp_port = servers[i].port();
nodes[i].tcp_port = servers[i].port();
}
if (ares_set_servers_ports(channel_, nodes) != ARES_SUCCESS) {
abort();
}
}

ares_socket_t Resolver::asocket(
int domain, int type, int protocol, void *user_data) {
auto *resolver = reinterpret_cast<Resolver *>(user_data);
Expand Down Expand Up @@ -212,14 +239,16 @@ void Resolver::Operation::parse(int status, ares_addrinfo *ai) {
continue;
}
auto *addr4 = reinterpret_cast<sockaddr_in *>(node->ai_addr);
addresses_.push_back(address_v4(ntohl(addr4->sin_addr.s_addr)));
address_v4::bytes_type bytes;
memcpy(bytes.data(), &addr4->sin_addr, 4);
addresses_.push_back(address_v4(bytes));
} else if (node->ai_family == AF_INET6) {
if (node->ai_addrlen < sizeof(sockaddr_in6)) {
continue;
}
auto *addr6 = reinterpret_cast<sockaddr_in6 *>(node->ai_addr);
address_v6::bytes_type bytes;
memcpy(bytes.data(), addr6->sin6_addr.s6_addr, 16);
memcpy(bytes.data(), &addr6->sin6_addr, 16);
addresses_.push_back(address_v6(bytes));
}
}
Expand Down
4 changes: 4 additions & 0 deletions net/proxy/ares/resolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include "absl/container/flat_hash_map.h"
#include "absl/functional/any_invocable.h"
#include "absl/types/span.h"
#include "net/asio.h"
#include "net/endpoint.h"
#include "net/proxy/ares/socket.h"
#include "net/proxy/connector.h"
#include "net/timer-list.h"
Expand All @@ -22,6 +24,7 @@ namespace ares {
class Resolver {
public:
struct Options {
std::vector<Endpoint> servers;
std::chrono::milliseconds query_timeout = std::chrono::seconds(1);
std::chrono::nanoseconds cache_timeout = std::chrono::minutes(1);
};
Expand All @@ -41,6 +44,7 @@ class Resolver {
class Operation;

void wait();
void set_servers(absl::Span<const Endpoint> servers);

static ares_socket_t asocket(
int domain, int type, int protocol, void *user_data);
Expand Down
10 changes: 6 additions & 4 deletions net/proxy/ares/socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ bool parse_addr(
return false;
}
const auto *addr4 = reinterpret_cast<const sockaddr_in *>(addr);
address = address_v4(ntohl(addr4->sin_addr.s_addr));
address_v4::bytes_type bytes;
memcpy(bytes.data(), &addr4->sin_addr, 4);
address = address_v4(bytes);
port = ntohs(addr4->sin_port);
return true;
} else if (addr->sa_family == AF_INET6) {
Expand All @@ -35,7 +37,7 @@ bool parse_addr(
}
const auto *addr6 = reinterpret_cast<const sockaddr_in6 *>(addr);
address_v6::bytes_type bytes;
memcpy(bytes.data(), addr6->sin6_addr.s6_addr, 16);
memcpy(bytes.data(), &addr6->sin6_addr, 16);
address = address_v6(bytes);
port = ntohs(addr6->sin6_port);
return true;
Expand All @@ -54,7 +56,7 @@ bool populate_addr(
auto *addr4 = reinterpret_cast<sockaddr_in *>(addr);
addr4->sin_family = AF_INET;
addr4->sin_port = htons(port);
addr4->sin_addr.s_addr = htonl(address.to_v4().to_uint());
memcpy(&addr4->sin_addr, address.to_v4().to_bytes().data(), 4);
*addr_len = sizeof(sockaddr_in);
return true;
} else {
Expand All @@ -64,7 +66,7 @@ bool populate_addr(
auto *addr6 = reinterpret_cast<sockaddr_in6 *>(addr);
addr6->sin6_family = AF_INET6;
addr6->sin6_port = htons(port);
memcpy(addr6->sin6_addr.s6_addr, address.to_v6().to_bytes().data(), 16);
memcpy(&addr6->sin6_addr, address.to_v6().to_bytes().data(), 16);
*addr_len = sizeof(sockaddr_in6);
return true;
}
Expand Down
1 change: 1 addition & 0 deletions net/proxy/system/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cc_library(
":connector",
"//net/proxy",
"@org_boost_boost//:property_tree",
"@org_iceboy_trunk//base:logging",
],
alwayslink = 1,
)
Expand Down
13 changes: 13 additions & 0 deletions net/proxy/system/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <memory>
#include <boost/property_tree/ptree.hpp>

#include "base/logging.h"
#include "net/proxy/proxy.h"
#include "net/proxy/registry.h"
#include "net/proxy/system/connector.h"
Expand All @@ -17,6 +18,18 @@ REGISTER_CONNECTOR(system, [](
options.timeout = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::duration<double>(config.get<double>("timeout", 300)));
options.tcp_no_delay = config.get<bool>("tcp_no_delay", true);
const auto &resolver_config = config.get_child("resolver", {});
for (auto iters = resolver_config.equal_range("server");
iters.first != iters.second;
++iters.first) {
std::string server_str = iters.first->second.get_value<std::string>();
auto server_endpoint = Endpoint::from_string(server_str);
if (!server_endpoint) {
LOG(error) << "invalid server endpoint: " << server_str;
continue;
}
options.resolver_options.servers.push_back(*server_endpoint);
}
return std::make_unique<Connector>(proxy.executor(), options);
});

Expand Down
2 changes: 1 addition & 1 deletion net/proxy/system/connector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace system {

Connector::Connector(const any_io_executor &executor, const Options &options)
: executor_(executor),
resolver_(executor_, *this, {}),
resolver_(executor_, *this, options.resolver_options),
timer_list_(executor_, options.timeout),
tcp_no_delay_(options.tcp_no_delay) {}

Expand Down
3 changes: 3 additions & 0 deletions net/proxy/system/connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Connector : public proxy::Connector {
struct Options {
std::chrono::nanoseconds timeout = std::chrono::minutes(5);
bool tcp_no_delay = true;
ares::Resolver::Options resolver_options;
};

Connector(const any_io_executor &executor, const Options &options);
Expand Down Expand Up @@ -49,6 +50,8 @@ class Connector : public proxy::Connector {
std::error_code bind_udp_v4(std::unique_ptr<Datagram> &datagram) override;
std::error_code bind_udp_v6(std::unique_ptr<Datagram> &datagram) override;

ares::Resolver &resolver() { return resolver_; }

private:
template <typename EndpointsT>
void connect_tcp(
Expand Down
8 changes: 2 additions & 6 deletions net/tools/ares-resolve.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <memory>
#include <ostream>
#include <string_view>
#include <system_error>
Expand All @@ -20,12 +19,9 @@ int main(int argc, char *argv[]) {
io_context io_context;
auto executor = io_context.get_executor();
proxy::system::Connector connector(executor, {});
proxy::ares::Resolver resolver(executor, connector, {});

using Result = BlockingResult<std::error_code, std::vector<address>>;
auto results = std::make_unique<Result[]>(argc - 1);
BlockingResult<std::error_code, std::vector<address>> results[argc - 1];
for (int i = 1; i < argc; ++i) {
resolver.resolve(argv[i], results[i - 1].callback());
connector.resolver().resolve(argv[i], results[i - 1].callback());
}

io::OStream os(io::posix::stdout);
Expand Down

0 comments on commit 4c50b39

Please sign in to comment.