diff --git a/gojo/bufio/bufio.mojo b/gojo/bufio/bufio.mojo index 95386fa..186366b 100644 --- a/gojo/bufio/bufio.mojo +++ b/gojo/bufio/bufio.mojo @@ -603,7 +603,7 @@ struct Reader[R: io.Reader]( var result = writer.write(self.buf[self.read_pos : self.write_pos]) if result.error: return Result(Int64(result.value), result.error) - + var bytes_written = result.value if bytes_written < 0: panic(ERR_NEGATIVE_WRITE) @@ -885,13 +885,13 @@ struct Writer[W: io.Writer]( while nr < MAX_CONSECUTIVE_EMPTY_READS: # TODO: should really be using a slice that returns refs and not a copy. # Read into remaining unused space in the buffer. We need to reserve capacity for the slice otherwise read will never hit EOF. - var sl = self.buf[self.bytes_written:len(self.buf)] + var sl = self.buf[self.bytes_written : len(self.buf)] sl.reserve(self.buf.capacity) var result = reader.read(sl) bytes_read = result.value err = result.get_error() _ = copy(self.buf, sl, self.bytes_written) - + if bytes_read != 0 or err: break nr += 1 @@ -910,7 +910,7 @@ struct Writer[W: io.Writer]( err = self.flush() else: err = None - + return Result(total_bytes_written, None) diff --git a/gojo/builtins/bytes.mojo b/gojo/builtins/bytes.mojo index 2d72ee4..0504d16 100644 --- a/gojo/builtins/bytes.mojo +++ b/gojo/builtins/bytes.mojo @@ -35,21 +35,21 @@ fn has_suffix(bytes: List[Byte], suffix: List[Byte]) -> Bool: fn index_byte(bytes: List[Byte], delim: Byte) -> Int: - """Return the index of the first occurrence of the byte delim. + """Return the index of the first occurrence of the byte delim. - Args: - bytes: The List[Byte] struct to search. - delim: The byte to search for. + Args: + bytes: The List[Byte] struct to search. + delim: The byte to search for. - Returns: - The index of the first occurrence of the byte delim. - """ - var i = 0 - for i in range(len(bytes)): - if bytes[i] == delim: - return i + Returns: + The index of the first occurrence of the byte delim. + """ + var i = 0 + for i in range(len(bytes)): + if bytes[i] == delim: + return i - return -1 + return -1 fn to_string(bytes: List[Byte]) -> String: diff --git a/gojo/builtins/list.mojo b/gojo/builtins/list.mojo index ddbfdbb..cb32504 100644 --- a/gojo/builtins/list.mojo +++ b/gojo/builtins/list.mojo @@ -130,4 +130,4 @@ fn equals(left: List[Bool], right: List[Bool]) -> Bool: for i in range(len(left)): if left[i] != right[i]: return False - return True \ No newline at end of file + return True diff --git a/gojo/bytes/reader.mojo b/gojo/bytes/reader.mojo index 7fe1945..c337ccc 100644 --- a/gojo/bytes/reader.mojo +++ b/gojo/bytes/reader.mojo @@ -216,4 +216,3 @@ fn new_reader(buffer: String) -> Reader: """ return Reader(buffer.as_bytes(), 0, -1) - diff --git a/gojo/io/traits.mojo b/gojo/io/traits.mojo index 7b9f4d6..bac52f6 100644 --- a/gojo/io/traits.mojo +++ b/gojo/io/traits.mojo @@ -106,7 +106,7 @@ trait Closer(Movable): Specific implementations may document their own behavior. """ - fn close(inout self) raises: + fn close(inout self) -> Optional[WrappedError]: ... diff --git a/gojo/net/__init__.mojo b/gojo/net/__init__.mojo new file mode 100644 index 0000000..2587673 --- /dev/null +++ b/gojo/net/__init__.mojo @@ -0,0 +1,4 @@ +"""Adapted from go's net package + +A good chunk of the leg work here came from the lightbug_http project! https://github.com/saviorand/lightbug_http/tree/main +""" diff --git a/gojo/net/address.mojo b/gojo/net/address.mojo new file mode 100644 index 0000000..3ba6d70 --- /dev/null +++ b/gojo/net/address.mojo @@ -0,0 +1,148 @@ +@value +struct NetworkType: + var value: String + + alias empty = NetworkType("") + alias tcp = NetworkType("tcp") + alias tcp4 = NetworkType("tcp4") + alias tcp6 = NetworkType("tcp6") + alias udp = NetworkType("udp") + alias udp4 = NetworkType("udp4") + alias udp6 = NetworkType("udp6") + alias ip = NetworkType("ip") + alias ip4 = NetworkType("ip4") + alias ip6 = NetworkType("ip6") + alias unix = NetworkType("unix") + + +trait Addr(CollectionElement, Stringable): + fn network(self) -> String: + """Name of the network (for example, "tcp", "udp").""" + ... + + +@value +struct TCPAddr(Addr): + """Addr struct representing a TCP address. + + Args: + ip: IP address. + port: Port number. + zone: IPv6 addressing zone. + """ + var ip: String + var port: Int + var zone: String # IPv6 addressing zone + + fn __init__(inout self): + self.ip = String("127.0.0.1") + self.port = 8000 + self.zone = "" + + fn __init__(inout self, ip: String, port: Int): + self.ip = ip + self.port = port + self.zone = "" + + fn __str__(self) -> String: + if self.zone != "": + return join_host_port(String(self.ip) + "%" + self.zone, self.port) + return join_host_port(self.ip, self.port) + + fn network(self) -> String: + return NetworkType.tcp.value + + +fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: + var host: String = "" + var port: String = "" + var portnum: Int = 0 + if ( + network == NetworkType.tcp.value + or network == NetworkType.tcp4.value + or network == NetworkType.tcp6.value + or network == NetworkType.udp.value + or network == NetworkType.udp4.value + or network == NetworkType.udp6.value + ): + if address != "": + var host_port = split_host_port(address) + host = host_port.host + port = host_port.port + portnum = atol(port.__str__()) + elif ( + network == NetworkType.ip.value + or network == NetworkType.ip4.value + or network == NetworkType.ip6.value + ): + if address != "": + host = address + elif network == NetworkType.unix.value: + raise Error("Unix addresses not supported yet") + else: + raise Error("unsupported network type: " + network) + return TCPAddr(host, portnum) + + +fn join_host_port(host: String, port: String) -> String: + if host.find(":") != -1: # must be IPv6 literal + return "[" + host + "]:" + port + return host + ":" + port + + +alias missingPortError = Error("missing port in address") +alias tooManyColonsError = Error("too many colons in address") + + +struct HostPort(Stringable): + var host: String + var port: Int + + fn __init__(inout self, host: String, port: Int): + self.host = host + self.port = port + + fn __str__(self) -> String: + return join_host_port(self.host, str(self.port)) + + +fn split_host_port(hostport: String) raises -> HostPort: + var host: String = "" + var port: String = "" + var colon_index = hostport.rfind(":") + var j: Int = 0 + var k: Int = 0 + + if colon_index == -1: + raise missingPortError + if hostport[0] == "[": + var end_bracket_index = hostport.find("]") + if end_bracket_index == -1: + raise Error("missing ']' in address") + if end_bracket_index + 1 == len(hostport): + raise missingPortError + elif end_bracket_index + 1 == colon_index: + host = hostport[1:end_bracket_index] + j = 1 + k = end_bracket_index + 1 + else: + if hostport[end_bracket_index + 1] == ":": + raise tooManyColonsError + else: + raise missingPortError + else: + host = hostport[:colon_index] + if host.find(":") != -1: + raise tooManyColonsError + if hostport[j:].find("[") != -1: + raise Error("unexpected '[' in address") + if hostport[k:].find("]") != -1: + raise Error("unexpected ']' in address") + port = hostport[colon_index + 1 :] + + if port == "": + raise missingPortError + if host == "": + raise Error("missing host") + + return HostPort(host, atol(port)) diff --git a/gojo/net/fd.mojo b/gojo/net/fd.mojo new file mode 100644 index 0000000..4780431 --- /dev/null +++ b/gojo/net/fd.mojo @@ -0,0 +1,111 @@ +from collections.optional import Optional +import ..io +from ..builtins import Byte, Result, WrappedError +from ..syscall.file import close +from ..syscall.types import ( + c_void, + c_uint, + c_char, + c_int, +) +from ..syscall.net import ( + sockaddr, + sockaddr_in, + addrinfo, + addrinfo_unix, + socklen_t, + socket, + connect, + recv, + send, + shutdown, + inet_pton, + inet_ntoa, + inet_ntop, + to_char_ptr, + htons, + ntohs, + strlen, + getaddrinfo, + getaddrinfo_unix, + gai_strerror, + c_charptr_to_string, + bind, + listen, + accept, + setsockopt, + getsockopt, + getsockname, + getpeername, + c_charptr_to_string, + AF_INET, + SOCK_STREAM, + SHUT_RDWR, + AI_PASSIVE, + SOL_SOCKET, + SO_REUSEADDR, + SO_RCVTIMEO, +) +from external.libc import Str, c_ssize_t, c_size_t, char_pointer + +alias O_RDWR = 0o2 + + +trait FileDescriptorBase(io.Reader, io.Writer, io.Closer): + ... + + +@value +struct FileDescriptor(FileDescriptorBase): + var fd: Int + + # This takes ownership of a POSIX file descriptor. + fn __moveinit__(inout self, owned existing: Self): + self.fd = existing.fd + + fn __init__(inout self, fd: Int): + self.fd = fd + + fn __del__(owned self): + var err = self.close() + if err: + print(err.value()) + + fn close(inout self) -> Optional[WrappedError]: + """Mark the file descriptor as closed.""" + var close_status = close(self.fd) + if close_status == -1: + return WrappedError("FileDescriptor.close: Failed to close socket") + + return None + + fn dup(self) -> Self: + """Duplicate the file descriptor.""" + var new_fd = external_call["dup", Int, Int](self.fd) + return Self(new_fd) + + fn read(inout self, inout dest: List[Byte]) -> Result[Int]: + """Receive data from the file descriptor and write it to the buffer provided.""" + var ptr = Pointer[UInt8]().alloc(dest.capacity) + var bytes_received = recv(self.fd, ptr, dest.capacity, 0) + if bytes_received == -1: + return Result(0, WrappedError("Failed to receive message from socket.")) + + var int8_ptr = ptr.bitcast[Int8]() + for i in range(bytes_received): + dest.append(int8_ptr[i]) + + if bytes_received < dest.capacity: + return Result(bytes_received, WrappedError(io.EOF)) + + return bytes_received + + fn write(inout self, src: List[Byte]) -> Result[Int]: + """Write data from the buffer to the file descriptor.""" + var header_pointer = Pointer[Int8](src.data.value).bitcast[UInt8]() + + var bytes_sent = send(self.fd, header_pointer, strlen(header_pointer), 0) + if bytes_sent == -1: + return Result(0, WrappedError("Failed to send message")) + + return bytes_sent diff --git a/gojo/net/ip.mojo b/gojo/net/ip.mojo new file mode 100644 index 0000000..ba8a7b2 --- /dev/null +++ b/gojo/net/ip.mojo @@ -0,0 +1,186 @@ +from utils.variant import Variant +from sys.info import os_is_linux, os_is_macos +from ..syscall.types import ( + c_int, + c_char, + c_void, + c_uint, +) +from ..syscall.net import ( + addrinfo, + addrinfo_unix, + AF_INET, + SOCK_STREAM, + AI_PASSIVE, + sockaddr, + sockaddr_in, + htons, + ntohs, + inet_pton, + inet_ntop, + getaddrinfo, + getaddrinfo_unix, + gai_strerror, + to_char_ptr, + c_charptr_to_string +) + +alias AddrInfo = Variant[addrinfo, addrinfo_unix] + +fn get_addr_info(host: String) raises -> AddrInfo: + var status: Int32 = 0 + if os_is_macos(): + var servinfo = Pointer[addrinfo]().alloc(1) + servinfo.store(addrinfo()) + var hints = addrinfo() + hints.ai_family = AF_INET + hints.ai_socktype = SOCK_STREAM + hints.ai_flags = AI_PASSIVE + + var host_ptr = to_char_ptr(host) + + var status = getaddrinfo( + host_ptr, + Pointer[UInt8](), + Pointer.address_of(hints), + Pointer.address_of(servinfo), + ) + if status != 0: + print("getaddrinfo failed to execute with status:", status) + var msg_ptr = gai_strerror(c_int(status)) + _ = external_call["printf", c_int, Pointer[c_char], Pointer[c_char]]( + to_char_ptr("gai_strerror: %s"), msg_ptr + ) + var msg = c_charptr_to_string(msg_ptr) + print("getaddrinfo error message: ", msg) + + if not servinfo: + print("servinfo is null") + raise Error("Failed to get address info. Pointer to addrinfo is null.") + + return servinfo.load() + elif os_is_linux(): + var servinfo = Pointer[addrinfo_unix]().alloc(1) + servinfo.store(addrinfo_unix()) + var hints = addrinfo_unix() + hints.ai_family = AF_INET + hints.ai_socktype = SOCK_STREAM + hints.ai_flags = AI_PASSIVE + + var host_ptr = to_char_ptr(host) + + var status = getaddrinfo_unix( + host_ptr, + Pointer[UInt8](), + Pointer.address_of(hints), + Pointer.address_of(servinfo), + ) + if status != 0: + print("getaddrinfo failed to execute with status:", status) + var msg_ptr = gai_strerror(c_int(status)) + _ = external_call["printf", c_int, Pointer[c_char], Pointer[c_char]]( + to_char_ptr("gai_strerror: %s"), msg_ptr + ) + var msg = c_charptr_to_string(msg_ptr) + print("getaddrinfo error message: ", msg) + + if not servinfo: + print("servinfo is null") + raise Error("Failed to get address info. Pointer to addrinfo is null.") + + return servinfo.load() + else: + raise Error("Windows is not supported yet! Sorry!") + + +fn get_ip_address(host: String) raises -> String: + """Get the IP address of a host.""" + # Call getaddrinfo to get the IP address of the host. + var result = get_addr_info(host) + var ai_addr: Pointer[sockaddr] + var address_family: Int32 = 0 + var address_length: UInt32 = 0 + if result.isa[addrinfo](): + var addrinfo = result.get[addrinfo]() + ai_addr = addrinfo[].ai_addr + address_family = addrinfo[].ai_family + address_length = addrinfo[].ai_addrlen + else: + var addrinfo = result.get[addrinfo_unix]() + ai_addr = addrinfo[].ai_addr + address_family = addrinfo[].ai_family + address_length = addrinfo[].ai_addrlen + + if not ai_addr: + print("ai_addr is null") + raise Error( + "Failed to get IP address. getaddrinfo was called successfully, but ai_addr" + " is null." + ) + + # Cast sockaddr struct to sockaddr_in struct and convert the binary IP to a string using inet_ntop. + var addr_in = ai_addr.bitcast[sockaddr_in]().load() + + return convert_binary_ip_to_string( + addr_in.sin_addr.s_addr, address_family, address_length + ).strip() + + +fn convert_port_to_binary(port: Int) -> UInt16: + return htons(UInt16(port)) + + +fn convert_binary_port_to_int(port: UInt16) -> Int: + return int(ntohs(port)) + + +fn convert_ip_to_binary(ip_address: String, address_family: Int) -> UInt32: + var ip_buffer = Pointer[c_void].alloc(4) + var status = inet_pton(address_family, to_char_ptr(ip_address), ip_buffer) + if status == -1: + print("Failed to convert IP address to binary") + + return ip_buffer.bitcast[c_uint]().load() + + +fn convert_binary_ip_to_string( + owned ip_address: UInt32, address_family: Int32, address_length: UInt32 +) -> String: + """Convert a binary IP address to a string by calling inet_ntop. + + Args: + ip_address: The binary IP address. + address_family: The address family of the IP address. + address_length: The length of the address. + + Returns: + The IP address as a string. + """ + # It seems like the len of the buffer depends on the length of the string IP. + # Allocating 10 works for localhost (127.0.0.1) which I suspect is 9 bytes + 1 null terminator byte. So max should be 16 (15 + 1). + var ip_buffer = Pointer[c_void].alloc(16) + var ip_address_ptr = Pointer.address_of(ip_address).bitcast[c_void]() + _ = inet_ntop(address_family, ip_address_ptr, ip_buffer, 16) + + var string_buf = ip_buffer.bitcast[Int8]() + var index = 0 + while True: + if string_buf[index] == 0: + break + index += 1 + + return StringRef(string_buf, index) + + +fn build_sockaddr_pointer( + ip_address: String, port: Int, address_family: Int +) -> Pointer[sockaddr]: + """Build a sockaddr pointer from an IP address and port number. + https://learn.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 + https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-sockaddr_in. + """ + var bin_port = convert_port_to_binary(port) + var bin_ip = convert_ip_to_binary(ip_address, address_family) + + var ai = sockaddr_in(address_family, bin_port, bin_ip, StaticTuple[c_char, 8]()) + return Pointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() diff --git a/gojo/net/net.mojo b/gojo/net/net.mojo new file mode 100644 index 0000000..5e4d8d5 --- /dev/null +++ b/gojo/net/net.mojo @@ -0,0 +1,108 @@ +from collections.optional import Optional +from memory._arc import Arc +import ..io +from ..builtins import Byte, Result, WrappedError +from .socket import Socket +from .address import Addr, TCPAddr + + +# Time in nanoseconds +alias Duration = Int +alias DEFAULT_BUFFER_SIZE = 4096 +alias DEFAULT_TCP_KEEP_ALIVE = Duration(15 * 1000 * 1000 * 1000) # 15 seconds + + +trait Listener(Movable): + fn accept(borrowed self) raises -> Connection: + ... + + fn close(self) -> Optional[WrappedError]: + ... + + fn addr(self) -> Addr: + ... + + +trait Conn(io.Writer, io.Reader, io.Closer): + fn local_address(self) -> TCPAddr: + """Returns the local network address, if known.""" + ... + + fn remote_address(self) -> TCPAddr: + """Returns the local network address, if known.""" + ... + + # fn set_deadline(self, t: time.Time) -> Error: + # """Sets the read and write deadlines associated + # with the connection. It is equivalent to calling both + # SetReadDeadline and SetWriteDeadline. + + # A deadline is an absolute time after which I/O operations + # fail instead of blocking. The deadline applies to all future + # and pending I/O, not just the immediately following call to + # read or write. After a deadline has been exceeded, the + # connection can be refreshed by setting a deadline in the future. + + # If the deadline is exceeded a call to read or write or to other + # I/O methods will return an error that wraps os.ErrDeadlineExceeded. + # This can be tested using errors.Is(err, os.ErrDeadlineExceeded). + # The error's Timeout method will return true, but note that there + # are other possible errors for which the Timeout method will + # return true even if the deadline has not been exceeded. + + # An idle timeout can be implemented by repeatedly extending + # the deadline after successful read or write calls. + + # A zero value for t means I/O operations will not time out.""" + # ... + + # fn set_read_deadline(self, t: time.Time) -> Error: + # """Sets the deadline for future read calls + # and any currently-blocked read call. + # A zero value for t means read will not time out.""" + # ... + + # fn set_write_deadline(self, t: time.Time) -> Error: + # """Sets the deadline for future write calls + # and any currently-blocked write call. + # Even if write times out, it may return n > 0, indicating that + # some of the data was successfully written. + # A zero value for t means write will not time out.""" + # ... + + +@value +struct Connection(Conn): + var fd: Arc[Socket] + + fn read(inout self, inout dest: List[Byte]) -> Result[Int]: + var result = self.fd[].read(dest) + if result.error: + if str(result.unwrap_error()) != io.EOF: + return Result[Int](0, result.unwrap_error()) + + return result.value + + fn write(inout self, src: List[Byte]) -> Result[Int]: + var result = self.fd[].write(src) + if result.error: + return Result[Int](0, result.unwrap_error()) + + return result.value + + fn close(inout self) -> Optional[WrappedError]: + var err = self.fd[].close() + if err: + return err.value() + + return None + + fn local_address(self) -> TCPAddr: + """Returns the local network address. + The Addr returned is shared by all invocations of local_address, so do not modify it.""" + return self.fd[].local_address + + fn remote_address(self) -> TCPAddr: + """Returns the remote network address. + The Addr returned is shared by all invocations of remote_address, so do not modify it.""" + return self.fd[].remote_address diff --git a/gojo/net/socket.mojo b/gojo/net/socket.mojo new file mode 100644 index 0000000..1288752 --- /dev/null +++ b/gojo/net/socket.mojo @@ -0,0 +1,419 @@ +from collections.optional import Optional +from ..builtins import Byte, Result, WrappedError +from ..syscall.file import close +from ..syscall.types import ( + c_void, + c_uint, + c_char, + c_int, +) +from ..syscall.net import ( + sockaddr, + sockaddr_in, + addrinfo, + addrinfo_unix, + socklen_t, + socket, + connect, + recv, + send, + shutdown, + inet_pton, + inet_ntoa, + inet_ntop, + to_char_ptr, + htons, + ntohs, + strlen, + getaddrinfo, + getaddrinfo_unix, + gai_strerror, + c_charptr_to_string, + bind, + listen, + accept, + setsockopt, + getsockopt, + getsockname, + getpeername, + AF_INET, + SOCK_STREAM, + SHUT_RDWR, + AI_PASSIVE, + SOL_SOCKET, + SO_REUSEADDR, + SO_RCVTIMEO, +) +from .fd import FileDescriptor, FileDescriptorBase +from .ip import convert_binary_ip_to_string, build_sockaddr_pointer, convert_binary_port_to_int +from .address import Addr, TCPAddr, HostPort + +alias SocketClosedError = Error("Socket: Socket is already closed") + + +@value +struct Socket(FileDescriptorBase): + """Represents a network file descriptor. Wraps around a file descriptor and provides network functions. + + Args: + local_address: The local address of the socket (local address if bound). + remote_address: The remote address of the socket (peer's address if connected). + address_family: The address family of the socket. + socket_type: The socket type. + protocol: The protocol. + """ + var sockfd: FileDescriptor + var address_family: Int + var socket_type: UInt8 + var protocol: UInt8 + var local_address: TCPAddr + var remote_address: TCPAddr + var _closed: Bool + var _is_connected: Bool + + fn __init__( + inout self, + local_address: TCPAddr = TCPAddr(), + remote_address: TCPAddr = TCPAddr(), + address_family: Int = AF_INET, + socket_type: UInt8 = SOCK_STREAM, + protocol: UInt8 = 0, + ) raises: + """Create a new socket object. + + Args: + local_address: The local address of the socket (local address if bound). + remote_address: The remote address of the socket (peer's address if connected). + address_family: The address family of the socket. + socket_type: The socket type. + protocol: The protocol. + """ + self.address_family = address_family + self.socket_type = socket_type + self.protocol = protocol + + var fd = socket(address_family, SOCK_STREAM, 0) + if fd == -1: + raise Error("Socket creation error") + self.sockfd = FileDescriptor(int(fd)) + self.local_address = local_address + self.remote_address = remote_address + self._closed = False + self._is_connected = False + + fn __init__( + inout self, + fd: Int32, + address_family: Int, + socket_type: UInt8, + protocol: UInt8, + local_address: TCPAddr = TCPAddr(), + remote_address: TCPAddr = TCPAddr(), + ): + """ + Create a new socket object when you already have a socket file descriptor. Typically through socket.accept(). + + Args: + fd: The file descriptor of the socket. + address_family: The address family of the socket. + socket_type: The socket type. + protocol: The protocol. + local_address: Local address of socket. + remote_address: Remote address of port. + """ + self.sockfd = FileDescriptor(int(fd)) + self.address_family = address_family + self.socket_type = socket_type + self.protocol = protocol + self.local_address = local_address + self.remote_address = remote_address + self._closed = False + self._is_connected = True + + fn __enter__(self) -> Self: + return self + + # fn __exit__(inout self) raises: + # if self._is_connected: + # self.shutdown() + # if not self._closed: + # self.close() + + fn __del__(owned self): + if self._is_connected: + self.shutdown() + if not self._closed: + var err = self.close() + if err: + print("Failed to close socket during deletion:", err.value()) + + @always_inline + fn accept(self) raises -> Self: + """Accept a connection. The socket must be bound to an address and listening for connections. + The return value is a connection where conn is a new socket object usable to send and receive data on the connection, + and address is the address bound to the socket on the other end of the connection. + """ + var their_addr_ptr = Pointer[sockaddr].alloc(1) + var sin_size = socklen_t(sizeof[socklen_t]()) + var new_sockfd = accept( + self.sockfd.fd, their_addr_ptr, Pointer[socklen_t].address_of(sin_size) + ) + if new_sockfd == -1: + raise Error("Failed to accept connection") + + var remote = self.get_peer_name() + return Self(new_sockfd, self.address_family, self.socket_type, self.protocol, self.local_address, TCPAddr(remote.host, remote.port)) + + fn listen(self, backlog: Int = 0) raises: + """Enable a server to accept connections. + + Args: + backlog: The maximum number of queued connections. Should be at least 0, and the maximum is system-dependent (usually 5). + """ + var queued = backlog + if backlog < 0: + queued = 0 + if listen(self.sockfd.fd, queued) == -1: + raise Error("Failed to listen for connections") + + @always_inline + fn bind(inout self, address: String, port: Int) raises: + """Bind the socket to address. The socket must not already be bound. (The format of address depends on the address family). + + When a socket is created with Socket(), it exists in a name + space (address family) but has no address assigned to it. bind() + assigns the address specified by addr to the socket referred to + by the file descriptor sockfd. addrlen specifies the size, in + bytes, of the address structure pointed to by addr. + Traditionally, this operation is called 'assigning a name to a + socket'. + + Args: + address: String - The IP address to bind the socket to. + port: The port number to bind the socket to. + """ + var sockaddr_pointer = build_sockaddr_pointer( + address, port, self.address_family + ) + + if bind(self.sockfd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1: + _ = shutdown(self.sockfd.fd, SHUT_RDWR) + raise Error("Binding socket failed. Wait a few seconds and try again?") + + var local = self.get_sock_name() + self.local_address = TCPAddr(local.host, local.port) + + @always_inline + fn file_no(self) -> Int32: + """Return the file descriptor of the socket.""" + return self.sockfd.fd + + @always_inline + fn get_sock_name(self) raises -> HostPort: + """Return the address of the socket.""" + if self._closed: + raise SocketClosedError + + # TODO: Add check to see if the socket is bound and error if not. + + var local_address_ptr = Pointer[sockaddr].alloc(1) + var local_address_ptr_size = socklen_t(sizeof[sockaddr]()) + var status = getsockname( + self.sockfd.fd, + local_address_ptr, + Pointer[socklen_t].address_of(local_address_ptr_size), + ) + if status == -1: + raise Error("Socket.get_sock_name: Failed to get address of local socket.") + var addr_in = local_address_ptr.bitcast[sockaddr_in]().load() + + return HostPort( + host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, 16), + port=convert_binary_port_to_int(addr_in.sin_port) + ) + + fn get_peer_name(self) raises -> HostPort: + """Return the address of the peer connected to the socket.""" + if self._closed: + raise SocketClosedError + + # TODO: Add check to see if the socket is bound and error if not. + var remote_address_ptr = Pointer[sockaddr].alloc(1) + var remote_address_ptr_size = socklen_t(sizeof[sockaddr]()) + var status = getpeername( + self.sockfd.fd, + remote_address_ptr, + Pointer[socklen_t].address_of(remote_address_ptr_size), + ) + if status == -1: + raise Error("Socket.get_peer_name: Failed to get address of remote socket.") + + # Cast sockaddr struct to sockaddr_in to convert binary IP to string. + var addr_in = remote_address_ptr.bitcast[sockaddr_in]().load() + + return HostPort( + host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, 16), + port=convert_binary_port_to_int(addr_in.sin_port) + ) + + fn get_socket_option(self, option_name: Int) raises -> Int: + """Return the value of the given socket option. + + Args: + option_name: The socket option to get. + """ + var option_value_pointer = Pointer[c_void].alloc(1) + var option_len = socklen_t(sizeof[socklen_t]()) + var option_len_pointer = Pointer.address_of(option_len) + var status = getsockopt( + self.sockfd.fd, + SOL_SOCKET, + option_name, + option_value_pointer, + option_len_pointer, + ) + if status == -1: + raise Error("Socket.get_sock_opt failed with status: " + str(status)) + + return option_value_pointer.bitcast[Int]().load() + + fn set_socket_option(self, option_name: Int, owned option_value: UInt8 = 1) raises: + """Return the value of the given socket option. + + Args: + option_name: The socket option to set. + option_value: The value to set the socket option to. + """ + var option_value_pointer = Pointer[c_void].address_of(option_value) + var option_len = sizeof[socklen_t]() + var status = setsockopt( + self.sockfd.fd, SOL_SOCKET, option_name, option_value_pointer, option_len + ) + if status == -1: + raise Error("Socket.set_sock_opt failed with status: " + str(status)) + + fn connect(inout self, address: String, port: Int) raises: + """Connect to a remote socket at address. + + Args: + address: String - The IP address to connect to. + port: The port number to connect to. + """ + var sockaddr_pointer = build_sockaddr_pointer( + address, port, self.address_family + ) + + if connect(self.sockfd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1: + self.shutdown() + raise Error("Socket.connect: Failed to connect to the remote socket at: " + address + ":" + str(port)) + + var remote = self.get_peer_name() + self.remote_address = TCPAddr(remote.host, remote.port) + + fn write(inout self: Self, src: List[Byte]) -> Result[Int]: + """Send data to the socket. The socket must be connected to a remote socket. + + Args: + src: The data to send. + + Returns: + The number of bytes sent. + """ + var result = self.sockfd.write(src) + if result.error: + return Result(0, result.unwrap_error()) + + return result.value + + fn send_all(self, src: List[Byte], max_attempts: Int = 3) raises: + """Send data to the socket. The socket must be connected to a remote socket. + + Args: + src: The data to send. + max_attempts: The maximum number of attempts to send the data. + """ + var header_pointer = Pointer[Int8](src.data.value).bitcast[UInt8]() + var total_bytes_sent = 0 + var attempts = 0 + + # Try to send all the data in the buffer. If it did not send all the data, keep trying but start from the offset of the last successful send. + while total_bytes_sent < len(src): + if attempts > max_attempts: + raise Error( + "Failed to send message after " + + String(max_attempts) + + " attempts." + ) + + var bytes_sent = send( + self.sockfd.fd, + header_pointer.offset(total_bytes_sent), + strlen(header_pointer.offset(total_bytes_sent)), + 0, + ) + if bytes_sent == -1: + raise Error( + "Failed to send message, wrote" + + String(total_bytes_sent) + + "bytes before failing." + ) + total_bytes_sent += bytes_sent + attempts += 1 + + fn send_to(inout self, src: List[Byte], address: String, port: Int) raises -> Int: + """Send data to the a remote address by connecting to the remote socket before sending. + The socket must be not already be connected to a remote socket. + + Args: + src: The data to send. + address: The IP address to connect to. + port: The port number to connect to. + """ + var header_pointer = Pointer[Int8](src.data.value).bitcast[UInt8]() + self.connect(address, port) + var result = self.write(src) + if result.error: + raise result.unwrap_error().error + return result.value + + fn read(inout self, inout dest: List[Byte]) -> Result[Int]: + """Receive data from the socket.""" + # Not ideal since we can't use the pointer from the List[Byte] struct directly. So we use a temporary pointer to receive the data. + # Then we copy all the data over. + var result = self.sockfd.read(dest) + if result.error: + return Result(0, result.unwrap_error()) + + return result.value + + fn shutdown(self): + _ = shutdown(self.sockfd.fd, SHUT_RDWR) + + fn close(inout self) -> Optional[WrappedError]: + """Mark the socket closed. + Once that happens, all future operations on the socket object will fail. + The remote end will receive no more data (after queued data is flushed). + """ + self.shutdown() + var err = self.sockfd.close() + if err: + return err.value() + + self._closed = True + return None + + # TODO: Trying to set timeout fails, but some other options don't? + # fn get_timeout(self) raises -> Seconds: + # """Return the timeout value for the socket.""" + # return self.get_socket_option(SO_RCVTIMEO) + + # fn set_timeout(self, owned duration: Seconds) raises: + # """Set the timeout value for the socket. + + # Args: + # duration: Seconds - The timeout duration in seconds. + # """ + # self.set_socket_option(SO_RCVTIMEO, duration) + + fn send_file(self, file: FileHandle, offset: Int = 0) raises: + self.send_all(file.read_bytes()) diff --git a/gojo/syscall/__init__.mojo b/gojo/syscall/__init__.mojo new file mode 100644 index 0000000..e69de29 diff --git a/gojo/syscall/file.mojo b/gojo/syscall/file.mojo new file mode 100644 index 0000000..aa2fb4e --- /dev/null +++ b/gojo/syscall/file.mojo @@ -0,0 +1,122 @@ +from .types import c_int, c_char, c_void, c_size_t, c_ssize_t + + +# --- ( File Related Syscalls & Structs )--------------------------------------- +alias O_NONBLOCK = 16384 +alias O_ACCMODE = 3 +alias O_CLOEXEC = 524288 + + +fn close(fildes: c_int) -> c_int: + """Libc POSIX `close` function + Reference: https://man7.org/linux/man-pages/man3/close.3p.html + Fn signature: int close(int fildes). + + Args: + fildes: A File Descriptor to close. + + Returns: + Upon successful completion, 0 shall be returned; otherwise, -1 + shall be returned and errno set to indicate the error. + """ + return external_call["close", c_int, c_int](fildes) + + +fn open[*T: AnyType](path: Pointer[c_char], oflag: c_int, *args: *T) -> c_int: + """Libc POSIX `open` function + Reference: https://man7.org/linux/man-pages/man3/open.3p.html + Fn signature: int open(const char *path, int oflag, ...). + + Args: + path: A pointer to a C string containing the path to open. + oflag: The flags to open the file with. + args: The optional arguments. + Returns: + A File Descriptor or -1 in case of failure + """ + return external_call[ + "open", c_int, Pointer[c_char], c_int # FnName, RetType # Args + ](path, oflag, args) + + +fn openat[ + *T: AnyType +](fd: c_int, path: Pointer[c_char], oflag: c_int, *args: *T) -> c_int: + """Libc POSIX `open` function + Reference: https://man7.org/linux/man-pages/man3/open.3p.html + Fn signature: int openat(int fd, const char *path, int oflag, ...). + + Args: + fd: A File Descriptor. + path: A pointer to a C string containing the path to open. + oflag: The flags to open the file with. + args: The optional arguments. + Returns: + A File Descriptor or -1 in case of failure + """ + return external_call[ + "openat", c_int, c_int, Pointer[c_char], c_int # FnName, RetType # Args + ](fd, path, oflag, args) + + +fn printf[*T: AnyType](format: Pointer[c_char], *args: *T) -> c_int: + """Libc POSIX `printf` function + Reference: https://man7.org/linux/man-pages/man3/fprintf.3p.html + Fn signature: int printf(const char *restrict format, ...). + + Args: format: A pointer to a C string containing the format. + args: The optional arguments. + Returns: The number of bytes written or -1 in case of failure. + """ + return external_call[ + "printf", + c_int, # FnName, RetType + Pointer[c_char], # Args + ](format, args) + + +fn sprintf[ + *T: AnyType +](s: Pointer[c_char], format: Pointer[c_char], *args: *T) -> c_int: + """Libc POSIX `sprintf` function + Reference: https://man7.org/linux/man-pages/man3/fprintf.3p.html + Fn signature: int sprintf(char *restrict s, const char *restrict format, ...). + + Args: s: A pointer to a buffer to store the result. + format: A pointer to a C string containing the format. + args: The optional arguments. + Returns: The number of bytes written or -1 in case of failure. + """ + return external_call[ + "sprintf", c_int, Pointer[c_char], Pointer[c_char] # FnName, RetType # Args + ](s, format, args) + + +fn read(fildes: c_int, buf: Pointer[c_void], nbyte: c_size_t) -> c_int: + """Libc POSIX `read` function + Reference: https://man7.org/linux/man-pages/man3/read.3p.html + Fn signature: sssize_t read(int fildes, void *buf, size_t nbyte). + + Args: fildes: A File Descriptor. + buf: A pointer to a buffer to store the read data. + nbyte: The number of bytes to read. + Returns: The number of bytes read or -1 in case of failure. + """ + return external_call["read", c_ssize_t, c_int, Pointer[c_void], c_size_t]( + fildes, buf, nbyte + ) + + +fn write(fildes: c_int, buf: Pointer[c_void], nbyte: c_size_t) -> c_int: + """Libc POSIX `write` function + Reference: https://man7.org/linux/man-pages/man3/write.3p.html + Fn signature: ssize_t write(int fildes, const void *buf, size_t nbyte). + + Args: fildes: A File Descriptor. + buf: A pointer to a buffer to write. + nbyte: The number of bytes to write. + Returns: The number of bytes written or -1 in case of failure. + """ + return external_call["write", c_ssize_t, c_int, Pointer[c_void], c_size_t]( + fildes, buf, nbyte + ) diff --git a/gojo/syscall/net.mojo b/gojo/syscall/net.mojo new file mode 100644 index 0000000..00df707 --- /dev/null +++ b/gojo/syscall/net.mojo @@ -0,0 +1,773 @@ +from .types import c_char, c_int, c_ushort, c_uint, c_void, c_size_t, c_ssize_t, strlen +from .file import O_CLOEXEC, O_NONBLOCK + +alias IPPROTO_IPV6 = 41 +alias IPV6_V6ONLY = 26 +alias EPROTONOSUPPORT = 93 + +# Adapted from https://github.com/gabrieldemarmiesse/mojo-stdlib-extensions/ . Huge thanks to Gabriel! + +alias FD_STDIN: c_int = 0 +alias FD_STDOUT: c_int = 1 +alias FD_STDERR: c_int = 2 + +alias SUCCESS = 0 +alias GRND_NONBLOCK: UInt8 = 1 + +alias char_pointer = AnyPointer[c_char] + + +# --- ( error.h Constants )----------------------------------------------------- +alias EPERM = 1 +alias ENOENT = 2 +alias ESRCH = 3 +alias EINTR = 4 +alias EIO = 5 +alias ENXIO = 6 +alias E2BIG = 7 +alias ENOEXEC = 8 +alias EBADF = 9 +alias ECHILD = 10 +alias EAGAIN = 11 +alias ENOMEM = 12 +alias EACCES = 13 +alias EFAULT = 14 +alias ENOTBLK = 15 +alias EBUSY = 16 +alias EEXIST = 17 +alias EXDEV = 18 +alias ENODEV = 19 +alias ENOTDIR = 20 +alias EISDIR = 21 +alias EINVAL = 22 +alias ENFILE = 23 +alias EMFILE = 24 +alias ENOTTY = 25 +alias ETXTBSY = 26 +alias EFBIG = 27 +alias ENOSPC = 28 +alias ESPIPE = 29 +alias EROFS = 30 +alias EMLINK = 31 +alias EPIPE = 32 +alias EDOM = 33 +alias ERANGE = 34 +alias EWOULDBLOCK = EAGAIN + + +fn to_char_ptr(s: String) -> Pointer[c_char]: + """Only ASCII-based strings.""" + var ptr = Pointer[c_char]().alloc(len(s)) + for i in range(len(s)): + ptr.store(i, ord(s[i])) + return ptr + + +fn c_charptr_to_string(s: Pointer[c_char]) -> String: + return String(s.bitcast[Int8](), strlen(s)) + + +fn cftob(val: c_int) -> Bool: + """Convert C-like failure (-1) to Bool.""" + return rebind[Bool](val > 0) + + +# --- ( Network Related Constants )--------------------------------------------- +alias sa_family_t = c_ushort +alias socklen_t = c_uint +alias in_addr_t = c_uint +alias in_port_t = c_ushort + +# Address Family Constants +alias AF_UNSPEC = 0 +alias AF_UNIX = 1 +alias AF_LOCAL = AF_UNIX +alias AF_INET = 2 +alias AF_AX25 = 3 +alias AF_IPX = 4 +alias AF_APPLETALK = 5 +alias AF_NETROM = 6 +alias AF_BRIDGE = 7 +alias AF_ATMPVC = 8 +alias AF_X25 = 9 +alias AF_INET6 = 10 +alias AF_ROSE = 11 +alias AF_DECnet = 12 +alias AF_NETBEUI = 13 +alias AF_SECURITY = 14 +alias AF_KEY = 15 +alias AF_NETLINK = 16 +alias AF_ROUTE = AF_NETLINK +alias AF_PACKET = 17 +alias AF_ASH = 18 +alias AF_ECONET = 19 +alias AF_ATMSVC = 20 +alias AF_RDS = 21 +alias AF_SNA = 22 +alias AF_IRDA = 23 +alias AF_PPPOX = 24 +alias AF_WANPIPE = 25 +alias AF_LLC = 26 +alias AF_CAN = 29 +alias AF_TIPC = 30 +alias AF_BLUETOOTH = 31 +alias AF_IUCV = 32 +alias AF_RXRPC = 33 +alias AF_ISDN = 34 +alias AF_PHONET = 35 +alias AF_IEEE802154 = 36 +alias AF_CAIF = 37 +alias AF_ALG = 38 +alias AF_NFC = 39 +alias AF_VSOCK = 40 +alias AF_KCM = 41 +alias AF_QIPCRTR = 42 +alias AF_MAX = 43 + +# Protocol family constants +alias PF_UNSPEC = AF_UNSPEC +alias PF_UNIX = AF_UNIX +alias PF_LOCAL = AF_LOCAL +alias PF_INET = AF_INET +alias PF_AX25 = AF_AX25 +alias PF_IPX = AF_IPX +alias PF_APPLETALK = AF_APPLETALK +alias PF_NETROM = AF_NETROM +alias PF_BRIDGE = AF_BRIDGE +alias PF_ATMPVC = AF_ATMPVC +alias PF_X25 = AF_X25 +alias PF_INET6 = AF_INET6 +alias PF_ROSE = AF_ROSE +alias PF_DECnet = AF_DECnet +alias PF_NETBEUI = AF_NETBEUI +alias PF_SECURITY = AF_SECURITY +alias PF_KEY = AF_KEY +alias PF_NETLINK = AF_NETLINK +alias PF_ROUTE = AF_ROUTE +alias PF_PACKET = AF_PACKET +alias PF_ASH = AF_ASH +alias PF_ECONET = AF_ECONET +alias PF_ATMSVC = AF_ATMSVC +alias PF_RDS = AF_RDS +alias PF_SNA = AF_SNA +alias PF_IRDA = AF_IRDA +alias PF_PPPOX = AF_PPPOX +alias PF_WANPIPE = AF_WANPIPE +alias PF_LLC = AF_LLC +alias PF_CAN = AF_CAN +alias PF_TIPC = AF_TIPC +alias PF_BLUETOOTH = AF_BLUETOOTH +alias PF_IUCV = AF_IUCV +alias PF_RXRPC = AF_RXRPC +alias PF_ISDN = AF_ISDN +alias PF_PHONET = AF_PHONET +alias PF_IEEE802154 = AF_IEEE802154 +alias PF_CAIF = AF_CAIF +alias PF_ALG = AF_ALG +alias PF_NFC = AF_NFC +alias PF_VSOCK = AF_VSOCK +alias PF_KCM = AF_KCM +alias PF_QIPCRTR = AF_QIPCRTR +alias PF_MAX = AF_MAX + +# Socket Type constants +alias SOCK_STREAM = 1 +alias SOCK_DGRAM = 2 +alias SOCK_RAW = 3 +alias SOCK_RDM = 4 +alias SOCK_SEQPACKET = 5 +alias SOCK_DCCP = 6 +alias SOCK_PACKET = 10 +alias SOCK_CLOEXEC = O_CLOEXEC +alias SOCK_NONBLOCK = O_NONBLOCK + +# Address Information +alias AI_PASSIVE = 1 +alias AI_CANONNAME = 2 +alias AI_NUMERICHOST = 4 +alias AI_V4MAPPED = 2048 +alias AI_ALL = 256 +alias AI_ADDRCONFIG = 1024 +alias AI_IDN = 64 + +alias INET_ADDRSTRLEN = 16 +alias INET6_ADDRSTRLEN = 46 + +alias SHUT_RD = 0 +alias SHUT_WR = 1 +alias SHUT_RDWR = 2 + +alias SOL_SOCKET = 65535 + +# Socket Options +alias SO_DEBUG = 1 +alias SO_REUSEADDR = 4 +alias SO_TYPE = 4104 +alias SO_ERROR = 4103 +alias SO_DONTROUTE = 16 +alias SO_BROADCAST = 32 +alias SO_SNDBUF = 4097 +alias SO_RCVBUF = 4098 +alias SO_KEEPALIVE = 8 +alias SO_OOBINLINE = 256 +alias SO_LINGER = 128 +alias SO_REUSEPORT = 512 +alias SO_RCVLOWAT = 4100 +alias SO_SNDLOWAT = 4099 +alias SO_RCVTIMEO = 4102 +alias SO_SNDTIMEO = 4101 +alias SO_RCVTIMEO_OLD = 4102 +alias SO_SNDTIMEO_OLD = 4101 +alias SO_ACCEPTCONN = 2 + +# unsure of these socket options, they weren't available via python +alias SO_NO_CHECK = 11 +alias SO_PRIORITY = 12 +alias SO_BSDCOMPAT = 14 +alias SO_PASSCRED = 16 +alias SO_PEERCRED = 17 +alias SO_SECURITY_AUTHENTICATION = 22 +alias SO_SECURITY_ENCRYPTION_TRANSPORT = 23 +alias SO_SECURITY_ENCRYPTION_NETWORK = 24 +alias SO_BINDTODEVICE = 25 +alias SO_ATTACH_FILTER = 26 +alias SO_DETACH_FILTER = 27 +alias SO_GET_FILTER = SO_ATTACH_FILTER +alias SO_PEERNAME = 28 +alias SO_TIMESTAMP = 29 +alias SO_TIMESTAMP_OLD = 29 +alias SO_PEERSEC = 31 +alias SO_SNDBUFFORCE = 32 +alias SO_RCVBUFFORCE = 33 +alias SO_PASSSEC = 34 +alias SO_TIMESTAMPNS = 35 +alias SO_TIMESTAMPNS_OLD = 35 +alias SO_MARK = 36 +alias SO_TIMESTAMPING = 37 +alias SO_TIMESTAMPING_OLD = 37 +alias SO_PROTOCOL = 38 +alias SO_DOMAIN = 39 +alias SO_RXQ_OVFL = 40 +alias SO_WIFI_STATUS = 41 +alias SCM_WIFI_STATUS = SO_WIFI_STATUS +alias SO_PEEK_OFF = 42 +alias SO_NOFCS = 43 +alias SO_LOCK_FILTER = 44 +alias SO_SELECT_ERR_QUEUE = 45 +alias SO_BUSY_POLL = 46 +alias SO_MAX_PACING_RATE = 47 +alias SO_BPF_EXTENSIONS = 48 +alias SO_INCOMING_CPU = 49 +alias SO_ATTACH_BPF = 50 +alias SO_DETACH_BPF = SO_DETACH_FILTER +alias SO_ATTACH_REUSEPORT_CBPF = 51 +alias SO_ATTACH_REUSEPORT_EBPF = 52 +alias SO_CNX_ADVICE = 53 +alias SCM_TIMESTAMPING_OPT_STATS = 54 +alias SO_MEMINFO = 55 +alias SO_INCOMING_NAPI_ID = 56 +alias SO_COOKIE = 57 +alias SCM_TIMESTAMPING_PKTINFO = 58 +alias SO_PEERGROUPS = 59 +alias SO_ZEROCOPY = 60 +alias SO_TXTIME = 61 +alias SCM_TXTIME = SO_TXTIME +alias SO_BINDTOIFINDEX = 62 +alias SO_TIMESTAMP_NEW = 63 +alias SO_TIMESTAMPNS_NEW = 64 +alias SO_TIMESTAMPING_NEW = 65 +alias SO_RCVTIMEO_NEW = 66 +alias SO_SNDTIMEO_NEW = 67 +alias SO_DETACH_REUSEPORT_BPF = 68 + + +# --- ( Network Related Structs )----------------------------------------------- +@value +@register_passable("trivial") +struct in_addr: + var s_addr: in_addr_t + + +@value +@register_passable("trivial") +struct in6_addr: + var s6_addr: StaticTuple[c_char, 16] + + +@value +@register_passable("trivial") +struct sockaddr: + var sa_family: sa_family_t + var sa_data: StaticTuple[c_char, 14] + + +@value +@register_passable("trivial") +struct sockaddr_in: + var sin_family: sa_family_t + var sin_port: in_port_t + var sin_addr: in_addr + var sin_zero: StaticTuple[c_char, 8] + + +@value +@register_passable("trivial") +struct sockaddr_in6: + var sin6_family: sa_family_t + var sin6_port: in_port_t + var sin6_flowinfo: c_uint + var sin6_addr: in6_addr + var sin6_scope_id: c_uint + + +@value +@register_passable("trivial") +struct addrinfo: + """Struct field ordering can vary based on platform. + For MacOS, I had to swap the order of ai_canonname and ai_addr. + https://stackoverflow.com/questions/53575101/calling-getaddrinfo-directly-from-python-ai-addr-is-null-pointer + """ + + var ai_flags: c_int + var ai_family: c_int + var ai_socktype: c_int + var ai_protocol: c_int + var ai_addrlen: socklen_t + var ai_canonname: Pointer[c_char] + var ai_addr: Pointer[sockaddr] + var ai_next: Pointer[addrinfo] + + fn __init__() -> Self: + return Self( + 0, 0, 0, 0, 0, Pointer[c_char](), Pointer[sockaddr](), Pointer[addrinfo]() + ) + + +@value +@register_passable("trivial") +struct addrinfo_unix: + """Struct field ordering can vary based on platform. + For MacOS, I had to swap the order of ai_canonname and ai_addr. + https://stackoverflow.com/questions/53575101/calling-getaddrinfo-directly-from-python-ai-addr-is-null-pointer. + """ + + var ai_flags: c_int + var ai_family: c_int + var ai_socktype: c_int + var ai_protocol: c_int + var ai_addrlen: socklen_t + var ai_addr: Pointer[sockaddr] + var ai_canonname: Pointer[c_char] + var ai_next: Pointer[addrinfo] + + fn __init__() -> Self: + return Self( + 0, 0, 0, 0, 0, Pointer[sockaddr](), Pointer[c_char](), Pointer[addrinfo]() + ) + + +# --- ( Network Related Syscalls & Structs )------------------------------------ + + +fn htonl(hostlong: c_uint) -> c_uint: + """Libc POSIX `htonl` function + Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html + Fn signature: uint32_t htonl(uint32_t hostlong). + + Args: hostlong: A 32-bit integer in host byte order. + Returns: The value provided in network byte order. + """ + return external_call["htonl", c_uint, c_uint](hostlong) + + +fn htons(hostshort: c_ushort) -> c_ushort: + """Libc POSIX `htons` function + Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html + Fn signature: uint16_t htons(uint16_t hostshort). + + Args: hostshort: A 16-bit integer in host byte order. + Returns: The value provided in network byte order. + """ + return external_call["htons", c_ushort, c_ushort](hostshort) + + +fn ntohl(netlong: c_uint) -> c_uint: + """Libc POSIX `ntohl` function + Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html + Fn signature: uint32_t ntohl(uint32_t netlong). + + Args: netlong: A 32-bit integer in network byte order. + Returns: The value provided in host byte order. + """ + return external_call["ntohl", c_uint, c_uint](netlong) + + +fn ntohs(netshort: c_ushort) -> c_ushort: + """Libc POSIX `ntohs` function + Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html + Fn signature: uint16_t ntohs(uint16_t netshort). + + Args: netshort: A 16-bit integer in network byte order. + Returns: The value provided in host byte order. + """ + return external_call["ntohs", c_ushort, c_ushort](netshort) + + +fn inet_ntop( + af: c_int, src: Pointer[c_void], dst: Pointer[c_char], size: socklen_t +) -> Pointer[c_char]: + """Libc POSIX `inet_ntop` function + Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html. + Fn signature: const char *inet_ntop(int af, const void *restrict src, char *restrict dst, socklen_t size). + + Args: + af: Address Family see AF_ aliases. + src: A pointer to a binary address. + dst: A pointer to a buffer to store the result. + size: The size of the buffer. + + Returns: + A pointer to the buffer containing the result. + """ + return external_call[ + "inet_ntop", + Pointer[c_char], # FnName, RetType + c_int, + Pointer[c_void], + Pointer[c_char], + socklen_t, # Args + ](af, src, dst, size) + + +fn inet_pton(af: c_int, src: Pointer[c_char], dst: Pointer[c_void]) -> c_int: + """Libc POSIX `inet_pton` function + Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html + Fn signature: int inet_pton(int af, const char *restrict src, void *restrict dst). + + Args: af: Address Family see AF_ aliases. + src: A pointer to a string containing the address. + dst: A pointer to a buffer to store the result. + Returns: 1 on success, 0 if the input is not a valid address, -1 on error. + """ + return external_call[ + "inet_pton", + c_int, # FnName, RetType + c_int, + Pointer[c_char], + Pointer[c_void], # Args + ](af, src, dst) + + +fn inet_addr(cp: Pointer[c_char]) -> in_addr_t: + """Libc POSIX `inet_addr` function + Reference: https://man7.org/linux/man-pages/man3/inet_addr.3p.html + Fn signature: in_addr_t inet_addr(const char *cp). + + Args: cp: A pointer to a string containing the address. + Returns: The address in network byte order. + """ + return external_call["inet_addr", in_addr_t, Pointer[c_char]](cp) + + +fn inet_ntoa(addr: in_addr) -> Pointer[c_char]: + """Libc POSIX `inet_ntoa` function + Reference: https://man7.org/linux/man-pages/man3/inet_addr.3p.html + Fn signature: char *inet_ntoa(struct in_addr in). + + Args: in: A pointer to a string containing the address. + Returns: The address in network byte order. + """ + return external_call["inet_ntoa", Pointer[c_char], in_addr](addr) + + +fn socket(domain: c_int, type: c_int, protocol: c_int) -> c_int: + """Libc POSIX `socket` function + Reference: https://man7.org/linux/man-pages/man3/socket.3p.html + Fn signature: int socket(int domain, int type, int protocol). + + Args: domain: Address Family see AF_ aliases. + type: Socket Type see SOCK_ aliases. + protocol: The protocol to use. + Returns: A File Descriptor or -1 in case of failure. + """ + return external_call[ + "socket", c_int, c_int, c_int, c_int # FnName, RetType # Args + ](domain, type, protocol) + + +fn setsockopt( + socket: c_int, + level: c_int, + option_name: c_int, + option_value: Pointer[c_void], + option_len: socklen_t, +) -> c_int: + """Libc POSIX `setsockopt` function + Reference: https://man7.org/linux/man-pages/man3/setsockopt.3p.html + Fn signature: int setsockopt(int socket, int level, int option_name, const void *option_value, socklen_t option_len). + + Args: + socket: A File Descriptor. + level: The protocol level. + option_name: The option to set. + option_value: A pointer to the value to set. + option_len: The size of the value. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "setsockopt", + c_int, # FnName, RetType + c_int, + c_int, + c_int, + Pointer[c_void], + socklen_t, # Args + ](socket, level, option_name, option_value, option_len) + + +fn getsockopt( + socket: c_int, + level: c_int, + option_name: c_int, + option_value: Pointer[c_void], + option_len: Pointer[socklen_t], +) -> c_int: + """Libc POSIX `getsockopt` function + Reference: https://man7.org/linux/man-pages/man3/getsockopt.3p.html + Fn signature: int getsockopt(int socket, int level, int option_name, void *restrict option_value, socklen_t *restrict option_len). + + Args: socket: A File Descriptor. + level: The protocol level. + option_name: The option to get. + option_value: A pointer to the value to get. + option_len: Pointer to the size of the value. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "getsockopt", + c_int, # FnName, RetType + c_int, + c_int, + c_int, + Pointer[c_void], + Pointer[socklen_t], # Args + ](socket, level, option_name, option_value, option_len) + + +fn getsockname( + socket: c_int, address: Pointer[sockaddr], address_len: Pointer[socklen_t] +) -> c_int: + """Libc POSIX `getsockname` function + Reference: https://man7.org/linux/man-pages/man3/getsockname.3p.html + Fn signature: int getsockname(int socket, struct sockaddr *restrict address, socklen_t *restrict address_len). + + Args: socket: A File Descriptor. + address: A pointer to a buffer to store the address of the peer. + address_len: A pointer to the size of the buffer. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "getsockname", + c_int, # FnName, RetType + c_int, + Pointer[sockaddr], + Pointer[socklen_t], # Args + ](socket, address, address_len) + + +fn getpeername( + sockfd: c_int, addr: Pointer[sockaddr], address_len: Pointer[socklen_t] +) -> c_int: + """Libc POSIX `getpeername` function + Reference: https://man7.org/linux/man-pages/man2/getpeername.2.html + Fn signature: int getpeername(int socket, struct sockaddr *restrict addr, socklen_t *restrict address_len). + + Args: sockfd: A File Descriptor. + addr: A pointer to a buffer to store the address of the peer. + address_len: A pointer to the size of the buffer. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "getpeername", + c_int, # FnName, RetType + c_int, + Pointer[sockaddr], + Pointer[socklen_t], # Args + ](sockfd, addr, address_len) + + +fn bind(socket: c_int, address: Pointer[sockaddr], address_len: socklen_t) -> c_int: + """Libc POSIX `bind` function + Reference: https://man7.org/linux/man-pages/man3/bind.3p.html + Fn signature: int bind(int socket, const struct sockaddr *address, socklen_t address_len). + """ + return external_call[ + "bind", c_int, c_int, Pointer[sockaddr], socklen_t # FnName, RetType # Args + ](socket, address, address_len) + + +fn listen(socket: c_int, backlog: c_int) -> c_int: + """Libc POSIX `listen` function + Reference: https://man7.org/linux/man-pages/man3/listen.3p.html + Fn signature: int listen(int socket, int backlog). + + Args: socket: A File Descriptor. + backlog: The maximum length of the queue of pending connections. + Returns: 0 on success, -1 on error. + """ + return external_call["listen", c_int, c_int, c_int](socket, backlog) + + +fn accept( + socket: c_int, address: Pointer[sockaddr], address_len: Pointer[socklen_t] +) -> c_int: + """Libc POSIX `accept` function + Reference: https://man7.org/linux/man-pages/man3/accept.3p.html + Fn signature: int accept(int socket, struct sockaddr *restrict address, socklen_t *restrict address_len). + + Args: socket: A File Descriptor. + address: A pointer to a buffer to store the address of the peer. + address_len: A pointer to the size of the buffer. + Returns: A File Descriptor or -1 in case of failure. + """ + return external_call[ + "accept", + c_int, # FnName, RetType + c_int, + Pointer[sockaddr], + Pointer[socklen_t], # Args + ](socket, address, address_len) + + +fn connect(socket: c_int, address: Pointer[sockaddr], address_len: socklen_t) -> c_int: + """Libc POSIX `connect` function + Reference: https://man7.org/linux/man-pages/man3/connect.3p.html + Fn signature: int connect(int socket, const struct sockaddr *address, socklen_t address_len). + + Args: socket: A File Descriptor. + address: A pointer to the address to connect to. + address_len: The size of the address. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "connect", c_int, c_int, Pointer[sockaddr], socklen_t # FnName, RetType # Args + ](socket, address, address_len) + + +fn recv( + socket: c_int, buffer: Pointer[c_void], length: c_size_t, flags: c_int +) -> c_ssize_t: + """Libc POSIX `recv` function + Reference: https://man7.org/linux/man-pages/man3/recv.3p.html + Fn signature: ssize_t recv(int socket, void *buffer, size_t length, int flags). + """ + return external_call[ + "recv", + c_ssize_t, # FnName, RetType + c_int, + Pointer[c_void], + c_size_t, + c_int, # Args + ](socket, buffer, length, flags) + + +fn send( + socket: c_int, buffer: Pointer[c_void], length: c_size_t, flags: c_int +) -> c_ssize_t: + """Libc POSIX `send` function + Reference: https://man7.org/linux/man-pages/man3/send.3p.html + Fn signature: ssize_t send(int socket, const void *buffer, size_t length, int flags). + + Args: socket: A File Descriptor. + buffer: A pointer to the buffer to send. + length: The size of the buffer. + flags: Flags to control the behaviour of the function. + Returns: The number of bytes sent or -1 in case of failure. + """ + return external_call[ + "send", + c_ssize_t, # FnName, RetType + c_int, + Pointer[c_void], + c_size_t, + c_int, # Args + ](socket, buffer, length, flags) + + +fn shutdown(socket: c_int, how: c_int) -> c_int: + """Libc POSIX `shutdown` function + Reference: https://man7.org/linux/man-pages/man3/shutdown.3p.html + Fn signature: int shutdown(int socket, int how). + + Args: socket: A File Descriptor. + how: How to shutdown the socket. + Returns: 0 on success, -1 on error. + """ + return external_call["shutdown", c_int, c_int, c_int]( # FnName, RetType # Args + socket, how + ) + + +fn getaddrinfo( + nodename: Pointer[c_char], + servname: Pointer[c_char], + hints: Pointer[addrinfo], + res: Pointer[Pointer[addrinfo]], +) -> c_int: + """Libc POSIX `getaddrinfo` function + Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html + Fn signature: int getaddrinfo(const char *restrict nodename, const char *restrict servname, const struct addrinfo *restrict hints, struct addrinfo **restrict res). + """ + return external_call[ + "getaddrinfo", + c_int, # FnName, RetType + Pointer[c_char], + Pointer[c_char], + Pointer[addrinfo], # Args + Pointer[Pointer[addrinfo]], # Args + ](nodename, servname, hints, res) + + +fn getaddrinfo_unix( + nodename: Pointer[c_char], + servname: Pointer[c_char], + hints: Pointer[addrinfo_unix], + res: Pointer[Pointer[addrinfo_unix]], +) -> c_int: + """Libc POSIX `getaddrinfo` function + Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html + Fn signature: int getaddrinfo(const char *restrict nodename, const char *restrict servname, const struct addrinfo *restrict hints, struct addrinfo **restrict res). + """ + return external_call[ + "getaddrinfo", + c_int, # FnName, RetType + Pointer[c_char], + Pointer[c_char], + Pointer[addrinfo_unix], # Args + Pointer[Pointer[addrinfo_unix]], # Args + ](nodename, servname, hints, res) + + +fn gai_strerror(ecode: c_int) -> Pointer[c_char]: + """Libc POSIX `gai_strerror` function + Reference: https://man7.org/linux/man-pages/man3/gai_strerror.3p.html + Fn signature: const char *gai_strerror(int ecode). + + Args: ecode: The error code. + Returns: A pointer to a string describing the error. + """ + return external_call[ + "gai_strerror", Pointer[c_char], c_int # FnName, RetType # Args + ](ecode) + + +fn inet_pton(address_family: Int, address: String) -> Int: + var ip_buf_size = 4 + if address_family == AF_INET6: + ip_buf_size = 16 + + var ip_buf = Pointer[c_void].alloc(ip_buf_size) + var conv_status = inet_pton( + rebind[c_int](address_family), to_char_ptr(address), ip_buf + ) + return ip_buf.bitcast[c_uint]().load().to_int() diff --git a/gojo/syscall/types.mojo b/gojo/syscall/types.mojo new file mode 100644 index 0000000..3494d57 --- /dev/null +++ b/gojo/syscall/types.mojo @@ -0,0 +1,63 @@ +@value +struct Str: + var vector: List[c_char] + + fn __init__(inout self, string: String): + self.vector = List[c_char](capacity=len(string) + 1) + for i in range(len(string)): + self.vector.append(ord(string[i])) + self.vector.append(0) + + fn __init__(inout self, size: Int): + self.vector = List[c_char]() + self.vector.resize(size + 1, 0) + + fn __len__(self) -> Int: + for i in range(len(self.vector)): + if self.vector[i] == 0: + return i + return -1 + + fn to_string(self, size: Int) -> String: + var result: String = "" + for i in range(size): + result += chr(self.vector[i].to_int()) + return result + + fn __enter__(owned self: Self) -> Self: + return self ^ + + +fn strlen(s: Pointer[c_char]) -> c_size_t: + """Libc POSIX `strlen` function + Reference: https://man7.org/linux/man-pages/man3/strlen.3p.html + Fn signature: size_t strlen(const char *s). + + Args: s: A pointer to a C string. + Returns: The length of the string. + """ + return external_call["strlen", c_size_t, Pointer[c_char]](s) + + +# Adapted from https://github.com/crisadamo/mojo-Libc . Huge thanks to Cristian! +# C types +alias c_void = UInt8 +alias c_char = UInt8 +alias c_schar = Int8 +alias c_uchar = UInt8 +alias c_short = Int16 +alias c_ushort = UInt16 +alias c_int = Int32 +alias c_uint = UInt32 +alias c_long = Int64 +alias c_ulong = UInt64 +alias c_float = Float32 +alias c_double = Float64 + +# `Int` is known to be machine's width +alias c_size_t = Int +alias c_ssize_t = Int + +alias ptrdiff_t = Int64 +alias intptr_t = Int64 +alias uintptr_t = UInt64 diff --git a/goodies/file.mojo b/goodies/file.mojo index be49b2f..f3b9393 100644 --- a/goodies/file.mojo +++ b/goodies/file.mojo @@ -13,14 +13,18 @@ struct FileWrapper(io.ReadWriteSeeker, io.ByteReader): self.handle = existing.handle ^ fn __del__(owned self): - try: - self.close() - except: + var err = self.close() + if err: # TODO: __del__ can't raise, but there should be some fallback. - print("Failed to close the file.") + print(err.value()) + + fn close(inout self) -> Optional[WrappedError]: + try: + self.handle.close() + except e: + return WrappedError(e) - fn close(inout self) raises: - self.handle.close() + return None fn read(inout self, inout dest: List[Byte]) -> Result[Int]: # Pretty hacky way to force the filehandle read into the defined trait. diff --git a/test_get_addr.mojo b/test_get_addr.mojo new file mode 100644 index 0000000..30035d4 --- /dev/null +++ b/test_get_addr.mojo @@ -0,0 +1,37 @@ +from gojo.net.socket import Socket +from gojo.net.ip import get_ip_address +from gojo.syscall.net import SO_REUSEADDR, PF_UNIX, SO_RCVTIMEO + +# fn main() raises: +# var ip = get_ip_address("localhost") +# print(ip) + +fn main() raises: + # TODO: context manager not working yet + # with Socket() as socket: + # socket.bind("0.0.0.0", 8080) + + var socket = Socket(protocol=PF_UNIX) + socket.bind("0.0.0.0", 8080) + socket.connect(get_ip_address("www.example.com"), 80) + print("File number", socket.file_no()) + var local = socket.get_sock_name() + var remote = socket.get_peer_name() + print("Local address", str(local), socket.local_address) + print("Remote address", str(remote), socket.remote_address) + socket.set_socket_option(SO_REUSEADDR, 1) + print("REUSE_ADDR value", socket.get_socket_option(SO_REUSEADDR)) + var timeout = 30 + # socket.set_timeout(timeout) + # print(socket.get_timeout()) + socket.shutdown() + var err = socket.close() + if err: + raise err.value().error + # var option_value = socket.get_sock_opt(SO_REUSEADDR) + # print(option_value) + # socket.connect(self.ip, self.port) + # socket.send(message) + # var response = socket.receive() # TODO: call receive until all data is fetched, receive should also just return bytes + # socket.shutdown() + # socket.close()