Skip to content

Commit

Permalink
clean up copies of file descriptor and remove copyinit
Browse files Browse the repository at this point in the history
  • Loading branch information
thatstoasty committed Apr 6, 2024
1 parent 99b32d2 commit 42c005b
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 13 deletions.
25 changes: 25 additions & 0 deletions gojo/net/dial.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .tcp import TCPAddr, TCPConnection, resolve_internet_addr
from .socket import Socket


@value
struct Dialer():
var local_address: TCPAddr

fn dial(self, network: String, address: String) raises -> TCPConnection:
var tcp_addr = resolve_internet_addr(network, address)
var socket = Socket(local_address=self.local_address)
socket.connect(tcp_addr.ip, tcp_addr.port)
print(String("Connected to ") + socket.remote_address)
return TCPConnection(socket ^)


fn dial_tcp(network: String, local_address: TCPAddr) raises -> TCPConnection:
# TODO: Add conversion of domain name to ip address
return Dialer(local_address).dial(
network, local_address.ip + ":" + str(local_address.port)
)


fn dial_tcp(network: String, ip: String, port: Int) raises -> TCPConnection:
return Dialer(TCPAddr(ip, port)).dial(network, ip + ":" + str(port))
1 change: 0 additions & 1 deletion gojo/net/fd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ trait FileDescriptorBase(io.Reader, io.Writer, io.Closer):
...


@value
struct FileDescriptor(FileDescriptorBase):
var fd: Int
var is_closed: Bool
Expand Down
6 changes: 3 additions & 3 deletions gojo/net/net.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ alias DEFAULT_BUFFER_SIZE = 4096


trait Conn(io.Writer, io.Reader, io.Closer):
fn __init__(inout self, socket: Socket):
fn __init__(inout self, owned socket: Socket):
...

"""Conn is a generic stream-oriented network connection."""
Expand Down Expand Up @@ -72,8 +72,8 @@ struct Connection(Conn):

var fd: Arc[Socket]

fn __init__(inout self, socket: Socket):
self.fd = Arc(socket)
fn __init__(inout self, owned socket: Socket):
self.fd = Arc(socket ^)

fn read(inout self, inout dest: List[Byte]) -> Result[Int]:
"""Reads data from the underlying file descriptor.
Expand Down
15 changes: 12 additions & 3 deletions gojo/net/socket.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ from .address import Addr, TCPAddr, HostPort
alias SocketClosedError = Error("Socket: Socket is already closed")


@value
struct Socket(FileDescriptorBase):
"""Represents a network file descriptor. Wraps around a file descriptor and provides network functions.
Expand Down Expand Up @@ -135,8 +134,18 @@ struct Socket(FileDescriptorBase):
self._closed = False
self._is_connected = True

fn __enter__(self) -> Self:
return self
fn __moveinit__(inout self, owned existing: Self):
self.sockfd = existing.sockfd ^
self.address_family = existing.address_family
self.socket_type = existing.socket_type
self.protocol = existing.protocol
self.local_address = existing.local_address ^
self.remote_address = existing.remote_address ^
self._closed = existing._closed
self._is_connected = existing._is_connected

# fn __enter__(self) -> Self:
# return self

# fn __exit__(inout self) raises:
# if self._is_connected:
Expand Down
21 changes: 16 additions & 5 deletions gojo/net/tcp.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ struct ListenConfig(CollectionElement):
socket.set_socket_option(SO_REUSEADDR, 1)
socket.listen()
print(String("Listening on ") + socket.local_address)
return TCPListener(socket, self, network, address)
return TCPListener(socket ^, self, network, address)


trait Listener(CollectionElement):
trait Listener(Movable):
# Raising here because a Result[Optional[Connection], Optional[WrappedError]] is funky.
fn accept(self) raises -> Connection:
...
Expand All @@ -84,8 +84,8 @@ struct TCPConnection(Conn):
fn __init__(inout self, connection: Connection):
self._connection = connection

fn __init__(inout self, socket: Socket):
self._connection = Connection(socket)
fn __init__(inout self, owned socket: Socket):
self._connection = Connection(socket ^)

fn __moveinit__(inout self, owned existing: Self):
self._connection = existing._connection ^
Expand Down Expand Up @@ -162,13 +162,24 @@ fn listen_tcp(network: String, ip: String, port: Int) raises -> TCPListener:
return ListenConfig(DEFAULT_TCP_KEEP_ALIVE).listen(network, ip + ":" + str(port))


@value
struct TCPListener(Listener):
var _file_descriptor: Socket
var listen_config: ListenConfig
var network_type: String
var address: String

fn __init__(inout self, owned file_descriptor: Socket, listen_config: ListenConfig, network_type: String, address: String):
self._file_descriptor = file_descriptor ^
self.listen_config = listen_config
self.network_type = network_type
self.address = address

fn __moveinit__(inout self, owned existing: Self):
self._file_descriptor = existing._file_descriptor ^
self.listen_config = existing.listen_config ^
self.network_type = existing.network_type
self.address = existing.address

fn listen(self) raises -> Self:
return self.listen_config.listen(self.network_type, self.address)

Expand Down
31 changes: 30 additions & 1 deletion test_get_addr.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,40 @@ from gojo.net.socket import Socket
from gojo.net.ip import get_ip_address
from gojo.net.tcp import listen_tcp, TCPAddr
from gojo.syscall.net import SO_REUSEADDR, PF_UNIX, SO_RCVTIMEO
from gojo.net.dial import dial_tcp

# fn main() raises:
# var ip = get_ip_address("localhost")
# print(ip)


fn test_dial() raises:
var connection = dial_tcp("tcp", get_ip_address("www.example.com"), 80)
var result = connection.write(String("GET / HTTP/1.1\r\n\r\n").as_bytes())
if result.error:
raise result.unwrap_error().error

if result.value == 0:
print("No bytes sent to peer.")
return

var response = List[Int8](capacity=4096)
result = connection.read(response)
if result.error:
raise result.unwrap_error().error

if result.value == 0:
print("No bytes received from peer.")
return

response.append(0)
print(String(response))

var err = connection.close()
if err:
raise err.value().error


fn test_listener() raises:
var listener = listen_tcp("tcp", TCPAddr("0.0.0.0", 8081))
while True:
Expand Down Expand Up @@ -52,5 +80,6 @@ fn test_stuff() raises:


fn main() raises:
test_stuff()
# test_stuff()
# test_listener()
test_dial()

0 comments on commit 42c005b

Please sign in to comment.