From 42c005bf3a8d60a77618fac6faf66f88d023cb1d Mon Sep 17 00:00:00 2001 From: Mikhail Tavarez Date: Fri, 5 Apr 2024 21:19:30 -0500 Subject: [PATCH] clean up copies of file descriptor and remove copyinit --- gojo/net/dial.mojo | 25 +++++++++++++++++++++++++ gojo/net/fd.mojo | 1 - gojo/net/net.mojo | 6 +++--- gojo/net/socket.mojo | 15 ++++++++++++--- gojo/net/tcp.mojo | 21 ++++++++++++++++----- test_get_addr.mojo | 31 ++++++++++++++++++++++++++++++- 6 files changed, 86 insertions(+), 13 deletions(-) diff --git a/gojo/net/dial.mojo b/gojo/net/dial.mojo index e69de29..bb04afe 100644 --- a/gojo/net/dial.mojo +++ b/gojo/net/dial.mojo @@ -0,0 +1,25 @@ +from .tcp import TCPAddr, TCPConnection, resolve_internet_addr +from .socket import Socket + + +@value +struct Dialer(): + var local_address: TCPAddr + + fn dial(self, network: String, address: String) raises -> TCPConnection: + var tcp_addr = resolve_internet_addr(network, address) + var socket = Socket(local_address=self.local_address) + socket.connect(tcp_addr.ip, tcp_addr.port) + print(String("Connected to ") + socket.remote_address) + return TCPConnection(socket ^) + + +fn dial_tcp(network: String, local_address: TCPAddr) raises -> TCPConnection: + # TODO: Add conversion of domain name to ip address + return Dialer(local_address).dial( + network, local_address.ip + ":" + str(local_address.port) + ) + + +fn dial_tcp(network: String, ip: String, port: Int) raises -> TCPConnection: + return Dialer(TCPAddr(ip, port)).dial(network, ip + ":" + str(port)) diff --git a/gojo/net/fd.mojo b/gojo/net/fd.mojo index 673137b..d3458c2 100644 --- a/gojo/net/fd.mojo +++ b/gojo/net/fd.mojo @@ -16,7 +16,6 @@ trait FileDescriptorBase(io.Reader, io.Writer, io.Closer): ... -@value struct FileDescriptor(FileDescriptorBase): var fd: Int var is_closed: Bool diff --git a/gojo/net/net.mojo b/gojo/net/net.mojo index e5c5ee6..daec034 100644 --- a/gojo/net/net.mojo +++ b/gojo/net/net.mojo @@ -9,7 +9,7 @@ alias DEFAULT_BUFFER_SIZE = 4096 trait Conn(io.Writer, io.Reader, io.Closer): - fn __init__(inout self, socket: Socket): + fn __init__(inout self, owned socket: Socket): ... """Conn is a generic stream-oriented network connection.""" @@ -72,8 +72,8 @@ struct Connection(Conn): var fd: Arc[Socket] - fn __init__(inout self, socket: Socket): - self.fd = Arc(socket) + fn __init__(inout self, owned socket: Socket): + self.fd = Arc(socket ^) fn read(inout self, inout dest: List[Byte]) -> Result[Int]: """Reads data from the underlying file descriptor. diff --git a/gojo/net/socket.mojo b/gojo/net/socket.mojo index 04eb0e5..1075775 100644 --- a/gojo/net/socket.mojo +++ b/gojo/net/socket.mojo @@ -55,7 +55,6 @@ from .address import Addr, TCPAddr, HostPort alias SocketClosedError = Error("Socket: Socket is already closed") -@value struct Socket(FileDescriptorBase): """Represents a network file descriptor. Wraps around a file descriptor and provides network functions. @@ -135,8 +134,18 @@ struct Socket(FileDescriptorBase): self._closed = False self._is_connected = True - fn __enter__(self) -> Self: - return self + fn __moveinit__(inout self, owned existing: Self): + self.sockfd = existing.sockfd ^ + self.address_family = existing.address_family + self.socket_type = existing.socket_type + self.protocol = existing.protocol + self.local_address = existing.local_address ^ + self.remote_address = existing.remote_address ^ + self._closed = existing._closed + self._is_connected = existing._is_connected + + # fn __enter__(self) -> Self: + # return self # fn __exit__(inout self) raises: # if self._is_connected: diff --git a/gojo/net/tcp.mojo b/gojo/net/tcp.mojo index d543170..253206a 100644 --- a/gojo/net/tcp.mojo +++ b/gojo/net/tcp.mojo @@ -56,10 +56,10 @@ struct ListenConfig(CollectionElement): socket.set_socket_option(SO_REUSEADDR, 1) socket.listen() print(String("Listening on ") + socket.local_address) - return TCPListener(socket, self, network, address) + return TCPListener(socket ^, self, network, address) -trait Listener(CollectionElement): +trait Listener(Movable): # Raising here because a Result[Optional[Connection], Optional[WrappedError]] is funky. fn accept(self) raises -> Connection: ... @@ -84,8 +84,8 @@ struct TCPConnection(Conn): fn __init__(inout self, connection: Connection): self._connection = connection - fn __init__(inout self, socket: Socket): - self._connection = Connection(socket) + fn __init__(inout self, owned socket: Socket): + self._connection = Connection(socket ^) fn __moveinit__(inout self, owned existing: Self): self._connection = existing._connection ^ @@ -162,13 +162,24 @@ fn listen_tcp(network: String, ip: String, port: Int) raises -> TCPListener: return ListenConfig(DEFAULT_TCP_KEEP_ALIVE).listen(network, ip + ":" + str(port)) -@value struct TCPListener(Listener): var _file_descriptor: Socket var listen_config: ListenConfig var network_type: String var address: String + fn __init__(inout self, owned file_descriptor: Socket, listen_config: ListenConfig, network_type: String, address: String): + self._file_descriptor = file_descriptor ^ + self.listen_config = listen_config + self.network_type = network_type + self.address = address + + fn __moveinit__(inout self, owned existing: Self): + self._file_descriptor = existing._file_descriptor ^ + self.listen_config = existing.listen_config ^ + self.network_type = existing.network_type + self.address = existing.address + fn listen(self) raises -> Self: return self.listen_config.listen(self.network_type, self.address) diff --git a/test_get_addr.mojo b/test_get_addr.mojo index 9972706..bc34faf 100644 --- a/test_get_addr.mojo +++ b/test_get_addr.mojo @@ -2,12 +2,40 @@ from gojo.net.socket import Socket from gojo.net.ip import get_ip_address from gojo.net.tcp import listen_tcp, TCPAddr from gojo.syscall.net import SO_REUSEADDR, PF_UNIX, SO_RCVTIMEO +from gojo.net.dial import dial_tcp # fn main() raises: # var ip = get_ip_address("localhost") # print(ip) +fn test_dial() raises: + var connection = dial_tcp("tcp", get_ip_address("www.example.com"), 80) + var result = connection.write(String("GET / HTTP/1.1\r\n\r\n").as_bytes()) + if result.error: + raise result.unwrap_error().error + + if result.value == 0: + print("No bytes sent to peer.") + return + + var response = List[Int8](capacity=4096) + result = connection.read(response) + if result.error: + raise result.unwrap_error().error + + if result.value == 0: + print("No bytes received from peer.") + return + + response.append(0) + print(String(response)) + + var err = connection.close() + if err: + raise err.value().error + + fn test_listener() raises: var listener = listen_tcp("tcp", TCPAddr("0.0.0.0", 8081)) while True: @@ -52,5 +80,6 @@ fn test_stuff() raises: fn main() raises: - test_stuff() + # test_stuff() # test_listener() + test_dial()