From c0d8b32646eedcb76a981fc0cf9b327406f1dff3 Mon Sep 17 00:00:00 2001 From: Mikhail Tavarez Date: Thu, 20 Jun 2024 09:44:27 -0500 Subject: [PATCH] Add UDP, update Scanner struct to fix bugged looping. Added scan_runes split function for scanner. (#45) * wip udp * wip udp * reworking listener, connectoin, etc * udp listener and socket working...I think * delete comments * revert dial change * update dial example * fixed examples * wip refactoring structs to use span for read trait. Scanner still needs some work * wip refactoring structs to use span for read trait. Scanner still needs some work * add span and list read funcs * fixed scanner * working state, tests passing * update changelog --- CHANGELOG.md | 19 +- examples/__init__.mojo | 0 examples/scanner/__init__.mojo | 0 examples/scanner/scan_text.mojo | 33 +++ examples/tcp/__init__.mojo | 0 examples/tcp/dial_client.mojo | 43 ++++ examples/tcp/get_request.mojo | 37 +++ examples/tcp/listener_server.mojo | 27 +++ examples/tcp/socket_client.mojo | 35 +++ examples/tcp/socket_server.mojo | 40 ++++ examples/udp/__init__.mojo | 0 examples/udp/dial_client.mojo | 29 +++ examples/udp/listener_server.mojo | 23 ++ examples/udp/socket_client.mojo | 28 +++ examples/udp/socket_server.mojo | 28 +++ gojo/bufio/__init__.mojo | 2 +- gojo/bufio/bufio.mojo | 86 +++---- gojo/bufio/scan.mojo | 197 +++++++--------- gojo/builtins/attributes.mojo | 39 +++- gojo/bytes/buffer.mojo | 30 ++- gojo/bytes/reader.mojo | 66 ++++-- gojo/io/__init__.mojo | 358 +++++++++++++++++++++++++++--- gojo/io/file.mojo | 30 ++- gojo/io/io.mojo | 2 +- gojo/io/traits.mojo | 320 -------------------------- gojo/net/__init__.mojo | 13 +- gojo/net/address.mojo | 80 ++++--- gojo/net/dial.mojo | 45 ---- gojo/net/fd.mojo | 26 ++- gojo/net/ip.mojo | 25 +++ gojo/net/net.mojo | 124 ----------- gojo/net/socket.mojo | 297 +++++++++++++++++++------ gojo/net/tcp.mojo | 249 +++++++++++++-------- gojo/net/udp.mojo | 210 ++++++++++++++++++ gojo/strings/reader.mojo | 48 +++- gojo/syscall/__init__.mojo | 3 +- gojo/syscall/net.mojo | 95 +++++++- tests/test_bufio_scanner.mojo | 94 ++++---- tests/test_get_addr.mojo | 63 +++--- 39 files changed, 1842 insertions(+), 1002 deletions(-) create mode 100644 examples/__init__.mojo create mode 100644 examples/scanner/__init__.mojo create mode 100644 examples/scanner/scan_text.mojo create mode 100644 examples/tcp/__init__.mojo create mode 100644 examples/tcp/dial_client.mojo create mode 100644 examples/tcp/get_request.mojo create mode 100644 examples/tcp/listener_server.mojo create mode 100644 examples/tcp/socket_client.mojo create mode 100644 examples/tcp/socket_server.mojo create mode 100644 examples/udp/__init__.mojo create mode 100644 examples/udp/dial_client.mojo create mode 100644 examples/udp/listener_server.mojo create mode 100644 examples/udp/socket_client.mojo create mode 100644 examples/udp/socket_server.mojo delete mode 100644 gojo/io/traits.mojo delete mode 100644 gojo/net/dial.mojo delete mode 100644 gojo/net/net.mojo create mode 100644 gojo/net/udp.mojo diff --git a/CHANGELOG.md b/CHANGELOG.md index f8fec6d..f22d166 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ - # Change Log + All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) @@ -7,6 +7,23 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] - yyyy-mm-dd +## [0.0.2] - 2024-06-19 + +### Added + +- UDP support in `net` package. +- `examples` package with `tcp` and `udp` examples using `Socket` and their respective `dial` and `listen` functions. +- Added `scan_runes` split function to `bufio.scan` module. +- Added `bufio.Scanner` examples to the `examples` directory. + +### Removed + +- `Listener`, `Dialer`, and `Conn` interfaces have been removed until Trait support improves. For now, call `listen_tcp/listen_udp` and `dial_tcp/dial_udp` functions directly. + +### Changed + +- Incrementally moving towards using `Span` for the `Reader` and `Writer` traits. Added an `_read` function to `Reader` and `_read_at` to `ReaderAt` traits to enable reading into `Span`. The usual implementation is the take a `List[UInt8]` but then to use `_read` and pass a `Span` constructed from the List. + ## [0.0.1] - 2024-06-16 ### Changed diff --git a/examples/__init__.mojo b/examples/__init__.mojo new file mode 100644 index 0000000..e69de29 diff --git a/examples/scanner/__init__.mojo b/examples/scanner/__init__.mojo new file mode 100644 index 0000000..e69de29 diff --git a/examples/scanner/scan_text.mojo b/examples/scanner/scan_text.mojo new file mode 100644 index 0000000..712d1cb --- /dev/null +++ b/examples/scanner/scan_text.mojo @@ -0,0 +1,33 @@ +from gojo.bytes import buffer +from gojo.bufio import Reader, Scanner, scan_words + + +fn print_words(owned text: String): + # Create a reader from a string buffer + var buf = buffer.new_buffer(text^) + var r = Reader(buf^) + + # Create a scanner from the reader + var scanner = Scanner[split=scan_words](r^) + + while scanner.scan(): + print(scanner.current_token()) + + +fn print_lines(owned text: String): + # Create a reader from a string buffer + var buf = buffer.new_buffer(text^) + var r = Reader(buf^) + + # Create a scanner from the reader + var scanner = Scanner(r^) + + while scanner.scan(): + print(scanner.current_token()) + + +fn main(): + var text = String("Testing this string!") + var text2 = String("Testing\nthis\nstring!") + print_words(text^) + print_lines(text2^) diff --git a/examples/tcp/__init__.mojo b/examples/tcp/__init__.mojo new file mode 100644 index 0000000..e69de29 diff --git a/examples/tcp/dial_client.mojo b/examples/tcp/dial_client.mojo new file mode 100644 index 0000000..aa98fd2 --- /dev/null +++ b/examples/tcp/dial_client.mojo @@ -0,0 +1,43 @@ +from gojo.net import Socket, HostPort, dial_tcp, TCPAddr +from gojo.syscall import SocketType +import gojo.io + + +fn main() raises: + # Create UDP Connection + alias message = String("dial") + alias host = "127.0.0.1" + alias port = 8081 + + for _ in range(10): + var connection = dial_tcp("tcp", host, port) + var bytes_written: Int + var err: Error + bytes_written, err = connection.write( + String("GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: close\r\n\r\n").as_bytes() + ) + if err: + raise err + + if bytes_written == 0: + print("No bytes sent to peer.") + return + + # Read the response from the connection + var response = List[UInt8](capacity=4096) + var bytes_read: Int = 0 + bytes_read, err = connection.read(response) + if err and str(err) != io.EOF: + raise err + + if bytes_read == 0: + print("No bytes received from peer.") + return + + response.append(0) + print("Message received:", String(response^)) + + # Cleanup the connection + err = connection.close() + if err: + raise err diff --git a/examples/tcp/get_request.mojo b/examples/tcp/get_request.mojo new file mode 100644 index 0000000..886cf8f --- /dev/null +++ b/examples/tcp/get_request.mojo @@ -0,0 +1,37 @@ +from gojo.net import TCPAddr, get_ip_address, dial_tcp +from gojo.syscall import ProtocolFamily + + +fn main() raises: + # Connect to example.com on port 80 and send a GET request + var connection = dial_tcp("tcp", TCPAddr(get_ip_address("www.example.com"), 80)) + var bytes_written: Int = 0 + var err = Error() + bytes_written, err = connection.write( + String("GET / HTTP/1.1\r\nHost: www.example.com\r\nConnection: close\r\n\r\n").as_bytes() + ) + if err: + raise err + + if bytes_written == 0: + print("No bytes sent to peer.") + return + + # Read the response from the connection + var response = List[UInt8](capacity=4096) + var bytes_read: Int = 0 + bytes_read, err = connection.read(response) + if err: + raise err + + if bytes_read == 0: + print("No bytes received from peer.") + return + + response.append(0) + print(String(response^)) + + # Cleanup the connection + err = connection.close() + if err: + raise err diff --git a/examples/tcp/listener_server.mojo b/examples/tcp/listener_server.mojo new file mode 100644 index 0000000..606e8e4 --- /dev/null +++ b/examples/tcp/listener_server.mojo @@ -0,0 +1,27 @@ +from gojo.net import TCPAddr, get_ip_address, listen_tcp, HostPort +import gojo.io + + +fn main() raises: + var listener = listen_tcp("udp", TCPAddr("127.0.0.1", 12000)) + + while True: + var connection = listener.accept() + + # Read the contents of the message from the client. + var bytes = List[UInt8](capacity=4096) + var bytes_read: Int + var err: Error + bytes_read, err = connection.read(bytes) + if str(err) != io.EOF: + raise err + + bytes.append(0) + var message = String(bytes^) + print("Message Received:", message) + message = message.upper() + + # Send a response back to the client. + var bytes_sent: Int + bytes_sent, err = connection.write(message.as_bytes()) + print("Message sent:", message, bytes_sent) diff --git a/examples/tcp/socket_client.mojo b/examples/tcp/socket_client.mojo new file mode 100644 index 0000000..8b51d46 --- /dev/null +++ b/examples/tcp/socket_client.mojo @@ -0,0 +1,35 @@ +from gojo.net import Socket, HostPort +from gojo.syscall import SocketType +import gojo.io + + +fn main() raises: + # Create TCP Socket + var socket = Socket() + alias message = String("test") + alias host = "127.0.0.1" + alias port = 8082 + + # Bind client to port 8082 + socket.bind(host, port) + + # Send 10 test messages + var err = socket.connect(host, 8081) + if err: + raise err + var bytes_sent: Int + bytes_sent, err = socket.write(message.as_bytes()) + print("Message sent:", message) + + var bytes = List[UInt8](capacity=16) + var bytes_read: Int + bytes_read, err = socket.read(bytes) + if str(err) != io.EOF: + raise err + + bytes.append(0) + var response = String(bytes^) + print("Message received:", response) + + _ = socket.shutdown() + _ = socket.close() diff --git a/examples/tcp/socket_server.mojo b/examples/tcp/socket_server.mojo new file mode 100644 index 0000000..7d5e963 --- /dev/null +++ b/examples/tcp/socket_server.mojo @@ -0,0 +1,40 @@ +from gojo.net import Socket, HostPort +from gojo.syscall import SocketOptions +import gojo.io + + +fn main() raises: + var socket = Socket() + socket.set_socket_option(SocketOptions.SO_REUSEADDR, 1) + alias host = "127.0.0.1" + alias port = 8081 + + # Bind server to port 8081 + socket.bind(host, port) + socket.listen() + print("Listening on", socket.local_address_as_tcp()) + while True: + # Accept connections from clients and serve them. + var connection = socket.accept() + print("Serving", connection.remote_address_as_tcp()) + + # Read the contents of the message from the client. + var bytes = List[UInt8](capacity=4096) + var bytes_read: Int + var err: Error + bytes_read, err = connection.read(bytes) + if str(err) != io.EOF: + raise err + + bytes.append(0) + var message = String(bytes^) + print("Message Received:", message) + message = message.upper() + + # Send a response back to the client. + var bytes_sent: Int + bytes_sent, err = connection.write(message.as_bytes()) + print("Message sent:", message, bytes_sent) + err = connection.close() + if err: + raise err diff --git a/examples/udp/__init__.mojo b/examples/udp/__init__.mojo new file mode 100644 index 0000000..e69de29 diff --git a/examples/udp/dial_client.mojo b/examples/udp/dial_client.mojo new file mode 100644 index 0000000..30f73e7 --- /dev/null +++ b/examples/udp/dial_client.mojo @@ -0,0 +1,29 @@ +from gojo.net import Socket, HostPort, dial_udp, UDPAddr +from gojo.syscall import SocketType +import gojo.io + + +fn main() raises: + # Create UDP Connection + alias message = String("dial") + alias host = "127.0.0.1" + alias port = 12000 + var udp = dial_udp("udp", host, port) + + # Send 10 test messages + for _ in range(10): + var bytes_sent: Int + var err: Error + bytes_sent, err = udp.write_to(message.as_bytes(), host, port) + print("Message sent:", message, bytes_sent) + + var bytes = List[UInt8](capacity=16) + var bytes_received: Int + var remote: HostPort + bytes_received, remote, err = udp.read_from(bytes) + if str(err) != io.EOF: + raise err + + bytes.append(0) + var response = String(bytes^) + print("Message received:", response) diff --git a/examples/udp/listener_server.mojo b/examples/udp/listener_server.mojo new file mode 100644 index 0000000..b62bcfd --- /dev/null +++ b/examples/udp/listener_server.mojo @@ -0,0 +1,23 @@ +from gojo.net import UDPAddr, get_ip_address, listen_udp, HostPort +import gojo.io + + +fn main() raises: + var listener = listen_udp("udp", UDPAddr("127.0.0.1", 12000)) + + while True: + var dest = List[UInt8](capacity=16) + var bytes_read: Int + var remote: HostPort + var err: Error + bytes_read, remote, err = listener.read_from(dest) + if err: + raise err + + dest.append(0) + var message = String(dest^) + print("Message received:", message) + message = message.upper() + var bytes_sent: Int + bytes_sent, err = listener.write_to(message.as_bytes(), UDPAddr(remote.host, remote.port)) + print("Message sent:", message) diff --git a/examples/udp/socket_client.mojo b/examples/udp/socket_client.mojo new file mode 100644 index 0000000..4cc89e6 --- /dev/null +++ b/examples/udp/socket_client.mojo @@ -0,0 +1,28 @@ +from gojo.net import Socket, HostPort +from gojo.syscall import SocketType +import gojo.io + + +fn main() raises: + # Create UDP Socket + var socket = Socket(socket_type=SocketType.SOCK_DGRAM) + alias message = String("test") + alias host = "127.0.0.1" + alias port = 12000 + + # Send 10 test messages + for _ in range(10): + var bytes_sent: Int + var err: Error + bytes_sent, err = socket.send_to(message.as_bytes(), host, port) + print("Message sent:", message) + + var bytes: List[UInt8] + var remote: HostPort + bytes, remote, err = socket.receive_from(1024) + if str(err) != io.EOF: + raise err + + bytes.append(0) + var response = String(bytes^) + print("Message received:", response) diff --git a/examples/udp/socket_server.mojo b/examples/udp/socket_server.mojo new file mode 100644 index 0000000..0ea45ca --- /dev/null +++ b/examples/udp/socket_server.mojo @@ -0,0 +1,28 @@ +from gojo.net import Socket, HostPort +from gojo.syscall import SocketType +import gojo.io + + +fn main() raises: + var socket = Socket(socket_type=SocketType.SOCK_DGRAM) + alias host = "127.0.0.1" + alias port = 12000 + + socket.bind(host, port) + print("Listening on", socket.local_address_as_udp()) + while True: + var bytes: List[UInt8] + var remote: HostPort + var err: Error + bytes, remote, err = socket.receive_from(1024) + if str(err) != io.EOF: + raise err + + bytes.append(0) + var message = String(bytes^) + print("Message Received:", message) + message = message.upper() + + var bytes_sent: Int + bytes_sent, err = socket.send_to(message.as_bytes(), remote.host, remote.port) + print("Message sent:", message) diff --git a/gojo/bufio/__init__.mojo b/gojo/bufio/__init__.mojo index c501992..cd20874 100644 --- a/gojo/bufio/__init__.mojo +++ b/gojo/bufio/__init__.mojo @@ -1,2 +1,2 @@ from .bufio import Reader, Writer, ReadWriter -from .scan import Scanner, scan_words, scan_bytes, scan_lines +from .scan import Scanner, scan_words, scan_bytes, scan_lines, scan_runes diff --git a/gojo/bufio/bufio.mojo b/gojo/bufio/bufio.mojo index 3806850..072d87e 100644 --- a/gojo/bufio/bufio.mojo +++ b/gojo/bufio/bufio.mojo @@ -108,11 +108,8 @@ struct Reader[R: io.Reader](Sized, io.Reader, io.ByteReader, io.ByteScanner): if self.read_pos > 0: var data_to_slide = self.as_bytes_slice()[self.read_pos : self.write_pos] # TODO: Temp copying of elements until I figure out a better pattern or slice refs are added - for i in range(len(self.buf)): - if i > len(self.buf): - self.buf[i] = data_to_slide[i] - else: - self.buf.append(data_to_slide[i]) + for i in range(len(data_to_slide)): + self.buf[i] = data_to_slide[i] # self.buf.reserve(current_capacity) self.write_pos -= self.read_pos @@ -126,20 +123,15 @@ struct Reader[R: io.Reader](Sized, io.Reader, io.ByteReader, io.ByteScanner): # Read new data: try a limited number of times. var i: Int = MAX_CONSECUTIVE_EMPTY_READS while i > 0: - # TODO: Using temp until slicing can return a Reference - var temp = List[UInt8](capacity=io.BUFFER_SIZE) + # TODO: Using temp until slicing can return a Reference, does reading directly into a Span of self.buf work? + # Maybe we need to read into the end of the buffer. + var span = self.as_bytes_slice() var bytes_read: Int var err: Error - bytes_read, err = self.reader.read(temp) + bytes_read, err = self.reader._read(span, len(self.buf)) if bytes_read < 0: panic(ERR_NEGATIVE_READ) - # TODO: Temp copying of elements until I figure out a better pattern or slice refs are added - for i in range(len(temp)): - if i + self.write_pos > len(temp): - self.buf[i + self.write_pos] = temp[i] - else: - self.buf.append(temp[i]) self.write_pos += bytes_read if err: @@ -230,7 +222,7 @@ struct Reader[R: io.Reader](Sized, io.Reader, io.ByteReader, io.ByteScanner): if remain == 0: return number_of_bytes, Error() - fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + fn _read(inout self, inout dest: Span[UInt8, True], capacity: Int) -> (Int, Error): """Reads data into dest. It returns the number of bytes read into dest. The bytes are taken from at most one Read on the underlying [Reader], @@ -238,22 +230,20 @@ struct Reader[R: io.Reader](Sized, io.Reader, io.ByteReader, io.ByteScanner): To read exactly len(src) bytes, use io.ReadFull(b, src). If the underlying [Reader] can return a non-zero count with io.EOF, then this Read method can do so as well; see the [io.Reader] docs.""" - var space_available = dest.capacity - len(dest) - if space_available == 0: + # TODO: How do we check the capacity of a Span? Or UnsafePointer? + if capacity == 0: if self.buffered() > 0: return 0, Error() return 0, self.read_error() var bytes_read: Int = 0 if self.read_pos == self.write_pos: - if space_available >= len(self.buf): + if capacity >= len(self.buf): # Large read, empty buffer. # Read directly into dest to avoid copy. var bytes_read: Int - var err: Error - bytes_read, err = self.reader.read(dest) + bytes_read, self.err = self.reader._read(dest, capacity) - self.err = err if bytes_read < 0: panic(ERR_NEGATIVE_READ) @@ -267,10 +257,9 @@ struct Reader[R: io.Reader](Sized, io.Reader, io.ByteReader, io.ByteScanner): # Do not use self.fill, which will loop. self.read_pos = 0 self.write_pos = 0 - var buf = List[UInt8](self.as_bytes_slice()) # TODO: I'm hoping this reads into self.data directly lol + var buf = self.as_bytes_slice() # TODO: I'm hoping this reads into self.data directly lol var bytes_read: Int - var err: Error - bytes_read, err = self.reader.read(buf) + bytes_read, self.err = self.reader._read(buf, len(buf)) if bytes_read < 0: panic(ERR_NEGATIVE_READ) @@ -281,14 +270,38 @@ struct Reader[R: io.Reader](Sized, io.Reader, io.ByteReader, io.ByteScanner): self.write_pos += bytes_read # copy as much as we can - # Note: if the slice panics here, it is probably because - # the underlying reader returned a bad count. See issue 49795. - bytes_read = copy(dest, self.as_bytes_slice()[self.read_pos : self.write_pos]) + var source = self.as_bytes_slice()[self.read_pos : self.write_pos] + bytes_read = 0 + var start = len(dest) + + for i in range(len(source)): + dest[i + start] = source[i] + bytes_read += 1 + dest._len += bytes_read self.read_pos += bytes_read - self.last_byte = int(self.as_bytes_slice()[self.read_pos - 1]) + self.last_byte = int(self.buf[self.read_pos - 1]) self.last_rune_size = -1 return bytes_read, Error() + @always_inline + fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + """Reads data into dest. + It returns the number of bytes read into dest. + The bytes are taken from at most one Read on the underlying [Reader], + hence n may be less than len(src). + To read exactly len(src) bytes, use io.ReadFull(b, src). + If the underlying [Reader] can return a non-zero count with io.EOF, + then this Read method can do so as well; see the [io.Reader] docs.""" + + var span = Span(dest) + + var bytes_read: Int + var err: Error + bytes_read, err = self._read(span, dest.capacity) + dest.size += bytes_read + + return bytes_read, err + @always_inline fn read_byte(inout self) -> (UInt8, Error): """Reads and returns a single byte from the internal buffer. If no byte is available, returns an error.""" @@ -929,20 +942,9 @@ struct Writer[W: io.Writer, size: Int = io.BUFFER_SIZE]( var nr = 0 while nr < MAX_CONSECUTIVE_EMPTY_READS: - # TODO: should really be using a slice that returns refs and not a copy. - # Read into remaining unused space in the buffer. We need to reserve capacity for the slice otherwise read will never hit EOF. - var sl = List[UInt8](self.as_bytes_slice()[self.bytes_written : len(self.buf)]) - sl.reserve(self.buf.capacity) - bytes_read, err = reader.read(sl) - if bytes_read > 0: - # TODO: Temp copying of elements until I figure out a better pattern or slice refs are added - var bytes_read = 0 - for i in range(len(sl)): - if i + self.bytes_written > len(sl): - self.buf[i + self.bytes_written] = sl[i] - else: - self.buf.append(sl[i]) - bytes_read += 1 + # Read into remaining unused space in the buffer. + var buf = self.as_bytes_slice()[self.bytes_written : len(self.buf)] + bytes_read, err = reader._read(buf, len(buf)) if bytes_read != 0 or err: break diff --git a/gojo/bufio/scan.mojo b/gojo/bufio/scan.mojo index dc4418b..8606069 100644 --- a/gojo/bufio/scan.mojo +++ b/gojo/bufio/scan.mojo @@ -1,5 +1,5 @@ import ..io -from ..builtins import copy, panic, Error +from ..builtins import copy, panic from ..builtins.bytes import index_byte from .bufio import MAX_CONSECUTIVE_EMPTY_READS @@ -7,7 +7,7 @@ from .bufio import MAX_CONSECUTIVE_EMPTY_READS alias MAX_INT: Int = 2147483647 -struct Scanner[R: io.Reader](): +struct Scanner[R: io.Reader, split: SplitFunction = scan_lines](): # The function to split the tokens. """Scanner provides a convenient Interface for reading data such as a file of newline-delimited lines of text. Successive calls to the [Scanner.Scan] method will step through the 'tokens' of a file, skipping @@ -25,12 +25,9 @@ struct Scanner[R: io.Reader](): on a reader, should use [bufio.Reader] instead.""" var reader: R # The reader provided by the client. - var split: SplitFunction # The function to split the tokens. var max_token_size: Int # Maximum size of a token; modified by tests. var token: List[UInt8] # Last token returned by split. - var data: UnsafePointer[UInt8] # contents are the bytes buf[off : len(buf)] - var size: Int - var capacity: Int + var buf: List[UInt8] # buffer used as argument to split. var start: Int # First non-processed byte in buf. var end: Int # End of data in buf. var empties: Int # Count of successive empty tokens. @@ -41,12 +38,9 @@ struct Scanner[R: io.Reader](): fn __init__( inout self, owned reader: R, - split: SplitFunction = scan_lines, max_token_size: Int = MAX_SCAN_TOKEN_SIZE, token: List[UInt8] = List[UInt8](capacity=io.BUFFER_SIZE), - data: UnsafePointer[UInt8] = UnsafePointer[UInt8](), - size: Int = 0, - capacity: Int = io.BUFFER_SIZE, + buf: List[UInt8] = List[UInt8](capacity=io.BUFFER_SIZE), start: Int = 0, end: Int = 0, empties: Int = 0, @@ -54,12 +48,9 @@ struct Scanner[R: io.Reader](): done: Bool = False, ): self.reader = reader^ - self.split = split self.max_token_size = max_token_size self.token = token - self.data = data - self.size = size - self.capacity = capacity + self.buf = buf self.start = start self.end = end self.empties = empties @@ -67,33 +58,11 @@ struct Scanner[R: io.Reader](): self.done = done self.err = Error() - @always_inline - fn _resize(inout self, capacity: Int) -> None: - """ - Resizes the string builder buffer. - - Args: - capacity: The new capacity of the string builder buffer. - """ - var new_data = UnsafePointer[UInt8]().alloc(capacity) - memcpy(new_data, self.data, self.size) - self.data.free() - self.data = new_data - self.capacity = capacity - - return None - - @always_inline - fn __del__(owned self): - if self.data: - self.data.free() - @always_inline fn as_bytes_slice(self: Reference[Self]) -> Span[UInt8, self.is_mutable, self.lifetime]: """Returns the internal buffer data as a Span[UInt8].""" - return Span[UInt8, self.is_mutable, self.lifetime](unsafe_ptr=self[].data, len=self[].size) + return Span[UInt8, self.is_mutable, self.lifetime](self[].buf) - @always_inline fn current_token_as_bytes(self) -> List[UInt8]: """Returns the most recent token generated by a call to [Scanner.Scan]. The underlying array may point to data that will be overwritten @@ -101,15 +70,14 @@ struct Scanner[R: io.Reader](): """ return self.token - @always_inline fn current_token(self) -> String: """Returns the most recent token generated by a call to [Scanner.Scan] as a newly allocated string holding its bytes.""" var copy = self.token copy.append(0) - return String(copy) + return String(copy^) - fn scan(inout self) raises -> Bool: + fn scan(inout self) -> Bool: """Advances the [Scanner] to the next token, which will then be available through the [Scanner.current_token_as_bytes] or [Scanner.current_token] method. It returns False when there are no more tokens, either by reaching the end of the input or an error. @@ -135,7 +103,7 @@ struct Scanner[R: io.Reader](): var at_eof = False if self.err: at_eof = True - advance, token, err = self.split(self.as_bytes_slice()[self.start : self.end], at_eof) + advance, token, err = split(self.as_bytes_slice()[self.start : self.end], at_eof) if err: if str(err) == str(ERR_FINAL_TOKEN): self.token = token @@ -174,24 +142,27 @@ struct Scanner[R: io.Reader](): # Must read more data. # First, shift data to beginning of buffer if there's lots of empty space # or space is needed. - if self.start > 0 and (self.end == self.size or self.start > int(self.size / 2)): - self.data = self.as_bytes_slice()[self.start : self.end].unsafe_ptr() + if self.start > 0 and (self.end == len(self.buf) or self.start > int(len(self.buf) / 2)): + _ = copy(self.buf, self.as_bytes_slice()[self.start : self.end]) self.end -= self.start self.start = 0 # Is the buffer full? If so, resize. - if self.end == self.size: + if self.end == len(self.buf): # Guarantee no overflow in the multiplication below. - if self.size >= self.max_token_size or self.size > int(MAX_INT / 2): - self.set_err(ERR_TOO_LONG) + if len(self.buf) >= self.max_token_size or len(self.buf) > int(MAX_INT / 2): + self.set_err((ERR_TOO_LONG)) return False - var new_size = self.size * 2 + var new_size = len(self.buf) * 2 if new_size == 0: new_size = START_BUF_SIZE + # Make a new List[UInt8] buffer and copy the elements in new_size = min(new_size, self.max_token_size) - self._resize(new_size) + var new_buf = List[UInt8](capacity=new_size) + _ = copy(new_buf, self.buf[self.start : self.end]) + self.buf = new_buf self.end -= self.start self.start = 0 @@ -200,14 +171,13 @@ struct Scanner[R: io.Reader](): # be extra careful: Scanner is for safe, simple jobs. var loop = 0 while True: - var bytes_read: Int - var sl = List[UInt8](self.as_bytes_slice()[self.end : self.size]) - var err: Error + var buf = self.as_bytes_slice()[self.end :] # Catch any reader errors and set the internal error field to that err instead of bubbling it up. - bytes_read, err = self.reader.read(sl) - self.data = sl.steal_data() - if bytes_read < 0 or self.size - self.end < bytes_read: + var bytes_read: Int + var err: Error + bytes_read, err = self.reader._read(buf, self.buf.capacity - self.end) + if bytes_read < 0 or len(buf) - self.end < bytes_read: self.set_err(ERR_BAD_READ_COUNT) break @@ -225,7 +195,6 @@ struct Scanner[R: io.Reader](): self.set_err(Error(io.ERR_NO_PROGRESS)) break - @always_inline fn set_err(inout self, err: Error): """Set the internal error field to the provided error. @@ -239,7 +208,6 @@ struct Scanner[R: io.Reader](): else: self.err = err - @always_inline fn advance(inout self, n: Int) -> Bool: """Consumes n bytes of the buffer. It reports whether the advance was legal. @@ -250,11 +218,11 @@ struct Scanner[R: io.Reader](): True if the advance was legal, False otherwise. """ if n < 0: - self.set_err(ERR_NEGATIVE_ADVANCE) + self.set_err(Error(str(ERR_NEGATIVE_ADVANCE))) return False if n > self.end - self.start: - self.set_err(ERR_ADVANCE_TOO_FAR) + self.set_err(Error(str(ERR_ADVANCE_TOO_FAR))) return False self.start += n @@ -285,17 +253,18 @@ struct Scanner[R: io.Reader](): # The function is never called with an empty data slice unless at_eof # is True. If at_eof is True, however, data may be non-empty and, # as always, holds unprocessed text. -alias SplitFunction = fn (data: List[UInt8], at_eof: Bool) -> ( +alias SplitFunction = fn (data: Span[UInt8], at_eof: Bool) -> ( Int, List[UInt8], Error, ) -# # Errors returned by Scanner. +# Errors returned by Scanner. alias ERR_TOO_LONG = Error("bufio.Scanner: token too long") alias ERR_NEGATIVE_ADVANCE = Error("bufio.Scanner: SplitFunction returns negative advance count") alias ERR_ADVANCE_TOO_FAR = Error("bufio.Scanner: SplitFunction returns advance count beyond input") alias ERR_BAD_READ_COUNT = Error("bufio.Scanner: Read returned impossible count") + # ERR_FINAL_TOKEN is a special sentinel error value. It is Intended to be # returned by a split function to indicate that the scanning should stop # with no error. If the token being delivered with this error is not nil, @@ -324,55 +293,50 @@ fn new_scanner[R: io.Reader](owned reader: R) -> Scanner[R]: ###### split functions ###### -fn scan_bytes(data: List[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): - """Split function for a [Scanner] that returns each byte as a token.""" - if at_eof and data.capacity == 0: - return 0, List[UInt8](), Error() - - return 1, data[0:1], Error() +fn scan_bytes(data: Span[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): + """Returns each byte as a token. + Args: + data: The data to split. + at_eof: Whether the data is at the end of the file. -# var errorRune = List[UInt8](string(utf8.RuneError)) - -# # ScanRunes is a split function for a [Scanner] that returns each -# # UTF-8-encoded rune as a token. The sequence of runes returned is -# # equivalent to that from a range loop over the input as a string, which -# # means that erroneous UTF-8 encodings translate to U+FFFD = "\xef\xbf\xbd". -# # Because of the Scan Interface, this makes it impossible for the client to -# # distinguish correctly encoded replacement runes from encoding errors. -# fn ScanRunes(data List[UInt8], at_eof Bool) (advance Int, token List[UInt8], err error): -# if at_eof and data.capacity == 0: -# return 0, nil, nil + Returns: + The number of bytes to advance the input, token in bytes, and an error if one occurred. + """ + if at_eof and len(data) == 0: + return 0, List[UInt8](), Error() + return 1, List[UInt8](data[0:1]), Error() -# # Fast path 1: ASCII. -# if data[0] < utf8.RuneSelf: -# return 1, data[0:1], nil +fn scan_runes(data: Span[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): + """Returns each UTF-8-encoded rune as a token. -# # Fast path 2: Correct UTF-8 decode without error. -# _, width := utf8.DecodeRune(data) -# if width > 1: -# # It's a valid encoding. Width cannot be one for a correctly encoded -# # non-ASCII rune. -# return width, data[0:width], nil + Args: + data: The data to split. + at_eof: Whether the data is at the end of the file. + Returns: + The number of bytes to advance the input, token in bytes, and an error if one occurred. + """ + if at_eof and len(data) == 0: + return 0, List[UInt8](), Error() -# # We know it's an error: we have width==1 and implicitly r==utf8.RuneError. -# # Is the error because there wasn't a full rune to be decoded? -# # FullRune distinguishes correctly between erroneous and incomplete encodings. -# if !at_eof and !utf8.FullRune(data): -# # Incomplete; get more bytes. -# return 0, nil, nil + # Number of bytes of the current character + var char_length = int( + (DTypePointer[DType.uint8](data.unsafe_ptr()).load() >> 7 == 0).cast[DType.uint8]() * 1 + + countl_zero(~DTypePointer[DType.uint8](data.unsafe_ptr()).load()) + ) + # Copy N bytes into new pointer and construct List. + var sp = UnsafePointer[UInt8].alloc(char_length) + memcpy(sp, data.unsafe_ptr(), char_length) + var result = List[UInt8](unsafe_pointer=sp, size=char_length, capacity=char_length) -# # We have a real UTF-8 encoding error. Return a properly encoded error rune -# # but advance only one byte. This matches the behavior of a range loop over -# # an incorrectly encoded string. -# return 1, errorRune, nil + return char_length, result, Error() -fn drop_carriage_return(data: List[UInt8]) -> List[UInt8]: +fn drop_carriage_return(data: Span[UInt8]) -> List[UInt8]: """Drops a terminal \r from the data. Args: @@ -382,15 +346,14 @@ fn drop_carriage_return(data: List[UInt8]) -> List[UInt8]: The stripped data. """ # In the case of a \r ending without a \n, indexing on -1 doesn't work as it finds a null terminator instead of \r. - if data.capacity > 0 and data[data.capacity - 1] == ord("\r"): - return data[0 : data.capacity - 1] + if len(data) > 0 and data[-1] == ord("\r"): + return data[:-1] return data -fn scan_lines(data: List[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): - """Split function for a [Scanner] that returns each line of - text, stripped of any trailing end-of-line marker. The returned line may +fn scan_lines(data: Span[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): + """Returns each line of text, stripped of any trailing end-of-line marker. The returned line may be empty. The end-of-line marker is one optional carriage return followed by one mandatory newline. The last non-empty line of input will be returned even if it has no newline. @@ -398,10 +361,11 @@ fn scan_lines(data: List[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): Args: data: The data to split. at_eof: Whether the data is at the end of the file. + Returns: The number of bytes to advance the input. """ - if at_eof and data.capacity == 0: + if at_eof and len(data) == 0: return 0, List[UInt8](), Error() var i = index_byte(data, ord("\n")) @@ -411,7 +375,7 @@ fn scan_lines(data: List[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): # If we're at EOF, we have a final, non-terminated line. Return it. # if at_eof: - return data.capacity, drop_carriage_return(data), Error() + return len(data), drop_carriage_return(data), Error() # Request more data. # return 0 @@ -425,16 +389,21 @@ fn is_space(r: UInt8) -> Bool: # TODO: Handle runes and utf8 decoding. For now, just assuming single byte length. -fn scan_words(data: List[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): - """Split function for a [Scanner] that returns each - space-separated word of text, with surrounding spaces deleted. It will - never return an empty string. The definition of space is set by - unicode.IsSpace. +fn scan_words(data: Span[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): + """Returns each space-separated word of text, with surrounding spaces deleted. It will + never return an empty string. + + Args: + data: The data to split. + at_eof: Whether the data is at the end of the file. + + Returns: + The number of bytes to advance the input, token in bytes, and an error if one occurred. """ # Skip leading spaces. var start = 0 var width = 0 - while start < data.capacity: + while start < len(data): width = len(data[0]) if not is_space(data[0]): break @@ -445,16 +414,16 @@ fn scan_words(data: List[UInt8], at_eof: Bool) -> (Int, List[UInt8], Error): var i = 0 width = 0 start = 0 - while i < data.capacity: + while i < len(data): width = len(data[i]) if is_space(data[i]): - return i + width, data[start:i], Error() + return i + width, List[UInt8](data[start:i]), Error() i += width # If we're at EOF, we have a final, non-empty, non-terminated word. Return it. - if at_eof and data.capacity > start: - return data.capacity, data[start:], Error() + if at_eof and len(data) > start: + return len(data), List[UInt8](data[start:]), Error() # Request more data. return start, List[UInt8](), Error() diff --git a/gojo/builtins/attributes.mojo b/gojo/builtins/attributes.mojo index d830186..7ce0b15 100644 --- a/gojo/builtins/attributes.mojo +++ b/gojo/builtins/attributes.mojo @@ -25,12 +25,34 @@ fn copy[T: CollectionElement](inout target: List[T], source: List[T], start: Int return count -fn copy[T: CollectionElement](inout target: List[T], source: Span[T], start: Int = 0) -> Int: +fn copy[T: CollectionElement](inout target_span: Span[T, True], source_span: Span[T], start: Int = 0) -> Int: """Copies the contents of source into target at the same index. Returns the number of bytes copied. Added a start parameter to specify the index to start copying into. Args: - target: The buffer to copy into. + target_span: The buffer to copy into. + source_span: The buffer to copy from. + start: The index to start copying into. + + Returns: + The number of bytes copied. + """ + var count = 0 + + for i in range(len(source_span)): + target_span[i + start] = source_span[i] + count += 1 + + target_span._len += count + return count + + +fn copy[T: CollectionElement](inout target_span: Span[T, True], source: InlineList[T], start: Int = 0) -> Int: + """Copies the contents of source into target at the same index. Returns the number of bytes copied. + Added a start parameter to specify the index to start copying into. + + Args: + target_span: The buffer to copy into. source: The buffer to copy from. start: The index to start copying into. @@ -40,15 +62,20 @@ fn copy[T: CollectionElement](inout target: List[T], source: Span[T], start: Int var count = 0 for i in range(len(source)): - if i + start > len(target): - target[i + start] = source[i] - else: - target.append(source[i]) + target_span[i + start] = source[i] count += 1 + target_span._len += count return count +fn test(inout dest: List[UInt8]): + var source = List[UInt8](1, 2, 3) + var target = Span[UInt8](dest) + + _ = copy(target, Span(source), start=0) + + fn copy[T: CollectionElement](inout list: InlineList[T], source: Span[T], start: Int = 0) -> Int: """Copies the contents of source into target at the same index. Returns the number of bytes copied. Added a start parameter to specify the index to start copying into. diff --git a/gojo/bytes/buffer.mojo b/gojo/bytes/buffer.mojo index 69cd706..bbd2a93 100644 --- a/gojo/bytes/buffer.mojo +++ b/gojo/bytes/buffer.mojo @@ -240,7 +240,7 @@ struct Buffer( self.last_read = OP_INVALID @always_inline - fn read(inout self, inout dest: List[Byte]) -> (Int, Error): + fn _read(inout self, inout dest: Span[Byte, True], capacity: Int) -> (Int, Error): """Reads the next len(dest) bytes from the buffer or until the buffer is drained. The return value n is the number of bytes read. If the buffer has no data to return, err is io.EOF (unless len(dest) is zero); @@ -256,12 +256,14 @@ struct Buffer( if self.empty(): # Buffer is empty, reset to recover space. self.reset() - if dest.capacity == 0: + # TODO: How to check if the span's pointer has 0 capacity? We want to return early if the span can't receive any data. + if capacity == 0: return 0, Error() return 0, Error(io.EOF) # Copy the data of the internal buffer from offset to len(buf) into the destination buffer at the given index. - var bytes_read = copy(dest, self.as_bytes_slice()[self.offset : self._size], dest.size) + var bytes_read = copy(dest, self.as_bytes_slice()[self.offset :]) + dest._len += bytes_read self.offset += bytes_read if bytes_read > 0: @@ -269,6 +271,28 @@ struct Buffer( return bytes_read, Error() + @always_inline + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): + """Reads the next len(dest) bytes from the buffer or until the buffer + is drained. The return value n is the number of bytes read. If the + buffer has no data to return, err is io.EOF (unless len(dest) is zero); + otherwise it is nil. + + Args: + dest: The buffer to read into. + + Returns: + The number of bytes read from the buffer. + """ + var span = Span(dest) + + var bytes_read: Int + var err: Error + bytes_read, err = self._read(span, dest.capacity) + dest.size += bytes_read + + return bytes_read, err + @always_inline fn read_byte(inout self) -> (Byte, Error): """Reads and returns the next byte from the buffer. diff --git a/gojo/bytes/reader.mojo b/gojo/bytes/reader.mojo index 640082a..54a1b7f 100644 --- a/gojo/bytes/reader.mojo +++ b/gojo/bytes/reader.mojo @@ -1,4 +1,4 @@ -from ..builtins import copy, Byte, panic +from ..builtins import copy, panic import ..io @@ -27,7 +27,7 @@ struct Reader( var prev_rune: Int # index of previous rune; or < 0 @always_inline - fn __init__(inout self, owned buffer: List[Byte]): + fn __init__(inout self, owned buffer: List[UInt8]): """Initializes a new [Reader.Reader] struct.""" self.capacity = buffer.capacity self.size = buffer.size @@ -67,12 +67,13 @@ struct Reader( return Span[UInt8, self.is_mutable, self.lifetime](unsafe_ptr=self[].data, len=self[].size) @always_inline - fn read(inout self, inout dest: List[Byte]) -> (Int, Error): - """Reads from the internal buffer into the dest List[Byte] struct. + fn _read(inout self, inout dest: Span[UInt8, True], capacity: Int) -> (Int, Error): + """Reads from the internal buffer into the dest List[UInt8] struct. Implements the [io.Reader] Interface. Args: - dest: The destination List[Byte] struct to read into. + dest: The destination Span[UInt8] struct to read into. + capacity: The capacity of the destination buffer. Returns: Int: The number of bytes read into dest.""" @@ -80,22 +81,40 @@ struct Reader( if self.index >= self.size: return 0, Error(io.EOF) - var unread_bytes = self.as_bytes_slice()[self.index : self.size] - # Copy the data of the internal buffer from offset to len(buf) into the destination buffer at the given index. self.prev_rune = -1 - var bytes_written = copy(dest, unread_bytes) + var bytes_written = copy(dest, self.as_bytes_slice()[self.index : self.size], len(dest)) + dest._len += bytes_written self.index += bytes_written return bytes_written, Error() @always_inline - fn read_at(self, inout dest: List[Byte], off: Int) -> (Int, Error): + fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + """Reads from the internal buffer into the dest List[UInt8] struct. + Implements the [io.Reader] Interface. + + Args: + dest: The destination List[UInt8] struct to read into. + + Returns: + Int: The number of bytes read into dest.""" + var span = Span(dest) + + var bytes_read: Int + var err: Error + bytes_read, err = self._read(span, dest.capacity) + dest.size += bytes_read + + return bytes_read, err + + @always_inline + fn _read_at(self, inout dest: Span[UInt8, True], off: Int, capacity: Int) -> (Int, Error): """Reads len(dest) bytes into dest beginning at byte offset off. Implements the [io.ReaderAt] Interface. Args: - dest: The destination List[Byte] struct to read into. + dest: The destination List[UInt8] struct to read into. off: The offset to start reading from. Returns: @@ -116,7 +135,28 @@ struct Reader( return bytes_written, Error() @always_inline - fn read_byte(inout self) -> (Byte, Error): + fn read_at(self, inout dest: List[UInt8], off: Int) -> (Int, Error): + """Reads len(dest) bytes into dest beginning at byte offset off. + Implements the [io.ReaderAt] Interface. + + Args: + dest: The destination List[UInt8] struct to read into. + off: The offset to start reading from. + + Returns: + Int: The number of bytes read into dest. + """ + var span = Span(dest) + + var bytes_read: Int + var err: Error + bytes_read, err = self._read_at(span, off, dest.capacity) + dest.size += bytes_read + + return bytes_read, err + + @always_inline + fn read_byte(inout self) -> (UInt8, Error): """Reads and returns a single byte from the internal buffer. Implements the [io.ByteReader] Interface.""" self.prev_rune = -1 if self.index >= self.size: @@ -221,7 +261,7 @@ struct Reader( return write_count, Error() @always_inline - fn reset(inout self, owned buffer: List[Byte]): + fn reset(inout self, owned buffer: List[UInt8]): """Resets the [Reader.Reader] to be reading from buffer. Args: @@ -234,7 +274,7 @@ struct Reader( self.prev_rune = -1 -fn new_reader(owned buffer: List[Byte]) -> Reader: +fn new_reader(owned buffer: List[UInt8]) -> Reader: """Returns a new [Reader.Reader] reading from b. Args: diff --git a/gojo/io/__init__.mojo b/gojo/io/__init__.mojo index 7738bd3..216bb1e 100644 --- a/gojo/io/__init__.mojo +++ b/gojo/io/__init__.mojo @@ -1,36 +1,328 @@ -from .traits import ( - Reader, - Writer, - Seeker, - Closer, - ReadWriter, - ReadCloser, - WriteCloser, - ReadWriteCloser, - ReadSeeker, - ReadSeekCloser, - WriteSeeker, - ReadWriteSeeker, - ReaderFrom, - WriterReadFrom, - WriterTo, - ReaderWriteTo, - ReaderAt, - WriterAt, - ByteReader, - ByteScanner, - ByteWriter, - RuneReader, - RuneScanner, - StringWriter, - SEEK_START, - SEEK_CURRENT, - SEEK_END, - ERR_SHORT_WRITE, - ERR_NO_PROGRESS, - ERR_SHORT_BUFFER, - EOF, -) from .io import write_string, read_at_least, read_full, read_all, BUFFER_SIZE from .file import FileWrapper from .std import STDWriter + + +alias Rune = Int32 + +# Package io provides basic interfaces to I/O primitives. +# Its primary job is to wrap existing implementations of such primitives, +# such as those in package os, into shared public interfaces that +# abstract the fntionality, plus some other related primitives. +# +# Because these interfaces and primitives wrap lower-level operations with +# various implementations, unless otherwise informed clients should not +# assume they are safe for parallel execution. +# Seek whence values. +alias SEEK_START = 0 # seek relative to the origin of the file +alias SEEK_CURRENT = 1 # seek relative to the current offset +alias SEEK_END = 2 # seek relative to the end + +# ERR_SHORT_WRITE means that a write accepted fewer bytes than requested +# but failed to return an explicit error. +alias ERR_SHORT_WRITE = "short write" + +# ERR_INVALID_WRITE means that a write returned an impossible count. +alias ERR_INVALID_WRITE = "invalid write result" + +# ERR_SHORT_BUFFER means that a read required a longer buffer than was provided. +alias ERR_SHORT_BUFFER = "short buffer" + +# EOF is the error returned by Read when no more input is available. +# (Read must return EOF itself, not an error wrapping EOF, +# because callers will test for EOF using ==.) +# fntions should return EOF only to signal a graceful end of input. +# If the EOF occurs unexpectedly in a structured data stream, +# the appropriate error is either [ERR_UNEXPECTED_EOF] or some other error +# giving more detail. +alias EOF = "EOF" + +# ERR_UNEXPECTED_EOF means that EOF was encountered in the +# middle of reading a fixed-size block or data structure. +alias ERR_UNEXPECTED_EOF = "unexpected EOF" + +# ERR_NO_PROGRESS is returned by some clients of a [Reader] when +# many calls to Read have failed to return any data or error, +# usually the sign of a broken [Reader] implementation. +alias ERR_NO_PROGRESS = "multiple Read calls return no data or error" + + +trait Reader(Movable): + """Reader is the trait that wraps the basic Read method. + + Read reads up to len(p) bytes into p. It returns the number of bytes + read (0 <= n <= len(p)) and any error encountered. Even if Read + returns n < len(p), it may use all of p as scratch space during the call. + If some data is available but not len(p) bytes, Read conventionally + returns what is available instead of waiting for more. + + When Read encounters an error or end-of-file condition after + successfully reading n > 0 bytes, it returns the number of + bytes read. It may return the (non-nil) error from the same call + or return the error (and n == 0) from a subsequent call. + An instance of this general case is that a Reader returning + a non-zero number of bytes at the end of the input stream may + return either err == EOF or err == nil. The next Read should + return 0, EOF. + + Callers should always process the n > 0 bytes returned before + considering the error err. Doing so correctly handles I/O errors + that happen after reading some bytes and also both of the + allowed EOF behaviors. + + If len(p) == 0, Read should always return n == 0. It may return a + non-nil error if some error condition is known, such as EOF. + + Implementations of Read are discouraged from returning a + zero byte count with a nil error, except when len(p) == 0. + Callers should treat a return of 0 and nil as indicating that + nothing happened; in particular it does not indicate EOF. + + Implementations must not retain p.""" + + fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + ... + + fn _read(inout self, inout dest: Span[UInt8, True], capacity: Int) -> (Int, Error): + ... + + +trait Writer(Movable): + """Writer is the trait that wraps the basic Write method. + + Write writes len(p) bytes from p to the underlying data stream. + It returns the number of bytes written from p (0 <= n <= len(p)) + and any error encountered that caused the write to stop early. + Write must return a non-nil error if it returns n < len(p). + Write must not modify the slice data, even temporarily. + + Implementations must not retain p. + """ + + fn write(inout self, src: List[UInt8]) -> (Int, Error): + ... + + +trait Closer(Movable): + """ + Closer is the trait that wraps the basic Close method. + + The behavior of Close after the first call is undefined. + Specific implementations may document their own behavior. + """ + + fn close(inout self) -> Error: + ... + + +trait Seeker(Movable): + """ + Seeker is the trait that wraps the basic Seek method. + + Seek sets the offset for the next Read or Write to offset, + interpreted according to whence: + [SEEK_START] means relative to the start of the file, + [SEEK_CURRENT] means relative to the current offset, and + [SEEK_END] means relative to the end + (for example, offset = -2 specifies the penultimate byte of the file). + Seek returns the new offset relative to the start of the + file or an error, if any. + + Seeking to an offset before the start of the file is an error. + Seeking to any positive offset may be allowed, but if the new offset exceeds + the size of the underlying object the behavior of subsequent I/O operations + is implementation-dependent. + """ + + fn seek(inout self, offset: Int, whence: Int) -> (Int, Error): + ... + + +trait ReadWriter(Reader, Writer): + ... + + +trait ReadCloser(Reader, Closer): + ... + + +trait WriteCloser(Writer, Closer): + ... + + +trait ReadWriteCloser(Reader, Writer, Closer): + ... + + +trait ReadSeeker(Reader, Seeker): + ... + + +trait ReadSeekCloser(Reader, Seeker, Closer): + ... + + +trait WriteSeeker(Writer, Seeker): + ... + + +trait ReadWriteSeeker(Reader, Writer, Seeker): + ... + + +trait ReaderFrom: + """ReaderFrom is the trait that wraps the ReadFrom method. + + ReadFrom reads data from r until EOF or error. + The return value n is the number of bytes read. + Any error except EOF encountered during the read is also returned. + + The [copy] function uses [ReaderFrom] if available.""" + + fn read_from[R: Reader](inout self, inout reader: R) -> (Int, Error): + ... + + +trait WriterReadFrom(Writer, ReaderFrom): + ... + + +trait WriterTo: + """WriterTo is the trait that wraps the WriteTo method. + + WriteTo writes data to w until there's no more data to write or + when an error occurs. The return value n is the number of bytes + written. Any error encountered during the write is also returned. + + The copy function uses WriterTo if available.""" + + fn write_to[W: Writer](inout self, inout writer: W) -> (Int, Error): + ... + + +trait ReaderWriteTo(Reader, WriterTo): + ... + + +trait ReaderAt: + """ReaderAt is the trait that wraps the basic ReadAt method. + + ReadAt reads len(p) bytes into p starting at offset off in the + underlying input source. It returns the number of bytes + read (0 <= n <= len(p)) and any error encountered. + + When ReadAt returns n < len(p), it returns a non-nil error + explaining why more bytes were not returned. In this respect, + ReadAt is stricter than Read. + + Even if ReadAt returns n < len(p), it may use all of p as scratch + space during the call. If some data is available but not len(p) bytes, + ReadAt blocks until either all the data is available or an error occurs. + In this respect ReadAt is different from Read. + + If the n = len(p) bytes returned by ReadAt are at the end of the + input source, ReadAt may return either err == EOF or err == nil. + + If ReadAt is reading from an input source with a seek offset, + ReadAt should not affect nor be affected by the underlying + seek offset. + + Clients of ReadAt can execute parallel ReadAt calls on the + same input source. + + Implementations must not retain p.""" + + fn read_at(self, inout dest: List[UInt8], off: Int) -> (Int, Error): + ... + + fn _read_at(self, inout dest: Span[UInt8, True], off: Int, capacity: Int) -> (Int, Error): + ... + + +trait WriterAt: + """WriterAt is the trait that wraps the basic WriteAt method. + + WriteAt writes len(p) bytes from p to the underlying data stream + at offset off. It returns the number of bytes written from p (0 <= n <= len(p)) + and any error encountered that caused the write to stop early. + WriteAt must return a non-nil error if it returns n < len(p). + + If WriteAt is writing to a destination with a seek offset, + WriteAt should not affect nor be affected by the underlying + seek offset. + + Clients of WriteAt can execute parallel WriteAt calls on the same + destination if the ranges do not overlap. + + Implementations must not retain p.""" + + fn write_at(self, src: Span[UInt8], off: Int) -> (Int, Error): + ... + + +trait ByteReader: + """ByteReader is the trait that wraps the read_byte method. + + read_byte reads and returns the next byte from the input or + any error encountered. If read_byte returns an error, no input + byte was consumed, and the returned byte value is undefined. + + read_byte provides an efficient trait for byte-at-time + processing. A [Reader] that does not implement ByteReader + can be wrapped using bufio.NewReader to add this method.""" + + fn read_byte(inout self) -> (UInt8, Error): + ... + + +trait ByteScanner(ByteReader): + """ByteScanner is the trait that adds the unread_byte method to the + basic read_byte method. + + unread_byte causes the next call to read_byte to return the last byte read. + If the last operation was not a successful call to read_byte, unread_byte may + return an error, unread the last byte read (or the byte prior to the + last-unread byte), or (in implementations that support the [Seeker] trait) + seek to one byte before the current offset.""" + + fn unread_byte(inout self) -> Error: + ... + + +trait ByteWriter: + """ByteWriter is the trait that wraps the write_byte method.""" + + fn write_byte(inout self, byte: UInt8) -> (Int, Error): + ... + + +trait RuneReader: + """RuneReader is the trait that wraps the read_rune method. + + read_rune reads a single encoded Unicode character + and returns the rune and its size in bytes. If no character is + available, err will be set.""" + + fn read_rune(inout self) -> (Rune, Int): + ... + + +trait RuneScanner(RuneReader): + """RuneScanner is the trait that adds the unread_rune method to the + basic read_rune method. + + unread_rune causes the next call to read_rune to return the last rune read. + If the last operation was not a successful call to read_rune, unread_rune may + return an error, unread the last rune read (or the rune prior to the + last-unread rune), or (in implementations that support the [Seeker] trait) + seek to the start of the rune before the current offset.""" + + fn unread_rune(inout self) -> Rune: + ... + + +trait StringWriter: + """StringWriter is the trait that wraps the WriteString method.""" + + fn write_string(inout self, src: String) -> (Int, Error): + ... diff --git a/gojo/io/file.mojo b/gojo/io/file.mojo index 445f088..a904b3c 100644 --- a/gojo/io/file.mojo +++ b/gojo/io/file.mojo @@ -30,13 +30,37 @@ struct FileWrapper(FileDescriptorBase, io.ByteReader): return Error() + @always_inline + fn _read(inout self, inout dest: Span[UInt8, True], capacity: Int) -> (Int, Error): + """Read from the file handle into dest's pointer. + Pretty hacky way to force the filehandle read into the defined trait, and it's unsafe since we're + reading directly into the pointer. + """ + # var bytes_to_read = dest.capacity - len(dest) + var bytes_read: Int + var result: List[UInt8] + try: + result = self.handle.read_bytes() + bytes_read = len(result) + # TODO: Need to raise an Issue for this. Reading with pointer does not return an accurate count of bytes_read :( + # bytes_read = int(self.handle.read(DTypePointer[DType.uint8](dest.unsafe_ptr()) + dest.size)) + except e: + return 0, e + + _ = copy(dest, Span(result), len(dest)) + + if bytes_read == 0: + return bytes_read, Error(io.EOF) + + return bytes_read, Error() + @always_inline fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): """Read from the file handle into dest's pointer. Pretty hacky way to force the filehandle read into the defined trait, and it's unsafe since we're reading directly into the pointer. """ - var bytes_to_read = dest.capacity - len(dest) + # var bytes_to_read = dest.capacity - len(dest) var bytes_read: Int var result: List[UInt8] try: @@ -47,9 +71,9 @@ struct FileWrapper(FileDescriptorBase, io.ByteReader): except e: return 0, e - _ = copy(dest, result, dest.size) + _ = copy(dest, result, len(dest)) - if bytes_read == 0 or bytes_read < bytes_to_read: + if bytes_read == 0: return bytes_read, Error(io.EOF) return bytes_read, Error() diff --git a/gojo/io/io.mojo b/gojo/io/io.mojo index c79982f..064b06d 100644 --- a/gojo/io/io.mojo +++ b/gojo/io/io.mojo @@ -1,5 +1,4 @@ from ..builtins import copy, Byte, panic -from .traits import ERR_UNEXPECTED_EOF alias BUFFER_SIZE = 4096 @@ -405,6 +404,7 @@ fn read_full[R: Reader](inout reader: R, inout dest: List[Byte]) -> (Int, Error) # } +# TODO: read directly into dest fn read_all[R: Reader](inout reader: R) -> (List[Byte], Error): """Reads from r until an error or EOF and returns the data it read. A successful call returns err == nil, not err == EOF. Because ReadAll is diff --git a/gojo/io/traits.mojo b/gojo/io/traits.mojo deleted file mode 100644 index d499545..0000000 --- a/gojo/io/traits.mojo +++ /dev/null @@ -1,320 +0,0 @@ -from collections.optional import Optional -from ..builtins import Byte - -alias Rune = Int32 - -# Package io provides basic interfaces to I/O primitives. -# Its primary job is to wrap existing implementations of such primitives, -# such as those in package os, into shared public interfaces that -# abstract the fntionality, plus some other related primitives. -# -# Because these interfaces and primitives wrap lower-level operations with -# various implementations, unless otherwise informed clients should not -# assume they are safe for parallel execution. -# Seek whence values. -alias SEEK_START = 0 # seek relative to the origin of the file -alias SEEK_CURRENT = 1 # seek relative to the current offset -alias SEEK_END = 2 # seek relative to the end - -# ERR_SHORT_WRITE means that a write accepted fewer bytes than requested -# but failed to return an explicit error. -alias ERR_SHORT_WRITE = "short write" - -# ERR_INVALID_WRITE means that a write returned an impossible count. -alias ERR_INVALID_WRITE = "invalid write result" - -# ERR_SHORT_BUFFER means that a read required a longer buffer than was provided. -alias ERR_SHORT_BUFFER = "short buffer" - -# EOF is the error returned by Read when no more input is available. -# (Read must return EOF itself, not an error wrapping EOF, -# because callers will test for EOF using ==.) -# fntions should return EOF only to signal a graceful end of input. -# If the EOF occurs unexpectedly in a structured data stream, -# the appropriate error is either [ERR_UNEXPECTED_EOF] or some other error -# giving more detail. -alias EOF = "EOF" - -# ERR_UNEXPECTED_EOF means that EOF was encountered in the -# middle of reading a fixed-size block or data structure. -alias ERR_UNEXPECTED_EOF = "unexpected EOF" - -# ERR_NO_PROGRESS is returned by some clients of a [Reader] when -# many calls to Read have failed to return any data or error, -# usually the sign of a broken [Reader] implementation. -alias ERR_NO_PROGRESS = "multiple Read calls return no data or error" - - -trait Reader(Movable): - """Reader is the trait that wraps the basic Read method. - - Read reads up to len(p) bytes into p. It returns the number of bytes - read (0 <= n <= len(p)) and any error encountered. Even if Read - returns n < len(p), it may use all of p as scratch space during the call. - If some data is available but not len(p) bytes, Read conventionally - returns what is available instead of waiting for more. - - When Read encounters an error or end-of-file condition after - successfully reading n > 0 bytes, it returns the number of - bytes read. It may return the (non-nil) error from the same call - or return the error (and n == 0) from a subsequent call. - An instance of this general case is that a Reader returning - a non-zero number of bytes at the end of the input stream may - return either err == EOF or err == nil. The next Read should - return 0, EOF. - - Callers should always process the n > 0 bytes returned before - considering the error err. Doing so correctly handles I/O errors - that happen after reading some bytes and also both of the - allowed EOF behaviors. - - If len(p) == 0, Read should always return n == 0. It may return a - non-nil error if some error condition is known, such as EOF. - - Implementations of Read are discouraged from returning a - zero byte count with a nil error, except when len(p) == 0. - Callers should treat a return of 0 and nil as indicating that - nothing happened; in particular it does not indicate EOF. - - Implementations must not retain p.""" - - fn read(inout self, inout dest: List[Byte]) -> (Int, Error): - ... - - -trait Writer(Movable): - """Writer is the trait that wraps the basic Write method. - - Write writes len(p) bytes from p to the underlying data stream. - It returns the number of bytes written from p (0 <= n <= len(p)) - and any error encountered that caused the write to stop early. - Write must return a non-nil error if it returns n < len(p). - Write must not modify the slice data, even temporarily. - - Implementations must not retain p. - """ - - fn write(inout self, src: List[Byte]) -> (Int, Error): - ... - - -trait Closer(Movable): - """ - Closer is the trait that wraps the basic Close method. - - The behavior of Close after the first call is undefined. - Specific implementations may document their own behavior. - """ - - fn close(inout self) -> Error: - ... - - -trait Seeker(Movable): - """ - Seeker is the trait that wraps the basic Seek method. - - Seek sets the offset for the next Read or Write to offset, - interpreted according to whence: - [SEEK_START] means relative to the start of the file, - [SEEK_CURRENT] means relative to the current offset, and - [SEEK_END] means relative to the end - (for example, offset = -2 specifies the penultimate byte of the file). - Seek returns the new offset relative to the start of the - file or an error, if any. - - Seeking to an offset before the start of the file is an error. - Seeking to any positive offset may be allowed, but if the new offset exceeds - the size of the underlying object the behavior of subsequent I/O operations - is implementation-dependent. - """ - - fn seek(inout self, offset: Int, whence: Int) -> (Int, Error): - ... - - -trait ReadWriter(Reader, Writer): - ... - - -trait ReadCloser(Reader, Closer): - ... - - -trait WriteCloser(Writer, Closer): - ... - - -trait ReadWriteCloser(Reader, Writer, Closer): - ... - - -trait ReadSeeker(Reader, Seeker): - ... - - -trait ReadSeekCloser(Reader, Seeker, Closer): - ... - - -trait WriteSeeker(Writer, Seeker): - ... - - -trait ReadWriteSeeker(Reader, Writer, Seeker): - ... - - -trait ReaderFrom: - """ReaderFrom is the trait that wraps the ReadFrom method. - - ReadFrom reads data from r until EOF or error. - The return value n is the number of bytes read. - Any error except EOF encountered during the read is also returned. - - The [copy] function uses [ReaderFrom] if available.""" - - fn read_from[R: Reader](inout self, inout reader: R) -> (Int, Error): - ... - - -trait WriterReadFrom(Writer, ReaderFrom): - ... - - -trait WriterTo: - """WriterTo is the trait that wraps the WriteTo method. - - WriteTo writes data to w until there's no more data to write or - when an error occurs. The return value n is the number of bytes - written. Any error encountered during the write is also returned. - - The copy function uses WriterTo if available.""" - - fn write_to[W: Writer](inout self, inout writer: W) -> (Int, Error): - ... - - -trait ReaderWriteTo(Reader, WriterTo): - ... - - -trait ReaderAt: - """ReaderAt is the trait that wraps the basic ReadAt method. - - ReadAt reads len(p) bytes into p starting at offset off in the - underlying input source. It returns the number of bytes - read (0 <= n <= len(p)) and any error encountered. - - When ReadAt returns n < len(p), it returns a non-nil error - explaining why more bytes were not returned. In this respect, - ReadAt is stricter than Read. - - Even if ReadAt returns n < len(p), it may use all of p as scratch - space during the call. If some data is available but not len(p) bytes, - ReadAt blocks until either all the data is available or an error occurs. - In this respect ReadAt is different from Read. - - If the n = len(p) bytes returned by ReadAt are at the end of the - input source, ReadAt may return either err == EOF or err == nil. - - If ReadAt is reading from an input source with a seek offset, - ReadAt should not affect nor be affected by the underlying - seek offset. - - Clients of ReadAt can execute parallel ReadAt calls on the - same input source. - - Implementations must not retain p.""" - - fn read_at(self, inout dest: List[Byte], off: Int) -> (Int, Error): - ... - - -trait WriterAt: - """WriterAt is the trait that wraps the basic WriteAt method. - - WriteAt writes len(p) bytes from p to the underlying data stream - at offset off. It returns the number of bytes written from p (0 <= n <= len(p)) - and any error encountered that caused the write to stop early. - WriteAt must return a non-nil error if it returns n < len(p). - - If WriteAt is writing to a destination with a seek offset, - WriteAt should not affect nor be affected by the underlying - seek offset. - - Clients of WriteAt can execute parallel WriteAt calls on the same - destination if the ranges do not overlap. - - Implementations must not retain p.""" - - fn write_at(self, src: Span[Byte], off: Int) -> (Int, Error): - ... - - -trait ByteReader: - """ByteReader is the trait that wraps the read_byte method. - - read_byte reads and returns the next byte from the input or - any error encountered. If read_byte returns an error, no input - byte was consumed, and the returned byte value is undefined. - - read_byte provides an efficient trait for byte-at-time - processing. A [Reader] that does not implement ByteReader - can be wrapped using bufio.NewReader to add this method.""" - - fn read_byte(inout self) -> (Byte, Error): - ... - - -trait ByteScanner(ByteReader): - """ByteScanner is the trait that adds the unread_byte method to the - basic read_byte method. - - unread_byte causes the next call to read_byte to return the last byte read. - If the last operation was not a successful call to read_byte, unread_byte may - return an error, unread the last byte read (or the byte prior to the - last-unread byte), or (in implementations that support the [Seeker] trait) - seek to one byte before the current offset.""" - - fn unread_byte(inout self) -> Error: - ... - - -trait ByteWriter: - """ByteWriter is the trait that wraps the write_byte method.""" - - fn write_byte(inout self, byte: Byte) -> (Int, Error): - ... - - -trait RuneReader: - """RuneReader is the trait that wraps the read_rune method. - - read_rune reads a single encoded Unicode character - and returns the rune and its size in bytes. If no character is - available, err will be set.""" - - fn read_rune(inout self) -> (Rune, Int): - ... - - -trait RuneScanner(RuneReader): - """RuneScanner is the trait that adds the unread_rune method to the - basic read_rune method. - - unread_rune causes the next call to read_rune to return the last rune read. - If the last operation was not a successful call to read_rune, unread_rune may - return an error, unread the last rune read (or the rune prior to the - last-unread rune), or (in implementations that support the [Seeker] trait) - seek to the start of the rune before the current offset.""" - - fn unread_rune(inout self) -> Rune: - ... - - -trait StringWriter: - """StringWriter is the trait that wraps the WriteString method.""" - - fn write_string(inout self, src: String) -> (Int, Error): - ... diff --git a/gojo/net/__init__.mojo b/gojo/net/__init__.mojo index 138d235..12b01d4 100644 --- a/gojo/net/__init__.mojo +++ b/gojo/net/__init__.mojo @@ -5,8 +5,13 @@ A good chunk of the leg work here came from the lightbug_http project! https://g from .fd import FileDescriptor from .socket import Socket -from .tcp import TCPConnection, TCPListener, listen_tcp -from .address import TCPAddr, NetworkType, Addr +from .tcp import TCPConnection, TCPListener, listen_tcp, dial_tcp, TCPAddr +from .udp import UDPAddr, UDPConnection, listen_udp, dial_udp +from .address import NetworkType, Addr, HostPort from .ip import get_ip_address, get_addr_info -from .dial import dial_tcp, Dialer -from .net import Connection, Conn + + +# Time in nanoseconds +alias Duration = Int +alias DEFAULT_BUFFER_SIZE = 4096 +alias DEFAULT_TCP_KEEP_ALIVE = Duration(15 * 1000 * 1000 * 1000) # 15 seconds diff --git a/gojo/net/address.mojo b/gojo/net/address.mojo index 9bf5a50..9278d9c 100644 --- a/gojo/net/address.mojo +++ b/gojo/net/address.mojo @@ -22,7 +22,7 @@ trait Addr(CollectionElement, Stringable): @value -struct TCPAddr(Addr): +struct BaseAddr: """Addr struct representing a TCP address. Args: @@ -35,29 +35,32 @@ struct TCPAddr(Addr): var port: Int var zone: String # IPv6 addressing zone - fn __init__(inout self): - self.ip = String("127.0.0.1") - self.port = 8000 - self.zone = "" - - fn __init__(inout self, ip: String, port: Int): + fn __init__(inout self, ip: String = "", port: Int = 0, zone: String = ""): self.ip = ip self.port = port - self.zone = "" + self.zone = zone + + fn __init__(inout self, other: TCPAddr): + self.ip = other.ip + self.port = other.port + self.zone = other.zone + + fn __init__(inout self, other: UDPAddr): + self.ip = other.ip + self.port = other.port + self.zone = other.zone fn __str__(self) -> String: if self.zone != "": - return join_host_port(str(self.ip) + "%" + self.zone, str(self.port)) + return join_host_port(self.ip + "%" + self.zone, str(self.port)) return join_host_port(self.ip, str(self.port)) - fn network(self) -> String: - return NetworkType.tcp.value - -fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: +fn resolve_internet_addr(network: String, address: String) -> (TCPAddr, Error): var host: String = "" var port: String = "" var portnum: Int = 0 + var err = Error() if ( network == NetworkType.tcp.value or network == NetworkType.tcp4.value @@ -67,29 +70,33 @@ fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: or network == NetworkType.udp6.value ): if address != "": - var host_port = split_host_port(address) - host = host_port.host - port = str(host_port.port) - portnum = atol(port.__str__()) + var result = split_host_port(address) + if result[1]: + return TCPAddr(), result[1] + + host = result[0].host + port = str(result[0].port) + portnum = result[0].port elif network == NetworkType.ip.value or network == NetworkType.ip4.value or network == NetworkType.ip6.value: if address != "": host = address elif network == NetworkType.unix.value: - raise Error("Unix addresses not supported yet") + return TCPAddr(), Error("Unix addresses not supported yet") else: - raise Error("unsupported network type: " + network) - return TCPAddr(host, portnum) + return TCPAddr(), Error("unsupported network type: " + network) + return TCPAddr(host, portnum), err -alias missingPortError = Error("missing port in address") -alias tooManyColonsError = Error("too many colons in address") +alias MISSING_PORT_ERROR = Error("missing port in address") +alias TOO_MANY_COLONS_ERROR = Error("too many colons in address") +@value struct HostPort(Stringable): var host: String var port: Int - fn __init__(inout self, host: String, port: Int): + fn __init__(inout self, host: String = "", port: Int = 0): self.host = host self.port = port @@ -103,7 +110,7 @@ fn join_host_port(host: String, port: String) -> String: return host + ":" + port -fn split_host_port(hostport: String) raises -> HostPort: +fn split_host_port(hostport: String) -> (HostPort, Error): var host: String = "" var port: String = "" var colon_index = hostport.rfind(":") @@ -111,35 +118,38 @@ fn split_host_port(hostport: String) raises -> HostPort: var k: Int = 0 if colon_index == -1: - raise missingPortError + return HostPort(), MISSING_PORT_ERROR if hostport[0] == "[": var end_bracket_index = hostport.find("]") if end_bracket_index == -1: - raise Error("missing ']' in address") + return HostPort(), Error("missing ']' in address") if end_bracket_index + 1 == len(hostport): - raise missingPortError + return HostPort(), MISSING_PORT_ERROR elif end_bracket_index + 1 == colon_index: host = hostport[1:end_bracket_index] j = 1 k = end_bracket_index + 1 else: if hostport[end_bracket_index + 1] == ":": - raise tooManyColonsError + return HostPort(), TOO_MANY_COLONS_ERROR else: - raise missingPortError + return HostPort(), MISSING_PORT_ERROR else: host = hostport[:colon_index] if host.find(":") != -1: - raise tooManyColonsError + return HostPort(), TOO_MANY_COLONS_ERROR if hostport[j:].find("[") != -1: - raise Error("unexpected '[' in address") + return HostPort(), Error("unexpected '[' in address") if hostport[k:].find("]") != -1: - raise Error("unexpected ']' in address") + return HostPort(), Error("unexpected ']' in address") port = hostport[colon_index + 1 :] if port == "": - raise missingPortError + return HostPort(), MISSING_PORT_ERROR if host == "": - raise Error("missing host") + return HostPort(), Error("missing host") - return HostPort(host, atol(port)) + try: + return HostPort(host, atol(port)), Error() + except e: + return HostPort(), e diff --git a/gojo/net/dial.mojo b/gojo/net/dial.mojo deleted file mode 100644 index 98442af..0000000 --- a/gojo/net/dial.mojo +++ /dev/null @@ -1,45 +0,0 @@ -from .tcp import TCPAddr, TCPConnection, resolve_internet_addr -from .socket import Socket -from .address import split_host_port - - -@value -struct Dialer: - var local_address: TCPAddr - - @always_inline - fn dial(self, network: String, address: String) raises -> TCPConnection: - var tcp_addr = resolve_internet_addr(network, address) - var socket = Socket(local_address=self.local_address) - socket.connect(tcp_addr.ip, tcp_addr.port) - return TCPConnection(socket^) - - -fn dial_tcp(network: String, remote_address: TCPAddr) raises -> TCPConnection: - """Connects to the address on the named network. - - The network must be "tcp", "tcp4", or "tcp6". - Args: - network: The network type. - remote_address: The remote address to connect to. - - Returns: - The TCP connection. - """ - # TODO: Add conversion of domain name to ip address - return Dialer(remote_address).dial(network, remote_address.ip + ":" + str(remote_address.port)) - - -fn dial_tcp(network: String, remote_address: String) raises -> TCPConnection: - """Connects to the address on the named network. - - The network must be "tcp", "tcp4", or "tcp6". - Args: - network: The network type. - remote_address: The remote address to connect to. - - Returns: - The TCP connection. - """ - var address = split_host_port(remote_address) - return Dialer(TCPAddr(address.host, address.port)).dial(network, remote_address) diff --git a/gojo/net/fd.mojo b/gojo/net/fd.mojo index 7484699..a9d9d4c 100644 --- a/gojo/net/fd.mojo +++ b/gojo/net/fd.mojo @@ -41,23 +41,35 @@ struct FileDescriptor(FileDescriptorBase): return Error() @always_inline - fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + fn _read(inout self, inout dest: Span[UInt8, True], capacity: Int) -> (Int, Error): """Receive data from the file descriptor and write it to the buffer provided.""" var bytes_received = recv( self.fd, - dest.unsafe_ptr() + dest.size, - dest.capacity - dest.size, + dest.unsafe_ptr() + len(dest), + capacity - len(dest), 0, ) + if bytes_received == 0: + return bytes_received, Error(io.EOF) + if bytes_received == -1: return 0, Error("Failed to receive message from socket.") - dest.size += bytes_received - - if bytes_received < dest.capacity: - return bytes_received, Error(io.EOF) + dest._len += bytes_received return bytes_received, Error() + @always_inline + fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + """Receive data from the file descriptor and write it to the buffer provided.""" + var span = Span(dest) + + var bytes_read: Int + var err: Error + bytes_read, err = self._read(span, dest.capacity) + dest.size += bytes_read + + return bytes_read, err + @always_inline fn write(inout self, src: List[UInt8]) -> (Int, Error): """Write data from the buffer to the file descriptor.""" diff --git a/gojo/net/ip.mojo b/gojo/net/ip.mojo index b3ab46c..4af7748 100644 --- a/gojo/net/ip.mojo +++ b/gojo/net/ip.mojo @@ -23,6 +23,7 @@ from ..syscall import ( getaddrinfo_unix, gai_strerror, ) +from .address import HostPort alias AddrInfo = Variant[addrinfo, addrinfo_unix] @@ -159,3 +160,27 @@ fn build_sockaddr_pointer(ip_address: String, port: Int, address_family: Int) -> var ai = sockaddr_in(address_family, bin_port, bin_ip, StaticTuple[c_char, 8]()) return UnsafePointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() + + +fn convert_sockaddr_to_host_port(sockaddr: UnsafePointer[sockaddr]) -> (HostPort, Error): + """Casts a sockaddr pointer to a sockaddr_in pointer and converts the binary IP and port to a string and int respectively. + + Args: + sockaddr: The sockaddr pointer to convert. + + Returns: + A tuple containing the HostPort and an Error if any occurred,. + """ + if not sockaddr: + return HostPort(), Error("sockaddr is null, nothing to convert.") + + # Cast sockaddr struct to sockaddr_in to convert binary IP to string. + var addr_in = move_from_pointee(sockaddr.bitcast[sockaddr_in]()) + + return ( + HostPort( + host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AddressFamily.AF_INET, 16), + port=convert_binary_port_to_int(addr_in.sin_port), + ), + Error(), + ) diff --git a/gojo/net/net.mojo b/gojo/net/net.mojo deleted file mode 100644 index f4c4d21..0000000 --- a/gojo/net/net.mojo +++ /dev/null @@ -1,124 +0,0 @@ -import ..io -from .socket import Socket -from .address import Addr, TCPAddr - -alias DEFAULT_BUFFER_SIZE = 4096 - - -trait Conn(io.Writer, io.Reader, io.Closer): - fn __init__(inout self, owned socket: Socket): - ... - - """Conn is a generic stream-oriented network connection.""" - - fn local_address(self) -> TCPAddr: - """Returns the local network address, if known.""" - ... - - fn remote_address(self) -> TCPAddr: - """Returns the local network address, if known.""" - ... - - # fn set_deadline(self, t: time.Time) -> Error: - # """Sets the read and write deadlines associated - # with the connection. It is equivalent to calling both - # SetReadDeadline and SetWriteDeadline. - - # A deadline is an absolute time after which I/O operations - # fail instead of blocking. The deadline applies to all future - # and pending I/O, not just the immediately following call to - # read or write. After a deadline has been exceeded, the - # connection can be refreshed by setting a deadline in the future. - - # If the deadline is exceeded a call to read or write or to other - # I/O methods will return an error that wraps os.ErrDeadlineExceeded. - # This can be tested using errors.Is(err, os.ErrDeadlineExceeded). - # The error's Timeout method will return true, but note that there - # are other possible errors for which the Timeout method will - # return true even if the deadline has not been exceeded. - - # An idle timeout can be implemented by repeatedly extending - # the deadline after successful read or write calls. - - # A zero value for t means I/O operations will not time out.""" - # ... - - # fn set_read_deadline(self, t: time.Time) -> Error: - # """Sets the deadline for future read calls - # and any currently-blocked read call. - # A zero value for t means read will not time out.""" - # ... - - # fn set_write_deadline(self, t: time.Time) -> Error: - # """Sets the deadline for future write calls - # and any currently-blocked write call. - # Even if write times out, it may return n > 0, indicating that - # some of the data was successfully written. - # A zero value for t means write will not time out.""" - # ... - - -struct Connection(Conn): - """Connection is a concrete generic stream-oriented network connection. - It is used as the internal connection for structs like TCPConnection. - - Args: - fd: The file descriptor of the connection. - """ - - var fd: Socket - - @always_inline - fn __init__(inout self, owned socket: Socket): - self.fd = socket^ - - @always_inline - fn __moveinit__(inout self, owned existing: Self): - self.fd = existing.fd^ - - @always_inline - fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): - """Reads data from the underlying file descriptor. - - Args: - dest: The buffer to read data into. - - Returns: - The number of bytes read, or an error if one occurred. - """ - return self.fd.read(dest) - - @always_inline - fn write(inout self, src: List[UInt8]) -> (Int, Error): - """Writes data to the underlying file descriptor. - - Args: - src: The buffer to read data into. - - Returns: - The number of bytes written, or an error if one occurred. - """ - return self.fd.write(src) - - @always_inline - fn close(inout self) -> Error: - """Closes the underlying file descriptor. - - Returns: - An error if one occurred, or None if the file descriptor was closed successfully. - """ - return self.fd.close() - - @always_inline - fn local_address(self) -> TCPAddr: - """Returns the local network address. - The Addr returned is shared by all invocations of local_address, so do not modify it. - """ - return self.fd.local_address - - @always_inline - fn remote_address(self) -> TCPAddr: - """Returns the remote network address. - The Addr returned is shared by all invocations of remote_address, so do not modify it. - """ - return self.fd.remote_address diff --git a/gojo/net/socket.mojo b/gojo/net/socket.mojo index cceb388..cafc451 100644 --- a/gojo/net/socket.mojo +++ b/gojo/net/socket.mojo @@ -11,7 +11,9 @@ from ..syscall import ( socket, connect, recv, + recvfrom, send, + sendto, shutdown, inet_pton, inet_ntoa, @@ -41,8 +43,9 @@ from .ip import ( convert_binary_ip_to_string, build_sockaddr_pointer, convert_binary_port_to_int, + convert_sockaddr_to_host_port, ) -from .address import Addr, TCPAddr, HostPort +from .address import Addr, BaseAddr, HostPort alias SocketClosedError = Error("Socket: Socket is already closed") @@ -58,21 +61,21 @@ struct Socket(FileDescriptorBase): protocol: The protocol. """ - var sockfd: FileDescriptor + var fd: FileDescriptor var address_family: Int - var socket_type: UInt8 + var socket_type: Int32 var protocol: UInt8 - var local_address: TCPAddr - var remote_address: TCPAddr + var local_address: BaseAddr + var remote_address: BaseAddr var _closed: Bool var _is_connected: Bool fn __init__( inout self, - local_address: TCPAddr = TCPAddr(), - remote_address: TCPAddr = TCPAddr(), + local_address: BaseAddr = BaseAddr(), + remote_address: BaseAddr = BaseAddr(), address_family: Int = AddressFamily.AF_INET, - socket_type: UInt8 = SocketType.SOCK_STREAM, + socket_type: Int32 = SocketType.SOCK_STREAM, protocol: UInt8 = 0, ) raises: """Create a new socket object. @@ -88,10 +91,10 @@ struct Socket(FileDescriptorBase): self.socket_type = socket_type self.protocol = protocol - var fd = socket(address_family, SocketType.SOCK_STREAM, 0) + var fd = socket(address_family, socket_type, 0) if fd == -1: raise Error("Socket creation error") - self.sockfd = FileDescriptor(int(fd)) + self.fd = FileDescriptor(int(fd)) self.local_address = local_address self.remote_address = remote_address self._closed = False @@ -101,10 +104,10 @@ struct Socket(FileDescriptorBase): inout self, fd: Int32, address_family: Int, - socket_type: UInt8, + socket_type: Int32, protocol: UInt8, - local_address: TCPAddr = TCPAddr(), - remote_address: TCPAddr = TCPAddr(), + local_address: BaseAddr = BaseAddr(), + remote_address: BaseAddr = BaseAddr(), ): """ Create a new socket object when you already have a socket file descriptor. Typically through socket.accept(). @@ -114,10 +117,10 @@ struct Socket(FileDescriptorBase): address_family: The address family of the socket. socket_type: The socket type. protocol: The protocol. - local_address: Local address of socket. - remote_address: Remote address of port. + local_address: The local address of the socket (local address if bound). + remote_address: The remote address of the socket (peer's address if connected). """ - self.sockfd = FileDescriptor(int(fd)) + self.fd = FileDescriptor(int(fd)) self.address_family = address_family self.socket_type = socket_type self.protocol = protocol @@ -127,7 +130,7 @@ struct Socket(FileDescriptorBase): self._is_connected = True fn __moveinit__(inout self, owned existing: Self): - self.sockfd = existing.sockfd^ + self.fd = existing.fd^ self.address_family = existing.address_family self.socket_type = existing.socket_type self.protocol = existing.protocol @@ -143,41 +146,64 @@ struct Socket(FileDescriptorBase): # if self._is_connected: # self.shutdown() # if not self._closed: - # self.close() + # var err = self.close() + # if err: + # raise err fn __del__(owned self): if self._is_connected: self.shutdown() if not self._closed: var err = self.close() - _ = self.sockfd.fd + _ = self.fd.fd if err: print("Failed to close socket during deletion:", str(err)) @always_inline - fn accept(self) raises -> Self: + fn local_address_as_udp(self) -> UDPAddr: + return UDPAddr(self.local_address) + + @always_inline + fn local_address_as_tcp(self) -> TCPAddr: + return TCPAddr(self.local_address) + + @always_inline + fn remote_address_as_udp(self) -> UDPAddr: + return UDPAddr(self.remote_address) + + @always_inline + fn remote_address_as_tcp(self) -> TCPAddr: + return TCPAddr(self.remote_address) + + @always_inline + fn accept(self) raises -> Socket: """Accept a connection. The socket must be bound to an address and listening for connections. The return value is a connection where conn is a new socket object usable to send and receive data on the connection, and address is the address bound to the socket on the other end of the connection. """ - var their_addr_ptr = UnsafePointer[sockaddr].alloc(1) + var remote_address_ptr = UnsafePointer[sockaddr].alloc(1) var sin_size = socklen_t(sizeof[socklen_t]()) - var new_sockfd = accept( - self.sockfd.fd, - their_addr_ptr, + var new_fd = accept( + self.fd.fd, + remote_address_ptr, UnsafePointer[socklen_t].address_of(sin_size), ) - if new_sockfd == -1: + if new_fd == -1: raise Error("Failed to accept connection") - var remote = self.get_peer_name() - return Self( - new_sockfd, + var remote: HostPort + var err: Error + remote, err = convert_sockaddr_to_host_port(remote_address_ptr) + if err: + raise err + + return Socket( + new_fd, self.address_family, self.socket_type, self.protocol, self.local_address, - TCPAddr(remote.host, remote.port), + BaseAddr(remote.host, remote.port), ) fn listen(self, backlog: Int = 0) raises: @@ -189,7 +215,7 @@ struct Socket(FileDescriptorBase): var queued = backlog if backlog < 0: queued = 0 - if listen(self.sockfd.fd, queued) == -1: + if listen(self.fd.fd, queued) == -1: raise Error("Failed to listen for connections") @always_inline @@ -199,7 +225,7 @@ struct Socket(FileDescriptorBase): When a socket is created with Socket(), it exists in a name space (address family) but has no address assigned to it. bind() assigns the address specified by addr to the socket referred to - by the file descriptor sockfd. addrlen specifies the size, in + by the file descriptor fd. addrlen specifies the size, in bytes, of the address structure pointed to by addr. Traditionally, this operation is called 'assigning a name to a socket'. @@ -210,17 +236,17 @@ struct Socket(FileDescriptorBase): """ var sockaddr_pointer = build_sockaddr_pointer(address, port, self.address_family) - if bind(self.sockfd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1: - _ = shutdown(self.sockfd.fd, SHUT_RDWR) + if bind(self.fd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1: + _ = shutdown(self.fd.fd, SHUT_RDWR) raise Error("Binding socket failed. Wait a few seconds and try again?") var local = self.get_sock_name() - self.local_address = TCPAddr(local.host, local.port) + self.local_address = BaseAddr(local.host, local.port) @always_inline fn file_no(self) -> Int32: """Return the file descriptor of the socket.""" - return self.sockfd.fd + return self.fd.fd @always_inline fn get_sock_name(self) raises -> HostPort: @@ -233,7 +259,7 @@ struct Socket(FileDescriptorBase): var local_address_ptr = UnsafePointer[sockaddr].alloc(1) var local_address_ptr_size = socklen_t(sizeof[sockaddr]()) var status = getsockname( - self.sockfd.fd, + self.fd.fd, local_address_ptr, UnsafePointer[socklen_t].address_of(local_address_ptr_size), ) @@ -246,29 +272,29 @@ struct Socket(FileDescriptorBase): port=convert_binary_port_to_int(addr_in.sin_port), ) - fn get_peer_name(self) raises -> HostPort: + fn get_peer_name(self) -> (HostPort, Error): """Return the address of the peer connected to the socket.""" if self._closed: - raise SocketClosedError + return HostPort(), SocketClosedError # TODO: Add check to see if the socket is bound and error if not. var remote_address_ptr = UnsafePointer[sockaddr].alloc(1) var remote_address_ptr_size = socklen_t(sizeof[sockaddr]()) var status = getpeername( - self.sockfd.fd, + self.fd.fd, remote_address_ptr, UnsafePointer[socklen_t].address_of(remote_address_ptr_size), ) if status == -1: - raise Error("Socket.get_peer_name: Failed to get address of remote socket.") + return HostPort(), Error("Socket.get_peer_name: Failed to get address of remote socket.") - # Cast sockaddr struct to sockaddr_in to convert binary IP to string. - var addr_in = move_from_pointee(remote_address_ptr.bitcast[sockaddr_in]()) + var remote: HostPort + var err: Error + remote, err = convert_sockaddr_to_host_port(remote_address_ptr) + if err: + return HostPort(), err - return HostPort( - host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AddressFamily.AF_INET, 16), - port=convert_binary_port_to_int(addr_in.sin_port), - ) + return remote, Error() fn get_socket_option(self, option_name: Int) raises -> Int: """Return the value of the given socket option. @@ -280,7 +306,7 @@ struct Socket(FileDescriptorBase): var option_len = socklen_t(sizeof[socklen_t]()) var option_len_pointer = UnsafePointer.address_of(option_len) var status = getsockopt( - self.sockfd.fd, + self.fd.fd, SOL_SOCKET, option_name, option_value_pointer, @@ -301,7 +327,7 @@ struct Socket(FileDescriptorBase): var option_value_pointer = UnsafePointer[c_void].address_of(option_value) var option_len = sizeof[socklen_t]() var status = setsockopt( - self.sockfd.fd, + self.fd.fd, SOL_SOCKET, option_name, option_value_pointer, @@ -310,7 +336,7 @@ struct Socket(FileDescriptorBase): if status == -1: raise Error("Socket.set_sock_opt failed with status: " + str(status)) - fn connect(inout self, address: String, port: Int) raises: + fn connect(inout self, address: String, port: Int) -> Error: """Connect to a remote socket at address. Args: @@ -319,12 +345,18 @@ struct Socket(FileDescriptorBase): """ var sockaddr_pointer = build_sockaddr_pointer(address, port, self.address_family) - if connect(self.sockfd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1: + if connect(self.fd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1: self.shutdown() - raise Error("Socket.connect: Failed to connect to the remote socket at: " + address + ":" + str(port)) + return Error("Socket.connect: Failed to connect to the remote socket at: " + address + ":" + str(port)) - var remote = self.get_peer_name() - self.remote_address = TCPAddr(remote.host, remote.port) + var remote: HostPort + var err: Error + remote, err = self.get_peer_name() + if err: + return err + + self.remote_address = BaseAddr(remote.host, remote.port) + return Error() @always_inline fn write(inout self: Self, src: List[UInt8]) -> (Int, Error): @@ -336,7 +368,7 @@ struct Socket(FileDescriptorBase): Returns: The number of bytes sent. """ - return self.sockfd.write(src) + return self.fd.write(src) fn send_all(self, src: List[UInt8], max_attempts: Int = 3) raises: """Send data to the socket. The socket must be connected to a remote socket. @@ -356,7 +388,7 @@ struct Socket(FileDescriptorBase): raise Error("Failed to send message after " + str(max_attempts) + " attempts.") var bytes_sent = send( - self.sockfd.fd, + self.fd.fd, data.offset(total_bytes_sent), bytes_to_send - total_bytes_sent, 0, @@ -366,7 +398,8 @@ struct Socket(FileDescriptorBase): total_bytes_sent += bytes_sent attempts += 1 - fn send_to(inout self, src: List[UInt8], address: String, port: Int) raises -> Int: + @always_inline + fn send_to(inout self, src: List[UInt8], address: String, port: Int) -> (Int, Error): """Send data to the a remote address by connecting to the remote socket before sending. The socket must be not already be connected to a remote socket. @@ -375,33 +408,156 @@ struct Socket(FileDescriptorBase): address: The IP address to connect to. port: The port number to connect to. """ - self.connect(address, port) - var bytes_written: Int - var err: Error - bytes_written, err = self.write(Span(src)) - if err: - raise err - return bytes_written + var bytes_sent = sendto( + self.fd.fd, + src.unsafe_ptr(), + len(src), + 0, + build_sockaddr_pointer(address, port, self.address_family), + sizeof[sockaddr_in](), + ) + + if bytes_sent == -1: + return 0, Error("Socket.send_to: Failed to send message to remote socket at: " + address + ":" + str(port)) + + return bytes_sent, Error() + + @always_inline + fn receive(inout self, size: Int = io.BUFFER_SIZE) -> (List[UInt8], Error): + """Receive data from the socket into the buffer with capacity of `size` bytes. + + Args: + size: The size of the buffer to receive data into. + + Returns: + The buffer with the received data, and an error if one occurred. + """ + var buffer = UnsafePointer[UInt8].alloc(size) + var bytes_received = recv( + self.fd.fd, + buffer, + size, + 0, + ) + if bytes_received == -1: + return List[UInt8](), Error("Socket.receive: Failed to receive message from socket.") + + var bytes = List[UInt8](unsafe_pointer=buffer, size=bytes_received, capacity=size) + if bytes_received < bytes.capacity: + return bytes, Error(io.EOF) + + return bytes, Error() + + @always_inline + fn _read(inout self, inout dest: Span[UInt8, True], capacity: Int) -> (Int, Error): + """Receive data from the socket into the buffer dest. Equivalent to recv_into(). + + Args: + dest: The buffer to read data into. + capacity: The capacity of the buffer. + + Returns: + The number of bytes read, and an error if one occurred. + """ + return self.fd._read(dest, capacity) @always_inline fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): - """Receive data from the socket.""" + """Receive data from the socket into the buffer dest. Equivalent to recv_into(). + + Args: + dest: The buffer to read data into. + + Returns: + The number of bytes read, and an error if one occurred. + """ + var span = Span(dest) + var bytes_read: Int - var err = Error() - bytes_read, err = self.sockfd.read(dest) + var err: Error + bytes_read, err = self._read(span, dest.capacity) + dest.size += bytes_read return bytes_read, err + @always_inline + fn receive_from(inout self, size: Int = io.BUFFER_SIZE) -> (List[UInt8], HostPort, Error): + """Receive data from the socket into the buffer dest. + + Args: + size: The size of the buffer to receive data into. + + Returns: + The number of bytes read, the remote address, and an error if one occurred. + """ + var remote_address_ptr = UnsafePointer[sockaddr].alloc(1) + var remote_address_ptr_size = socklen_t(sizeof[sockaddr]()) + var buffer = UnsafePointer[UInt8].alloc(size) + var bytes_received = recvfrom( + self.fd.fd, + buffer, + size, + 0, + remote_address_ptr, + UnsafePointer[socklen_t].address_of(remote_address_ptr_size), + ) + + if bytes_received == -1: + return List[UInt8](), HostPort(), Error("Failed to read from socket, received a -1 response.") + + var remote: HostPort + var err: Error + remote, err = convert_sockaddr_to_host_port(remote_address_ptr) + if err: + return List[UInt8](), HostPort(), err + + var bytes = List[UInt8](unsafe_pointer=buffer, size=bytes_received, capacity=size) + if bytes_received < bytes.capacity: + return bytes, remote, Error(io.EOF) + + return bytes, remote, Error() + + @always_inline + fn receive_from_into(inout self, inout dest: List[UInt8]) -> (Int, HostPort, Error): + """Receive data from the socket into the buffer dest.""" + var remote_address_ptr = UnsafePointer[sockaddr].alloc(1) + var remote_address_ptr_size = socklen_t(sizeof[sockaddr]()) + var bytes_read = recvfrom( + self.fd.fd, + dest.unsafe_ptr() + dest.size, + dest.capacity - dest.size, + 0, + remote_address_ptr, + UnsafePointer[socklen_t].address_of(remote_address_ptr_size), + ) + dest.size += bytes_read + + if bytes_read == -1: + return 0, HostPort(), Error("Socket.receive_from_into: Failed to read from socket, received a -1 response.") + + var remote: HostPort + var err: Error + remote, err = convert_sockaddr_to_host_port(remote_address_ptr) + if err: + return 0, HostPort(), err + + if bytes_read < dest.capacity: + return bytes_read, remote, Error(io.EOF) + + return bytes_read, remote, Error() + + @always_inline fn shutdown(self): - _ = shutdown(self.sockfd.fd, SHUT_RDWR) + _ = shutdown(self.fd.fd, SHUT_RDWR) + @always_inline fn close(inout self) -> Error: """Mark the socket closed. Once that happens, all future operations on the socket object will fail. The remote end will receive no more data (after queued data is flushed). """ self.shutdown() - var err = self.sockfd.close() + var err = self.fd.close() if err: return err @@ -409,11 +565,11 @@ struct Socket(FileDescriptorBase): return Error() # TODO: Trying to set timeout fails, but some other options don't? - # fn get_timeout(self) raises -> Seconds: + # fn get_timeout(self) raises -> Int: # """Return the timeout value for the socket.""" # return self.get_socket_option(SocketOptions.SO_RCVTIMEO) - # fn set_timeout(self, owned duration: Seconds) raises: + # fn set_timeout(self, owned duration: Int) raises: # """Set the timeout value for the socket. # Args: @@ -421,5 +577,6 @@ struct Socket(FileDescriptorBase): # """ # self.set_socket_option(SocketOptions.SO_RCVTIMEO, duration) + @always_inline fn send_file(self, file: FileHandle, offset: Int = 0) raises: self.send_all(file.read_bytes()) diff --git a/gojo/net/tcp.mojo b/gojo/net/tcp.mojo index b353eaa..b459826 100644 --- a/gojo/net/tcp.mojo +++ b/gojo/net/tcp.mojo @@ -1,106 +1,99 @@ +from collections import InlineList from ..syscall import SocketOptions -from .net import Connection, Conn -from .address import TCPAddr, NetworkType, split_host_port +from .address import NetworkType, split_host_port, join_host_port, BaseAddr, resolve_internet_addr, HostPort from .socket import Socket -# Time in nanoseconds -alias Duration = Int -alias DEFAULT_BUFFER_SIZE = 4096 -alias DEFAULT_TCP_KEEP_ALIVE = Duration(15 * 1000 * 1000 * 1000) # 15 seconds - - -fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: - var host: String = "" - var port: String = "" - var portnum: Int = 0 - if ( - network == NetworkType.tcp.value - or network == NetworkType.tcp4.value - or network == NetworkType.tcp6.value - or network == NetworkType.udp.value - or network == NetworkType.udp4.value - or network == NetworkType.udp6.value - ): - if address != "": - var host_port = split_host_port(address) - host = host_port.host - port = str(host_port.port) - portnum = atol(port.__str__()) - elif network == NetworkType.ip.value or network == NetworkType.ip4.value or network == NetworkType.ip6.value: - if address != "": - host = address - elif network == NetworkType.unix.value: - raise Error("Unix addresses not supported yet") - else: - raise Error("unsupported network type: " + network) - return TCPAddr(host, portnum) - - -# TODO: For now listener is paired with TCP until we need to support -# more than one type of Connection or Listener @value -struct ListenConfig(CollectionElement): - var keep_alive: Duration +struct TCPAddr(Addr): + """Addr struct representing a TCP address. - fn listen(self, network: String, address: String) raises -> TCPListener: - var tcp_addr = resolve_internet_addr(network, address) - var socket = Socket(local_address=tcp_addr) - socket.bind(tcp_addr.ip, tcp_addr.port) - socket.set_socket_option(SocketOptions.SO_REUSEADDR, 1) - socket.listen() - print(str("Listening on ") + str(socket.local_address)) - return TCPListener(socket^, self, network, address) + Args: + ip: IP address. + port: Port number. + zone: IPv6 addressing zone. + """ + var ip: String + var port: Int + var zone: String # IPv6 addressing zone -trait Listener(Movable): - # Raising here because a Result[Optional[Connection], Error] is funky. - fn accept(self) raises -> Connection: - ... + fn __init__(inout self, ip: String = "127.0.0.1", port: Int = 8000, zone: String = ""): + self.ip = ip + self.port = port + self.zone = zone - fn close(inout self) -> Error: - ... + fn __init__(inout self, addr: BaseAddr): + self.ip = addr.ip + self.port = addr.port + self.zone = addr.zone + + fn __str__(self) -> String: + if self.zone != "": + return join_host_port(str(self.ip) + "%" + self.zone, str(self.port)) + return join_host_port(self.ip, str(self.port)) - fn addr(self) raises -> TCPAddr: - ... + fn network(self) -> String: + return NetworkType.tcp.value -struct TCPConnection(Conn): +struct TCPConnection(Movable): """TCPConn is an implementation of the Conn interface for TCP network connections. Args: connection: The underlying Connection. """ - var _connection: Connection - - fn __init__(inout self, owned connection: Connection): - self._connection = connection^ + var socket: Socket + @always_inline fn __init__(inout self, owned socket: Socket): - self._connection = Connection(socket^) + self.socket = socket^ + @always_inline fn __moveinit__(inout self, owned existing: Self): - self._connection = existing._connection^ + self.socket = existing.socket^ - fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + @always_inline + fn _read(inout self, inout dest: Span[UInt8, True], capacity: Int) -> (Int, Error): """Reads data from the underlying file descriptor. Args: dest: The buffer to read data into. + capacity: The capacity of the destination buffer. Returns: The number of bytes read, or an error if one occurred. """ var bytes_read: Int - var err: Error - bytes_read, err = self._connection.read(dest) + var err = Error() + bytes_read, err = self.socket._read(dest, capacity) if err: if str(err) != io.EOF: return bytes_read, err - return bytes_read, Error() + return bytes_read, err + @always_inline + fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + """Reads data from the underlying file descriptor. + + Args: + dest: The buffer to read data into. + + Returns: + The number of bytes read, or an error if one occurred. + """ + var span = Span(dest) + + var bytes_read: Int + var err: Error + bytes_read, err = self._read(span, dest.capacity) + dest.size += bytes_read + + return bytes_read, err + + @always_inline fn write(inout self, src: List[UInt8]) -> (Int, Error): """Writes data to the underlying file descriptor. @@ -110,16 +103,18 @@ struct TCPConnection(Conn): Returns: The number of bytes written, or an error if one occurred. """ - return self._connection.write(src) + return self.socket.write(src) + @always_inline fn close(inout self) -> Error: """Closes the underlying file descriptor. Returns: An error if one occurred, or None if the file descriptor was closed successfully. """ - return self._connection.close() + return self.socket.close() + @always_inline fn local_address(self) -> TCPAddr: """Returns the local network address. The Addr returned is shared by all invocations of local_address, so do not modify it. @@ -127,8 +122,9 @@ struct TCPConnection(Conn): Returns: The local network address. """ - return self._connection.local_address() + return self.socket.local_address_as_tcp() + @always_inline fn remote_address(self) -> TCPAddr: """Returns the remote network address. The Addr returned is shared by all invocations of remote_address, so do not modify it. @@ -136,7 +132,7 @@ struct TCPConnection(Conn): Returns: The remote network address. """ - return self._connection.remote_address() + return self.socket.remote_address_as_tcp() fn listen_tcp(network: String, local_address: TCPAddr) raises -> TCPListener: @@ -146,7 +142,12 @@ fn listen_tcp(network: String, local_address: TCPAddr) raises -> TCPListener: network: The network type. local_address: The local address to listen on. """ - return ListenConfig(DEFAULT_TCP_KEEP_ALIVE).listen(network, local_address.ip + ":" + str(local_address.port)) + var socket = Socket() + socket.bind(local_address.ip, local_address.port) + socket.set_socket_option(SocketOptions.SO_REUSEADDR, 1) + socket.listen() + # print(str("Listening on ") + str(socket.local_address_as_tcp())) + return TCPListener(socket^, network, local_address) fn listen_tcp(network: String, local_address: String) raises -> TCPListener: @@ -156,44 +157,106 @@ fn listen_tcp(network: String, local_address: String) raises -> TCPListener: network: The network type. local_address: The address to listen on. The format is "host:port". """ - return ListenConfig(DEFAULT_TCP_KEEP_ALIVE).listen(network, local_address) + var tcp_addr: TCPAddr + var err: Error + tcp_addr, err = resolve_internet_addr(network, local_address) + if err: + raise err + return listen_tcp(network, tcp_addr) + + +fn listen_tcp(network: String, host: String, port: Int) raises -> TCPListener: + """Creates a new TCP listener. + + Args: + network: The network type. + host: The address to listen on, in ipv4 format. + port: The port to listen on. + """ + return listen_tcp(network, TCPAddr(host, port)) -struct TCPListener(Listener): - var _file_descriptor: Socket - var listen_config: ListenConfig +struct TCPListener: + var socket: Socket var network_type: String - var address: String + var address: TCPAddr fn __init__( inout self, - owned file_descriptor: Socket, - listen_config: ListenConfig, + owned socket: Socket, network_type: String, - address: String, + address: TCPAddr, ): - self._file_descriptor = file_descriptor^ - self.listen_config = listen_config + self.socket = socket^ self.network_type = network_type self.address = address fn __moveinit__(inout self, owned existing: Self): - self._file_descriptor = existing._file_descriptor^ - self.listen_config = existing.listen_config^ + self.socket = existing.socket^ self.network_type = existing.network_type^ self.address = existing.address^ - fn listen(self) raises -> Self: - return self.listen_config.listen(self.network_type, self.address) + fn accept(self) raises -> TCPConnection: + return TCPConnection(self.socket.accept()) + + fn close(inout self) -> Error: + return self.socket.close() - fn accept(self) raises -> Connection: - return Connection(self._file_descriptor.accept()) - fn accept_tcp(self) raises -> TCPConnection: - return TCPConnection(self._file_descriptor.accept()) +alias TCP_NETWORK_TYPES = InlineList[String, 3]("tcp", "tcp4", "tcp6") - fn close(inout self) -> Error: - return self._file_descriptor.close() - fn addr(self) raises -> TCPAddr: - return resolve_internet_addr(self.network_type, self.address) +fn dial_tcp(network: String, remote_address: TCPAddr) raises -> TCPConnection: + """Connects to the address on the named network. + + The network must be "tcp", "tcp4", or "tcp6". + Args: + network: The network type. + remote_address: The remote address to connect to. + + Returns: + The TCP connection. + """ + # TODO: Add conversion of domain name to ip address + if network not in TCP_NETWORK_TYPES: + raise Error("unsupported network type: " + network) + + var socket = Socket() + var err = socket.connect(remote_address.ip, remote_address.port) + if err: + raise err + return TCPConnection(socket^) + + +fn dial_tcp(network: String, remote_address: String) raises -> TCPConnection: + """Connects to the address on the named network. + + The network must be "tcp", "tcp4", or "tcp6". + Args: + network: The network type. + remote_address: The remote address to connect to. (The format is "host:port"). + + Returns: + The TCP connection. + """ + var remote: HostPort + var err: Error + remote, err = split_host_port(remote_address) + if err: + raise err + return dial_tcp(network, TCPAddr(remote.host, remote.port)) + + +fn dial_tcp(network: String, host: String, port: Int) raises -> TCPConnection: + """Connects to the address on the named network. + + The network must be "tcp", "tcp4", or "tcp6". + Args: + network: The network type. + host: The remote address to connect to in ipv4 format. + port: The remote port. + + Returns: + The TCP connection. + """ + return dial_tcp(network, TCPAddr(host, port)) diff --git a/gojo/net/udp.mojo b/gojo/net/udp.mojo new file mode 100644 index 0000000..61cb4cb --- /dev/null +++ b/gojo/net/udp.mojo @@ -0,0 +1,210 @@ +from collections import InlineList +from ..syscall import SocketOptions, SocketType +from .address import NetworkType, split_host_port, join_host_port, BaseAddr, resolve_internet_addr +from .socket import Socket + + +# TODO: Change ip to list of bytes +@value +struct UDPAddr(Addr): + """Represents the address of a UDP end point. + + Args: + ip: IP address. + port: Port number. + zone: IPv6 addressing zone. + """ + + var ip: String + var port: Int + var zone: String # IPv6 addressing zone + + fn __init__(inout self, ip: String = "127.0.0.1", port: Int = 8000, zone: String = ""): + self.ip = ip + self.port = port + self.zone = zone + + fn __init__(inout self, addr: BaseAddr): + self.ip = addr.ip + self.port = addr.port + self.zone = addr.zone + + fn __str__(self) -> String: + if self.zone != "": + return join_host_port(str(self.ip) + "%" + self.zone, str(self.port)) + return join_host_port(self.ip, str(self.port)) + + fn network(self) -> String: + return NetworkType.udp.value + + +struct UDPConnection(Movable): + """Implementation of the Conn interface for TCP network connections.""" + + var socket: Socket + + fn __init__(inout self, owned socket: Socket): + self.socket = socket^ + + fn __moveinit__(inout self, owned existing: Self): + self.socket = existing.socket^ + + fn read_from(inout self, inout dest: List[UInt8]) -> (Int, HostPort, Error): + """Reads data from the underlying file descriptor. + + Args: + dest: The buffer to read data into. + + Returns: + The number of bytes read, or an error if one occurred. + """ + var bytes_read: Int + var remote: HostPort + var err = Error() + bytes_read, remote, err = self.socket.receive_from_into(dest) + if err: + if str(err) != io.EOF: + return bytes_read, remote, err + + return bytes_read, remote, err + + fn write_to(inout self, src: List[UInt8], address: UDPAddr) -> (Int, Error): + """Writes data to the underlying file descriptor. + + Args: + src: The buffer to read data into. + address: The remote peer address. + + Returns: + The number of bytes written, or an error if one occurred. + """ + return self.socket.send_to(src, address.ip, address.port) + + fn write_to(inout self, src: List[UInt8], host: String, port: Int) -> (Int, Error): + """Writes data to the underlying file descriptor. + + Args: + src: The buffer to read data into. + host: The remote peer address in IPv4 format. + port: The remote peer port. + + Returns: + The number of bytes written, or an error if one occurred. + """ + return self.socket.send_to(src, host, port) + + fn close(inout self) -> Error: + """Closes the underlying file descriptor. + + Returns: + An error if one occurred, or None if the file descriptor was closed successfully. + """ + return self.socket.close() + + fn local_address(self) -> UDPAddr: + """Returns the local network address. + The Addr returned is shared by all invocations of local_address, so do not modify it. + + Returns: + The local network address. + """ + return self.socket.local_address_as_udp() + + fn remote_address(self) -> UDPAddr: + """Returns the remote network address. + The Addr returned is shared by all invocations of remote_address, so do not modify it. + + Returns: + The remote network address. + """ + return self.socket.remote_address_as_udp() + + +fn listen_udp(network: String, local_address: UDPAddr) raises -> UDPConnection: + """Creates a new UDP listener. + + Args: + network: The network type. + local_address: The local address to listen on. + """ + var socket = Socket(socket_type=SocketType.SOCK_DGRAM) + socket.bind(local_address.ip, local_address.port) + # print(str("Listening on ") + str(socket.local_address_as_udp())) + return UDPConnection(socket^) + + +fn listen_udp(network: String, local_address: String) raises -> UDPConnection: + """Creates a new UDP listener. + + Args: + network: The network type. + local_address: The address to listen on. The format is "host:port". + """ + var result = split_host_port(local_address) + return listen_udp(network, UDPAddr(result[0].host, result[0].port)) + + +fn listen_udp(network: String, host: String, port: Int) raises -> UDPConnection: + """Creates a new UDP listener. + + Args: + network: The network type. + host: The address to listen on in ipv4 format. + port: The port number. + """ + return listen_udp(network, UDPAddr(host, port)) + + +alias UDP_NETWORK_TYPES = InlineList[String, 3]("udp", "udp4", "udp6") + + +fn dial_udp(network: String, local_address: UDPAddr) raises -> UDPConnection: + """Connects to the address on the named network. + + The network must be "udp", "udp4", or "udp6". + Args: + network: The network type. + local_address: The local address. + + Returns: + The TCP connection. + """ + # TODO: Add conversion of domain name to ip address + if network not in UDP_NETWORK_TYPES: + raise Error("unsupported network type: " + network) + + var socket = Socket(local_address=BaseAddr(local_address), socket_type=SocketType.SOCK_DGRAM) + return UDPConnection(socket^) + + +fn dial_udp(network: String, local_address: String) raises -> UDPConnection: + """Connects to the address on the named network. + + The network must be "udp", "udp4", or "udp6". + Args: + network: The network type. + local_address: The local address to connect to. (The format is "host:port"). + + Returns: + The TCP connection. + """ + var result = split_host_port(local_address) + if result[1]: + raise result[1] + + return dial_udp(network, UDPAddr(result[0].host, result[0].port)) + + +fn dial_udp(network: String, host: String, port: Int) raises -> UDPConnection: + """Connects to the address on the named network. + + The network must be "udp", "udp4", or "udp6". + Args: + network: The network type. + host: The remote host in ipv4 format. + port: The remote port. + + Returns: + The TCP connection. + """ + return dial_udp(network, UDPAddr(host, port)) diff --git a/gojo/strings/reader.mojo b/gojo/strings/reader.mojo index 9df939d..61292bf 100644 --- a/gojo/strings/reader.mojo +++ b/gojo/strings/reader.mojo @@ -51,12 +51,13 @@ struct Reader( return len(self.string) @always_inline - fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + fn _read(inout self, inout dest: Span[UInt8, True], capacity: Int) -> (Int, Error): """Reads from the underlying string into the provided List[UInt8] object. Implements the [io.Reader] trait. Args: dest: The destination List[UInt8] object to read into. + capacity: The capacity of the destination List[UInt8] object. Returns: The number of bytes read into dest. @@ -69,13 +70,35 @@ struct Reader( self.read_pos += bytes_written return bytes_written, Error() - fn read_at(self, inout dest: List[UInt8], off: Int) -> (Int, Error): + @always_inline + fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + """Reads from the underlying string into the provided List[UInt8] object. + Implements the [io.Reader] trait. + + Args: + dest: The destination List[UInt8] object to read into. + + Returns: + The number of bytes read into dest. + """ + var span = Span(dest) + + var bytes_read: Int + var err: Error + bytes_read, err = self._read(span, dest.capacity) + dest.size += bytes_read + + return bytes_read, err + + @always_inline + fn _read_at(self, inout dest: Span[UInt8, True], off: Int, capacity: Int) -> (Int, Error): """Reads from the Reader into the dest List[UInt8] starting at the offset off. It returns the number of bytes read into dest and an error if any. Args: dest: The destination List[UInt8] object to read into. off: The byte offset to start reading from. + capacity: The capacity of the destination List[UInt8] object. Returns: The number of bytes read into dest. @@ -94,6 +117,27 @@ struct Reader( return copied_elements_count, error + @always_inline + fn read_at(self, inout dest: List[UInt8], off: Int) -> (Int, Error): + """Reads from the Reader into the dest List[UInt8] starting at the offset off. + It returns the number of bytes read into dest and an error if any. + + Args: + dest: The destination List[UInt8] object to read into. + off: The byte offset to start reading from. + + Returns: + The number of bytes read into dest. + """ + var span = Span(dest) + + var bytes_read: Int + var err: Error + bytes_read, err = self._read_at(span, off, dest.capacity) + dest.size += bytes_read + + return bytes_read, err + @always_inline fn read_byte(inout self) -> (UInt8, Error): """Reads the next byte from the underlying string.""" diff --git a/gojo/syscall/__init__.mojo b/gojo/syscall/__init__.mojo index 219df7d..c89fef0 100644 --- a/gojo/syscall/__init__.mojo +++ b/gojo/syscall/__init__.mojo @@ -6,7 +6,9 @@ from .net import ( SocketOptions, AddressInformation, send, + sendto, recv, + recvfrom, open, addrinfo, addrinfo_unix, @@ -22,7 +24,6 @@ from .net import ( getaddrinfo, getaddrinfo_unix, gai_strerror, - c_charptr_to_string, shutdown, inet_ntoa, bind, diff --git a/gojo/syscall/net.mojo b/gojo/syscall/net.mojo index d6174c5..676b2da 100644 --- a/gojo/syscall/net.mojo +++ b/gojo/syscall/net.mojo @@ -714,17 +714,73 @@ fn recv( """Libc POSIX `recv` function Reference: https://man7.org/linux/man-pages/man3/recv.3p.html Fn signature: ssize_t recv(int socket, void *buffer, size_t length, int flags). + + Args: + socket: Specifies the socket file descriptor. + buffer: Points to the buffer where the message should be stored. + length: Specifies the length in bytes of the buffer pointed to by the buffer argument. + flags: Specifies the type of message reception. + + Returns: + The number of bytes received or -1 in case of failure. + + Valid Flags: + MSG_PEEK: Peeks at an incoming message. The data is treated as unread and the next recvfrom() or similar function shall still return this data. + MSG_OOB: Requests out-of-band data. The significance and semantics of out-of-band data are protocol-specific. + MSG_WAITALL: On SOCK_STREAM sockets this requests that the function block until the full amount of data can be returned. The function may return the smaller amount of data if the socket is a message-based socket, if a signal is caught, if the connection is terminated, if MSG_PEEK was specified, or if an error is pending for the socket. """ return external_call[ "recv", - c_ssize_t, # FnName, RetType + c_ssize_t, c_int, UnsafePointer[UInt8], c_size_t, - c_int, # Args + c_int, ](socket, buffer, length, flags) +fn recvfrom( + socket: c_int, + buffer: UnsafePointer[UInt8], + length: c_size_t, + flags: c_int, + address: UnsafePointer[sockaddr], + address_len: UnsafePointer[socklen_t], +) -> c_ssize_t: + """Libc POSIX `recvfrom` function + Reference: https://man7.org/linux/man-pages/man3/recvfrom.3p.html + Fn signature: ssize_t recvfrom(int socket, void *restrict buffer, size_t length, + int flags, struct sockaddr *restrict address, + socklen_t *restrict address_len). + + Args: + socket: Specifies the socket file descriptor. + buffer: Points to the buffer where the message should be stored. + length: Specifies the length in bytes of the buffer pointed to by the buffer argument. + flags: Specifies the type of message reception. + address: A null pointer, or points to a sockaddr structure in which the sending address is to be stored. + address_len: Either a null pointer, if address is a null pointer, or a pointer to a socklen_t object which on input specifies the length of the supplied sockaddr structure, and on output specifies the length of the stored address. + + Returns: + The number of bytes received or -1 in case of failure. + + Valid Flags: + MSG_PEEK: Peeks at an incoming message. The data is treated as unread and the next recvfrom() or similar function shall still return this data. + MSG_OOB: Requests out-of-band data. The significance and semantics of out-of-band data are protocol-specific. + MSG_WAITALL: On SOCK_STREAM sockets this requests that the function block until the full amount of data can be returned. The function may return the smaller amount of data if the socket is a message-based socket, if a signal is caught, if the connection is terminated, if MSG_PEEK was specified, or if an error is pending for the socket. + """ + return external_call[ + "recvfrom", + c_ssize_t, + c_int, + UnsafePointer[UInt8], + c_size_t, + c_int, + UnsafePointer[sockaddr], + UnsafePointer[socklen_t], + ](socket, buffer, length, flags, address, address_len) + + fn send( socket: c_int, buffer: UnsafePointer[UInt8], @@ -751,6 +807,41 @@ fn send( ](socket, buffer, length, flags) +fn sendto( + socket: c_int, + message: UnsafePointer[UInt8], + length: c_size_t, + flags: c_int, + dest_addr: UnsafePointer[sockaddr], + dest_len: socklen_t, +) -> c_ssize_t: + """Libc POSIX `sendto` function + Reference: https://man7.org/linux/man-pages/man3/sendto.3p.html + Fn signature: ssize_t sendto(int socket, const void *message, size_t length, + int flags, const struct sockaddr *dest_addr, + socklen_t dest_len). + + Args: + socket: Specifies the socket file descriptor. + message: Points to a buffer containing the message to be sent. + length: Specifies the size of the message in bytes. + flags: Specifies the type of message transmission. + dest_addr: Points to a sockaddr structure containing the destination address. + dest_len: Specifies the length of the sockaddr. + + Returns: + The number of bytes sent or -1 in case of failure. + + Valid Flags: + MSG_EOR: Terminates a record (if supported by the protocol). + MSG_OOB: Sends out-of-band data on sockets that support out-of-band data. The significance and semantics of out-of-band data are protocol-specific. + MSG_NOSIGNAL: Requests not to send the SIGPIPE signal if an attempt to send is made on a stream-oriented socket that is no longer connected. The [EPIPE] error shall still be returned. + """ + return external_call[ + "sendto", c_ssize_t, c_int, UnsafePointer[UInt8], c_size_t, c_int, UnsafePointer[sockaddr], socklen_t + ](socket, message, length, flags, dest_addr, dest_len) + + fn shutdown(socket: c_int, how: c_int) -> c_int: """Libc POSIX `shutdown` function Reference: https://man7.org/linux/man-pages/man3/shutdown.3p.html diff --git a/tests/test_bufio_scanner.mojo b/tests/test_bufio_scanner.mojo index ff2c4da..7a035b0 100644 --- a/tests/test_bufio_scanner.mojo +++ b/tests/test_bufio_scanner.mojo @@ -1,11 +1,11 @@ from tests.wrapper import MojoTest from gojo.bytes import buffer from gojo.io import FileWrapper -from gojo.bufio import Reader, Scanner, scan_words, scan_bytes +from gojo.bufio import Reader, Scanner, scan_words, scan_bytes, scan_runes -fn test_scan_words() raises: - var test = MojoTest("Testing scan_words") +fn test_scan_words(): + var test = MojoTest("Testing bufio.scan_words") # Create a reader from a string buffer var s: String = "Testing this string!" @@ -13,13 +13,9 @@ fn test_scan_words() raises: var r = Reader(buf^) # Create a scanner from the reader - var scanner = Scanner(r^) - scanner.split = scan_words + var scanner = Scanner[split=scan_words](r^) - var expected_results = List[String]() - expected_results.append("Testing") - expected_results.append("this") - expected_results.append("string!") + var expected_results = List[String]("Testing", "this", "string!") var i = 0 while scanner.scan(): @@ -27,8 +23,8 @@ fn test_scan_words() raises: i += 1 -fn test_scan_lines() raises: - var test = MojoTest("Testing scan_lines") +fn test_scan_lines(): + var test = MojoTest("Testing bufio.scan_lines") # Create a reader from a string buffer var s: String = "Testing\nthis\nstring!" @@ -38,10 +34,7 @@ fn test_scan_lines() raises: # Create a scanner from the reader var scanner = Scanner(r^) - var expected_results = List[String]() - expected_results.append("Testing") - expected_results.append("this") - expected_results.append("string!") + var expected_results = List[String]("Testing", "this", "string!") var i = 0 while scanner.scan(): @@ -49,7 +42,7 @@ fn test_scan_lines() raises: i += 1 -fn scan_no_newline_test(test_case: String, result_lines: List[String], test: MojoTest) raises: +fn scan_no_newline_test(test_case: String, result_lines: List[String], test: MojoTest): # Create a reader from a string buffer var buf = buffer.new_buffer(test_case) var r = Reader(buf^) @@ -62,56 +55,42 @@ fn scan_no_newline_test(test_case: String, result_lines: List[String], test: Moj i += 1 -fn test_scan_lines_no_newline() raises: +fn test_scan_lines_no_newline(): var test = MojoTest("Testing bufio.scan_lines with no final newline") var test_case = "abcdefghijklmn\nopqrstuvwxyz" - var result_lines = List[String]() - result_lines.append("abcdefghijklmn") - result_lines.append("opqrstuvwxyz") + var result_lines = List[String]("abcdefghijklmn", "opqrstuvwxyz") scan_no_newline_test(test_case, result_lines, test) -fn test_scan_lines_cr_no_newline() raises: +fn test_scan_lines_cr_no_newline(): var test = MojoTest("Testing bufio.scan_lines with no final newline but carriage return") var test_case = "abcdefghijklmn\nopqrstuvwxyz\r" - var result_lines = List[String]() - result_lines.append("abcdefghijklmn") - result_lines.append("opqrstuvwxyz") + var result_lines = List[String]("abcdefghijklmn", "opqrstuvwxyz") scan_no_newline_test(test_case, result_lines, test) -fn test_scan_lines_empty_final_line() raises: +fn test_scan_lines_empty_final_line(): var test = MojoTest("Testing bufio.scan_lines with an empty final line") var test_case = "abcdefghijklmn\nopqrstuvwxyz\n\n" - var result_lines = List[String]() - result_lines.append("abcdefghijklmn") - result_lines.append("opqrstuvwxyz") - result_lines.append("") + var result_lines = List[String]("abcdefghijklmn", "opqrstuvwxyz", "") scan_no_newline_test(test_case, result_lines, test) -fn test_scan_lines_cr_empty_final_line() raises: +fn test_scan_lines_cr_empty_final_line(): var test = MojoTest("Testing bufio.scan_lines with an empty final line and carriage return") var test_case = "abcdefghijklmn\nopqrstuvwxyz\n\r" - var result_lines = List[String]() - result_lines.append("abcdefghijklmn") - result_lines.append("opqrstuvwxyz") - result_lines.append("") + var result_lines = List[String]("abcdefghijklmn", "opqrstuvwxyz", "") scan_no_newline_test(test_case, result_lines, test) -fn test_scan_bytes() raises: - var test = MojoTest("Testing scan_bytes") +fn test_scan_bytes(): + var test = MojoTest("Testing bufio.scan_bytes") - var test_cases = List[String]() - test_cases.append("") - test_cases.append("a") - test_cases.append("abc") - test_cases.append("abc def\n\t\tgh ") + var test_cases = List[String]("", "a", "abc", "abc def\n\t\tgh ") for i in range(len(test_cases)): var test_case = test_cases[i] @@ -120,13 +99,11 @@ fn test_scan_bytes() raises: var reader = Reader(buf^) # Create a scanner from the reader - var scanner = Scanner(reader^) - scanner.split = scan_bytes + var scanner = Scanner[split=scan_bytes](reader^) var j = 0 - while scanner.scan(): - test.assert_equal(String(scanner.current_token_as_bytes()), test_case[j]) + test.assert_equal(scanner.current_token(), test_case[j]) j += 1 @@ -136,12 +113,26 @@ fn test_file_wrapper_scanner() raises: # Create a scanner from the reader var scanner = Scanner(file^) - var expected_results = List[String]() - expected_results.append("11111") - expected_results.append("22222") - expected_results.append("33333") - expected_results.append("44444") - expected_results.append("55555") + var expected_results = List[String]("11111", "22222", "33333", "44444", "55555") + var i = 0 + + while scanner.scan(): + test.assert_equal(scanner.current_token(), expected_results[i]) + i += 1 + + +fn test_scan_runes(): + var test = MojoTest("Testing bufio.scan_runes") + + # Create a reader from a string buffer + var s: String = "🔪🔥🔪" + var buf = buffer.new_buffer(s) + var r = Reader(buf^) + + # Create a scanner from the reader + var scanner = Scanner[split=scan_runes](r^) + + var expected_results = List[String]("🔪", "🔥", "🔪") var i = 0 while scanner.scan(): @@ -158,3 +149,4 @@ fn main() raises: test_scan_lines_cr_empty_final_line() test_scan_bytes() test_file_wrapper_scanner() + test_scan_runes() diff --git a/tests/test_get_addr.mojo b/tests/test_get_addr.mojo index 94d3520..2f481d2 100644 --- a/tests/test_get_addr.mojo +++ b/tests/test_get_addr.mojo @@ -40,43 +40,44 @@ fn test_listener() raises: var listener = listen_tcp("tcp", TCPAddr("0.0.0.0", 8081)) while True: var conn = listener.accept() + print("Accepted connection from", conn.remote_address()) var err = conn.close() if err: raise err -fn test_stuff() raises: - # TODO: context manager not working yet - # with Socket() as socket: - # socket.bind("0.0.0.0", 8080) +# fn test_stuff() raises: +# # TODO: context manager not working yet +# # with Socket() as socket: +# # socket.bind("0.0.0.0", 8080) - var socket = Socket(protocol=ProtocolFamily.PF_UNIX) - socket.bind("0.0.0.0", 8080) - socket.connect(get_ip_address("www.example.com"), 80) - print("File number", socket.file_no()) - var local = socket.get_sock_name() - var remote = socket.get_peer_name() - print("Local address", str(local), socket.local_address) - print("Remote address", str(remote), socket.remote_address) - socket.set_socket_option(SocketOptions.SO_REUSEADDR, 1) - print("REUSE_ADDR value", socket.get_socket_option(SocketOptions.SO_REUSEADDR)) - var timeout = 30 - # socket.set_timeout(timeout) - # print(socket.get_timeout()) - socket.shutdown() - print("closing") - var err = socket.close() - print("closed") - if err: - print("err returned") - raise err - # var option_value = socket.get_sock_opt(SocketOptions.SO_REUSEADDR) - # print(option_value) - # socket.connect(self.ip, self.port) - # socket.send(message) - # var response = socket.receive() # TODO: call receive until all data is fetched, receive should also just return bytes - # socket.shutdown() - # socket.close() +# var socket = Socket(protocol=ProtocolFamily.PF_UNIX) +# socket.bind("0.0.0.0", 8080) +# socket.connect(get_ip_address("www.example.com"), 80) +# print("File number", socket.file_no()) +# var local = socket.get_sock_name() +# var remote = socket.get_peer_name() +# print("Local address", str(local), socket.local_address) +# print("Remote address", str(remote), socket.remote_address) +# socket.set_socket_option(SocketOptions.SO_REUSEADDR, 1) +# print("REUSE_ADDR value", socket.get_socket_option(SocketOptions.SO_REUSEADDR)) +# var timeout = 30 +# # socket.set_timeout(timeout) +# # print(socket.get_timeout()) +# socket.shutdown() +# print("closing") +# var err = socket.close() +# print("closed") +# if err: +# print("err returned") +# raise err +# # var option_value = socket.get_sock_opt(SocketOptions.SO_REUSEADDR) +# # print(option_value) +# # socket.connect(self.ip, self.port) +# # socket.send(message) +# # var response = socket.receive() # TODO: call receive until all data is fetched, receive should also just return bytes +# # socket.shutdown() +# # socket.close() fn main() raises: