Skip to content

Commit

Permalink
update net package
Browse files Browse the repository at this point in the history
  • Loading branch information
thatstoasty committed Jun 2, 2024
1 parent 1908123 commit 0aa8862
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 225 deletions.
15 changes: 4 additions & 11 deletions gojo/net/fd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ from collections.optional import Optional
import ..io
from ..builtins import Byte
from ..syscall.file import close
from ..syscall.types import c_char
from ..syscall.net import (
recv,
send,
Expand Down Expand Up @@ -52,25 +51,19 @@ struct FileDescriptor(FileDescriptorBase):
# TODO: Need faster approach to copying data from the file descriptor to the buffer.
fn read(inout self, inout dest: List[Byte]) -> (Int, Error):
"""Receive data from the file descriptor and write it to the buffer provided."""
var ptr = Pointer[UInt8]().alloc(dest.capacity)
var bytes_received = recv(self.fd, ptr, dest.capacity, 0)
var bytes_received = recv(self.fd, dest.unsafe_ptr(), dest.capacity, 0)
if bytes_received == -1:
return 0, Error("Failed to receive message from socket.")

var int8_ptr = ptr.bitcast[Int8]()
for i in range(bytes_received):
dest.append(int8_ptr[i])
dest.size += bytes_received

if bytes_received < dest.capacity:
return bytes_received, Error(io.EOF)

return bytes_received, Error()

fn write(inout self, src: List[Byte]) -> (Int, Error):
fn write(inout self, src: Span[Byte]) -> (Int, Error):
"""Write data from the buffer to the file descriptor."""
var header_pointer = Pointer[Int8](src.data.address).bitcast[UInt8]()

var bytes_sent = send(self.fd, header_pointer, strlen(header_pointer), 0)
var bytes_sent = send(self.fd, src.unsafe_ptr(), strlen(src.unsafe_ptr()), 0)
if bytes_sent == -1:
return 0, Error("Failed to send message")

Expand Down
46 changes: 17 additions & 29 deletions gojo/net/ip.mojo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.variant import Variant
from utils.static_tuple import StaticTuple
from sys.info import os_is_linux, os_is_macos
from ..syscall.types import (
from ..syscall import (
c_int,
c_char,
c_void,
Expand Down Expand Up @@ -32,8 +32,8 @@ alias AddrInfo = Variant[addrinfo, addrinfo_unix]
fn get_addr_info(host: String) raises -> AddrInfo:
var status: Int32 = 0
if os_is_macos():
var servinfo = Pointer[addrinfo]().alloc(1)
servinfo.store(addrinfo())
var servinfo = UnsafePointer[addrinfo]().alloc(1)
servinfo[0] = addrinfo()
var hints = addrinfo()
hints.ai_family = AF_INET
hints.ai_socktype = SOCK_STREAM
Expand All @@ -43,27 +43,21 @@ fn get_addr_info(host: String) raises -> AddrInfo:

var status = getaddrinfo(
host_ptr,
Pointer[UInt8](),
Pointer.address_of(hints),
Pointer.address_of(servinfo),
DTypePointer[DType.uint8](),
UnsafePointer.address_of(hints),
UnsafePointer.address_of(servinfo),
)
if status != 0:
print("getaddrinfo failed to execute with status:", status)
var msg_ptr = gai_strerror(c_int(status))
_ = external_call["printf", c_int, Pointer[c_char], Pointer[c_char]](
to_char_ptr("gai_strerror: %s"), msg_ptr
)
var msg = c_charptr_to_string(msg_ptr)
print("getaddrinfo error message: ", msg)

if not servinfo:
print("servinfo is null")
raise Error("Failed to get address info. Pointer to addrinfo is null.")

return servinfo.load()
return move_from_pointee(servinfo)
elif os_is_linux():
var servinfo = Pointer[addrinfo_unix]().alloc(1)
servinfo.store(addrinfo_unix())
var servinfo = UnsafePointer[addrinfo_unix]().alloc(1)
servinfo[0] = addrinfo_unix()
var hints = addrinfo_unix()
hints.ai_family = AF_INET
hints.ai_socktype = SOCK_STREAM
Expand All @@ -73,24 +67,18 @@ fn get_addr_info(host: String) raises -> AddrInfo:

var status = getaddrinfo_unix(
host_ptr,
Pointer[UInt8](),
Pointer.address_of(hints),
Pointer.address_of(servinfo),
DTypePointer[DType.uint8](),
UnsafePointer.address_of(hints),
UnsafePointer.address_of(servinfo),
)
if status != 0:
print("getaddrinfo failed to execute with status:", status)
var msg_ptr = gai_strerror(c_int(status))
_ = external_call["printf", c_int, Pointer[c_char], Pointer[c_char]](
to_char_ptr("gai_strerror: %s"), msg_ptr
)
var msg = c_charptr_to_string(msg_ptr)
print("getaddrinfo error message: ", msg)

if not servinfo:
print("servinfo is null")
raise Error("Failed to get address info. Pointer to addrinfo is null.")

return servinfo.load()
return move_from_pointee(servinfo)
else:
raise Error("Windows is not supported yet! Sorry!")

Expand All @@ -99,7 +87,7 @@ fn get_ip_address(host: String) raises -> String:
"""Get the IP address of a host."""
# Call getaddrinfo to get the IP address of the host.
var result = get_addr_info(host)
var ai_addr: Pointer[sockaddr]
var ai_addr: UnsafePointer[sockaddr]
var address_family: Int32 = 0
var address_length: UInt32 = 0
if result.isa[addrinfo]():
Expand All @@ -118,7 +106,7 @@ fn get_ip_address(host: String) raises -> String:
raise Error("Failed to get IP address. getaddrinfo was called successfully, but ai_addr is null.")

# Cast sockaddr struct to sockaddr_in struct and convert the binary IP to a string using inet_ntop.
var addr_in = ai_addr.bitcast[sockaddr_in]().load()
var addr_in = move_from_pointee(ai_addr.bitcast[sockaddr_in]())

return convert_binary_ip_to_string(addr_in.sin_addr.s_addr, address_family, address_length).strip()

Expand Down Expand Up @@ -167,7 +155,7 @@ fn convert_binary_ip_to_string(owned ip_address: UInt32, address_family: Int32,
return StringRef(string_buf, index)


fn build_sockaddr_pointer(ip_address: String, port: Int, address_family: Int) -> Pointer[sockaddr]:
fn build_sockaddr_pointer(ip_address: String, port: Int, address_family: Int) -> UnsafePointer[sockaddr]:
"""Build a sockaddr pointer from an IP address and port number.
https://learn.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-sockaddr_in.
Expand All @@ -176,4 +164,4 @@ fn build_sockaddr_pointer(ip_address: String, port: Int, address_family: Int) ->
var bin_ip = convert_ip_to_binary(ip_address, address_family)

var ai = sockaddr_in(address_family, bin_port, bin_ip, StaticTuple[c_char, 8]())
return Pointer[sockaddr_in].address_of(ai).bitcast[sockaddr]()
return UnsafePointer[sockaddr_in].address_of(ai).bitcast[sockaddr]()
2 changes: 1 addition & 1 deletion gojo/net/net.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ struct Connection(Conn):

return bytes_written, err

fn write(inout self, src: List[Byte]) -> (Int, Error):
fn write(inout self, src: Span[Byte]) -> (Int, Error):
"""Writes data to the underlying file descriptor.
Args:
Expand Down
33 changes: 16 additions & 17 deletions gojo/net/socket.mojo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.optional import Optional
from ..builtins import Byte
from ..syscall.file import close
from ..syscall.types import (
from ..syscall import (
c_void,
c_uint,
c_char,
Expand Down Expand Up @@ -168,9 +168,9 @@ struct Socket(FileDescriptorBase):
The return value is a connection where conn is a new socket object usable to send and receive data on the connection,
and address is the address bound to the socket on the other end of the connection.
"""
var their_addr_ptr = Pointer[sockaddr].alloc(1)
var their_addr_ptr = UnsafePointer[sockaddr].alloc(1)
var sin_size = socklen_t(sizeof[socklen_t]())
var new_sockfd = accept(self.sockfd.fd, their_addr_ptr, Pointer[socklen_t].address_of(sin_size))
var new_sockfd = accept(self.sockfd.fd, their_addr_ptr, UnsafePointer[socklen_t].address_of(sin_size))
if new_sockfd == -1:
raise Error("Failed to accept connection")

Expand Down Expand Up @@ -234,16 +234,16 @@ struct Socket(FileDescriptorBase):

# TODO: Add check to see if the socket is bound and error if not.

var local_address_ptr = Pointer[sockaddr].alloc(1)
var local_address_ptr = UnsafePointer[sockaddr].alloc(1)
var local_address_ptr_size = socklen_t(sizeof[sockaddr]())
var status = getsockname(
self.sockfd.fd,
local_address_ptr,
Pointer[socklen_t].address_of(local_address_ptr_size),
UnsafePointer[socklen_t].address_of(local_address_ptr_size),
)
if status == -1:
raise Error("Socket.get_sock_name: Failed to get address of local socket.")
var addr_in = local_address_ptr.bitcast[sockaddr_in]().load()
var addr_in = move_from_pointee(local_address_ptr.bitcast[sockaddr_in]())

return HostPort(
host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, 16),
Expand All @@ -256,18 +256,18 @@ struct Socket(FileDescriptorBase):
raise SocketClosedError

# TODO: Add check to see if the socket is bound and error if not.
var remote_address_ptr = Pointer[sockaddr].alloc(1)
var remote_address_ptr = UnsafePointer[sockaddr].alloc(1)
var remote_address_ptr_size = socklen_t(sizeof[sockaddr]())
var status = getpeername(
self.sockfd.fd,
remote_address_ptr,
Pointer[socklen_t].address_of(remote_address_ptr_size),
UnsafePointer[socklen_t].address_of(remote_address_ptr_size),
)
if status == -1:
raise Error("Socket.get_peer_name: Failed to get address of remote socket.")

# Cast sockaddr struct to sockaddr_in to convert binary IP to string.
var addr_in = remote_address_ptr.bitcast[sockaddr_in]().load()
var addr_in = move_from_pointee(remote_address_ptr.bitcast[sockaddr_in]())

return HostPort(
host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, 16),
Expand All @@ -280,9 +280,9 @@ struct Socket(FileDescriptorBase):
Args:
option_name: The socket option to get.
"""
var option_value_pointer = Pointer[c_void].alloc(1)
var option_value_pointer = UnsafePointer[c_void].alloc(1)
var option_len = socklen_t(sizeof[socklen_t]())
var option_len_pointer = Pointer.address_of(option_len)
var option_len_pointer = UnsafePointer.address_of(option_len)
var status = getsockopt(
self.sockfd.fd,
SOL_SOCKET,
Expand All @@ -293,7 +293,7 @@ struct Socket(FileDescriptorBase):
if status == -1:
raise Error("Socket.get_sock_opt failed with status: " + str(status))

return option_value_pointer.bitcast[Int]().load()
return move_from_pointee(option_value_pointer.bitcast[Int]())

fn set_socket_option(self, option_name: Int, owned option_value: UInt8 = 1) raises:
"""Return the value of the given socket option.
Expand All @@ -302,7 +302,7 @@ struct Socket(FileDescriptorBase):
option_name: The socket option to set.
option_value: The value to set the socket option to.
"""
var option_value_pointer = Pointer[c_void].address_of(option_value)
var option_value_pointer = UnsafePointer[c_void].address_of(option_value)
var option_len = sizeof[socklen_t]()
var status = setsockopt(self.sockfd.fd, SOL_SOCKET, option_name, option_value_pointer, option_len)
if status == -1:
Expand All @@ -324,7 +324,7 @@ struct Socket(FileDescriptorBase):
var remote = self.get_peer_name()
self.remote_address = TCPAddr(remote.host, remote.port)

fn write(inout self: Self, src: List[Byte]) -> (Int, Error):
fn write(inout self: Self, src: Span[Byte]) -> (Int, Error):
"""Send data to the socket. The socket must be connected to a remote socket.
Args:
Expand All @@ -348,7 +348,7 @@ struct Socket(FileDescriptorBase):
src: The data to send.
max_attempts: The maximum number of attempts to send the data.
"""
var header_pointer = src.unsafe_ptr()
var header_pointer = DTypePointer(src.unsafe_ptr())
var total_bytes_sent = 0
var attempts = 0

Expand Down Expand Up @@ -377,11 +377,10 @@ struct Socket(FileDescriptorBase):
address: The IP address to connect to.
port: The port number to connect to.
"""
var header_pointer = Pointer[Int8](src.data.address).bitcast[UInt8]()
self.connect(address, port)
var bytes_written: Int
var err: Error
bytes_written, err = self.write(src)
bytes_written, err = self.write(Span(src))
if err:
raise err
return bytes_written
Expand Down
2 changes: 1 addition & 1 deletion gojo/net/tcp.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct TCPConnection(Conn):

return bytes_written, Error()

fn write(inout self, src: List[Byte]) -> (Int, Error):
fn write(inout self, src: Span[Byte]) -> (Int, Error):
"""Writes data to the underlying file descriptor.
Args:
Expand Down
96 changes: 96 additions & 0 deletions gojo/strings/builder.mojo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from algorithm.functional import vectorize
import ..io
from ..builtins import Byte

Expand Down Expand Up @@ -122,3 +123,98 @@ struct StringBuilder[growth_factor: Float32 = 2](Stringable, Sized, io.Writer, i
src: The string to append.
"""
return self.write(src.as_bytes_slice())


@value
struct VectorizedStringBuilder(Stringable, Sized):
"""
A string builder class that allows for efficient string management and concatenation.
This class is useful when you need to build a string by appending multiple strings
together. The performance increase is not linear. Compared to string concatenation,
I've observed around 20-30x faster for writing and rending ~4KB and up to 2100x-2300x
for ~4MB. This is because it avoids the overhead of creating and destroying many
intermediate strings and performs memcopy operations.
The result is a more efficient when building larger string concatenations. It
is generally not recommended to use this class for small concatenations such as
a few strings like `a + b + c + d` because the overhead of creating the string
builder and appending the strings is not worth the performance gain.
Example:
```
from strings.builder import StringBuilder
var sb = StringBuilder()
sb.write_string("Hello ")
sb.write_string("World!")
print(sb) # Hello World!
```
"""

var data: List[String]
var position: List[Int]
var size: Int

@always_inline
fn __init__(inout self, *, capacity: Int = 4096):
self.data = List[String](capacity=capacity)
self.position = List[Int](0)
self.size = 0

@always_inline
fn __len__(self) -> Int:
"""
Returns the length of the string builder.
Returns:
The length of the string builder.
"""
return self.size

@always_inline
fn __str__(self) -> String:
"""
Converts the string builder to a string.
Returns:
The string representation of the string builder. Returns an empty
string if the string builder is empty.
"""
var copy = DTypePointer[DType.uint8]().alloc(self.size)

@parameter
fn copy_string[simd_width: Int](i: Int):
var elements = len(self.data[i])
if i == 0:
memcpy(copy, self.data[i].unsafe_ptr(), elements)
return

memcpy(copy.offset(self.position[i]), self.data[i].unsafe_ptr(), elements)

vectorize[copy_string, 1](size=len(self.data))
return StringRef(copy, self.size)

@always_inline
fn write(inout self, owned src: String) -> (Int, Error):
"""
Appends a byte Span to the builder buffer.
Args:
src: The byte array to append.
"""
var elements_to_write = len(src)
self.data.append(src^)
self.size += elements_to_write
self.position.append(self.size)

return elements_to_write, Error()

@always_inline
fn write_string(inout self, src: String) -> (Int, Error):
"""
Appends a string to the builder buffer.
Args:
src: The string to append.
"""
return self.write(src)
Loading

0 comments on commit 0aa8862

Please sign in to comment.