Skip to content

Commit

Permalink
added net, fd, and socket
Browse files Browse the repository at this point in the history
  • Loading branch information
thatstoasty committed Apr 4, 2024
1 parent 7284a9f commit 52dec7f
Show file tree
Hide file tree
Showing 17 changed files with 1,999 additions and 25 deletions.
8 changes: 4 additions & 4 deletions gojo/bufio/bufio.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ struct Reader[R: io.Reader](
var result = writer.write(self.buf[self.read_pos : self.write_pos])
if result.error:
return Result(Int64(result.value), result.error)

var bytes_written = result.value
if bytes_written < 0:
panic(ERR_NEGATIVE_WRITE)
Expand Down Expand Up @@ -885,13 +885,13 @@ struct Writer[W: io.Writer](
while nr < MAX_CONSECUTIVE_EMPTY_READS:
# TODO: should really be using a slice that returns refs and not a copy.
# Read into remaining unused space in the buffer. We need to reserve capacity for the slice otherwise read will never hit EOF.
var sl = self.buf[self.bytes_written:len(self.buf)]
var sl = self.buf[self.bytes_written : len(self.buf)]
sl.reserve(self.buf.capacity)
var result = reader.read(sl)
bytes_read = result.value
err = result.get_error()
_ = copy(self.buf, sl, self.bytes_written)

if bytes_read != 0 or err:
break
nr += 1
Expand All @@ -910,7 +910,7 @@ struct Writer[W: io.Writer](
err = self.flush()
else:
err = None

return Result(total_bytes_written, None)


Expand Down
24 changes: 12 additions & 12 deletions gojo/builtins/bytes.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ fn has_suffix(bytes: List[Byte], suffix: List[Byte]) -> Bool:


fn index_byte(bytes: List[Byte], delim: Byte) -> Int:
"""Return the index of the first occurrence of the byte delim.
"""Return the index of the first occurrence of the byte delim.
Args:
bytes: The List[Byte] struct to search.
delim: The byte to search for.
Args:
bytes: The List[Byte] struct to search.
delim: The byte to search for.
Returns:
The index of the first occurrence of the byte delim.
"""
var i = 0
for i in range(len(bytes)):
if bytes[i] == delim:
return i
Returns:
The index of the first occurrence of the byte delim.
"""
var i = 0
for i in range(len(bytes)):
if bytes[i] == delim:
return i

return -1
return -1


fn to_string(bytes: List[Byte]) -> String:
Expand Down
2 changes: 1 addition & 1 deletion gojo/builtins/list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ fn equals(left: List[Bool], right: List[Bool]) -> Bool:
for i in range(len(left)):
if left[i] != right[i]:
return False
return True
return True
1 change: 0 additions & 1 deletion gojo/bytes/reader.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,3 @@ fn new_reader(buffer: String) -> Reader:
"""
return Reader(buffer.as_bytes(), 0, -1)

2 changes: 1 addition & 1 deletion gojo/io/traits.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ trait Closer(Movable):
Specific implementations may document their own behavior.
"""

fn close(inout self) raises:
fn close(inout self) -> Optional[WrappedError]:
...


Expand Down
4 changes: 4 additions & 0 deletions gojo/net/__init__.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Adapted from go's net package
A good chunk of the leg work here came from the lightbug_http project! https://github.com/saviorand/lightbug_http/tree/main
"""
148 changes: 148 additions & 0 deletions gojo/net/address.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
@value
struct NetworkType:
var value: String

alias empty = NetworkType("")
alias tcp = NetworkType("tcp")
alias tcp4 = NetworkType("tcp4")
alias tcp6 = NetworkType("tcp6")
alias udp = NetworkType("udp")
alias udp4 = NetworkType("udp4")
alias udp6 = NetworkType("udp6")
alias ip = NetworkType("ip")
alias ip4 = NetworkType("ip4")
alias ip6 = NetworkType("ip6")
alias unix = NetworkType("unix")


trait Addr(CollectionElement, Stringable):
fn network(self) -> String:
"""Name of the network (for example, "tcp", "udp")."""
...


@value
struct TCPAddr(Addr):
"""Addr struct representing a TCP address.
Args:
ip: IP address.
port: Port number.
zone: IPv6 addressing zone.
"""
var ip: String
var port: Int
var zone: String # IPv6 addressing zone

fn __init__(inout self):
self.ip = String("127.0.0.1")
self.port = 8000
self.zone = ""

fn __init__(inout self, ip: String, port: Int):
self.ip = ip
self.port = port
self.zone = ""

fn __str__(self) -> String:
if self.zone != "":
return join_host_port(String(self.ip) + "%" + self.zone, self.port)
return join_host_port(self.ip, self.port)

fn network(self) -> String:
return NetworkType.tcp.value


fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr:
var host: String = ""
var port: String = ""
var portnum: Int = 0
if (
network == NetworkType.tcp.value
or network == NetworkType.tcp4.value
or network == NetworkType.tcp6.value
or network == NetworkType.udp.value
or network == NetworkType.udp4.value
or network == NetworkType.udp6.value
):
if address != "":
var host_port = split_host_port(address)
host = host_port.host
port = host_port.port
portnum = atol(port.__str__())
elif (
network == NetworkType.ip.value
or network == NetworkType.ip4.value
or network == NetworkType.ip6.value
):
if address != "":
host = address
elif network == NetworkType.unix.value:
raise Error("Unix addresses not supported yet")
else:
raise Error("unsupported network type: " + network)
return TCPAddr(host, portnum)


fn join_host_port(host: String, port: String) -> String:
if host.find(":") != -1: # must be IPv6 literal
return "[" + host + "]:" + port
return host + ":" + port


alias missingPortError = Error("missing port in address")
alias tooManyColonsError = Error("too many colons in address")


struct HostPort(Stringable):
var host: String
var port: Int

fn __init__(inout self, host: String, port: Int):
self.host = host
self.port = port

fn __str__(self) -> String:
return join_host_port(self.host, str(self.port))


fn split_host_port(hostport: String) raises -> HostPort:
var host: String = ""
var port: String = ""
var colon_index = hostport.rfind(":")
var j: Int = 0
var k: Int = 0

if colon_index == -1:
raise missingPortError
if hostport[0] == "[":
var end_bracket_index = hostport.find("]")
if end_bracket_index == -1:
raise Error("missing ']' in address")
if end_bracket_index + 1 == len(hostport):
raise missingPortError
elif end_bracket_index + 1 == colon_index:
host = hostport[1:end_bracket_index]
j = 1
k = end_bracket_index + 1
else:
if hostport[end_bracket_index + 1] == ":":
raise tooManyColonsError
else:
raise missingPortError
else:
host = hostport[:colon_index]
if host.find(":") != -1:
raise tooManyColonsError
if hostport[j:].find("[") != -1:
raise Error("unexpected '[' in address")
if hostport[k:].find("]") != -1:
raise Error("unexpected ']' in address")
port = hostport[colon_index + 1 :]

if port == "":
raise missingPortError
if host == "":
raise Error("missing host")

return HostPort(host, atol(port))
111 changes: 111 additions & 0 deletions gojo/net/fd.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from collections.optional import Optional
import ..io
from ..builtins import Byte, Result, WrappedError
from ..syscall.file import close
from ..syscall.types import (
c_void,
c_uint,
c_char,
c_int,
)
from ..syscall.net import (
sockaddr,
sockaddr_in,
addrinfo,
addrinfo_unix,
socklen_t,
socket,
connect,
recv,
send,
shutdown,
inet_pton,
inet_ntoa,
inet_ntop,
to_char_ptr,
htons,
ntohs,
strlen,
getaddrinfo,
getaddrinfo_unix,
gai_strerror,
c_charptr_to_string,
bind,
listen,
accept,
setsockopt,
getsockopt,
getsockname,
getpeername,
c_charptr_to_string,
AF_INET,
SOCK_STREAM,
SHUT_RDWR,
AI_PASSIVE,
SOL_SOCKET,
SO_REUSEADDR,
SO_RCVTIMEO,
)
from external.libc import Str, c_ssize_t, c_size_t, char_pointer

alias O_RDWR = 0o2


trait FileDescriptorBase(io.Reader, io.Writer, io.Closer):
...


@value
struct FileDescriptor(FileDescriptorBase):
var fd: Int

# This takes ownership of a POSIX file descriptor.
fn __moveinit__(inout self, owned existing: Self):
self.fd = existing.fd

fn __init__(inout self, fd: Int):
self.fd = fd

fn __del__(owned self):
var err = self.close()
if err:
print(err.value())

fn close(inout self) -> Optional[WrappedError]:
"""Mark the file descriptor as closed."""
var close_status = close(self.fd)
if close_status == -1:
return WrappedError("FileDescriptor.close: Failed to close socket")

return None

fn dup(self) -> Self:
"""Duplicate the file descriptor."""
var new_fd = external_call["dup", Int, Int](self.fd)
return Self(new_fd)

fn read(inout self, inout dest: List[Byte]) -> Result[Int]:
"""Receive data from the file descriptor and write it to the buffer provided."""
var ptr = Pointer[UInt8]().alloc(dest.capacity)
var bytes_received = recv(self.fd, ptr, dest.capacity, 0)
if bytes_received == -1:
return Result(0, WrappedError("Failed to receive message from socket."))

var int8_ptr = ptr.bitcast[Int8]()
for i in range(bytes_received):
dest.append(int8_ptr[i])

if bytes_received < dest.capacity:
return Result(bytes_received, WrappedError(io.EOF))

return bytes_received

fn write(inout self, src: List[Byte]) -> Result[Int]:
"""Write data from the buffer to the file descriptor."""
var header_pointer = Pointer[Int8](src.data.value).bitcast[UInt8]()

var bytes_sent = send(self.fd, header_pointer, strlen(header_pointer), 0)
if bytes_sent == -1:
return Result(0, WrappedError("Failed to send message"))

return bytes_sent
Loading

0 comments on commit 52dec7f

Please sign in to comment.