Skip to content

Commit bf047f1

Browse files
committedAug 14, 2024
add cancel support to EvLoop.Fd
1 parent f438d90 commit bf047f1

16 files changed

+207
-147
lines changed
 

‎build.zig

+1-1
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ fn setup_libexeobj_step(step: *LibExeObjStep) void {
701701

702702
// compile
703703
if (step.kind == .obj)
704-
step.use_stage1 = true; // required by async/await (.zig)
704+
step.use_stage1 = true; // required by coroutine (.zig)
705705

706706
step.single_threaded = true;
707707

‎src/CacheMsg.zig

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ const cc = @import("cc.zig");
55
const dns = @import("dns.zig");
66
const log = @import("log.zig");
77
const Rc = @import("Rc.zig");
8-
const ListNode = @import("ListNode.zig");
8+
const Node = @import("Node.zig");
99
const Bytes = cc.Bytes;
1010

1111
// =======================================================
1212

1313
const CacheMsg = @This();
1414

1515
next: ?*CacheMsg = null, // for hashmap
16-
list_node: ListNode = undefined,
16+
list_node: Node = undefined,
1717
update_time: c.time_t,
1818
hashv: c_uint,
1919
ttl: i32,
@@ -73,7 +73,7 @@ pub fn free(self: *CacheMsg) void {
7373
g.allocator.free(self.mem());
7474
}
7575

76-
pub fn from_list_node(node: *ListNode) *CacheMsg {
76+
pub fn from_list_node(node: *Node) *CacheMsg {
7777
return @fieldParentPtr(CacheMsg, "list_node", node);
7878
}
7979

‎src/EvLoop.zig

+62-18
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ pub const Fd = struct {
9696
write_frame: ?anyframe = null, // waiting for writable event
9797
fd: c_int,
9898
rc: Rc = .{},
99+
canceled: bool = false,
99100

100101
/// ownership of `fd` is transferred to `fdobj`
101102
pub fn new(fd: c_int) *Fd {
@@ -139,6 +140,31 @@ pub const Fd = struct {
139140

140141
return events;
141142
}
143+
144+
pub fn cancel(self: *Fd) void {
145+
if (self.canceled)
146+
return;
147+
148+
self.canceled = true;
149+
// cc.set_errno(c.ECANCELED);
150+
151+
if (self.read_frame) |frame|
152+
co.do_resume(frame);
153+
if (self.write_frame) |frame|
154+
co.do_resume(frame);
155+
156+
assert(self.read_frame == null);
157+
assert(self.write_frame == null);
158+
}
159+
160+
/// return `true` if canceled and set errno to `ECANCELED`
161+
pub fn is_canceled(self: *const Fd) bool {
162+
if (self.canceled) {
163+
cc.set_errno(c.ECANCELED);
164+
return true;
165+
}
166+
return false;
167+
}
142168
};
143169

144170
// =============================================================
@@ -256,6 +282,7 @@ fn del(self: *EvLoop, fdobj: *const Fd) bool {
256282
// ======================================================================
257283

258284
fn set_frame(fdobj: *Fd, comptime field_name: []const u8, frame: anyframe) void {
285+
assert(!fdobj.canceled);
259286
assert(@field(fdobj, field_name) == null);
260287
@field(fdobj, field_name) = frame;
261288
}
@@ -401,26 +428,34 @@ pub fn run(self: *EvLoop) void {
401428

402429
// ========================================================================
403430

404-
// socket API (non-blocking + async)
431+
// socket API (non-blocking + coroutine)
405432

406433
comptime {
407434
assert(c.EAGAIN == c.EWOULDBLOCK);
408435
}
409436

410-
/// used for external modules, not for this module:
411-
/// because the async-call chains consume at least 24 bytes per level (x86_64)
412-
pub fn wait_readable(self: *EvLoop, fdobj: *Fd) void {
437+
/// return `null` if fdobj is canceled. \
438+
/// used for external modules, not for this module: \
439+
/// because the coroutine chains consume at least 24 bytes per level (x86_64).
440+
pub fn wait_readable(self: *EvLoop, fdobj: *Fd) ?void {
413441
self.add_readable(fdobj, @frame());
414442
suspend {}
415443
self.del_readable(fdobj, @frame());
444+
445+
if (fdobj.is_canceled())
446+
return null;
416447
}
417448

418-
/// used for external modules, not for this module:
419-
/// because the async-call chains consume at least 24 bytes per level (x86_64)
420-
pub fn wait_writable(self: *EvLoop, fdobj: *Fd) void {
449+
/// return `null` if fdobj is canceled. \
450+
/// used for external modules, not for this module: \
451+
/// because the coroutine chains consume at least 24 bytes per level (x86_64).
452+
pub fn wait_writable(self: *EvLoop, fdobj: *Fd) ?void {
421453
self.add_writable(fdobj, @frame());
422454
suspend {}
423455
self.del_writable(fdobj, @frame());
456+
457+
if (fdobj.is_canceled())
458+
return null;
424459
}
425460

426461
pub fn connect(self: *EvLoop, fdobj: *Fd, addr: *const cc.SockAddr) ?void {
@@ -432,6 +467,9 @@ pub fn connect(self: *EvLoop, fdobj: *Fd, addr: *const cc.SockAddr) ?void {
432467
suspend {}
433468
self.del_writable(fdobj, @frame());
434469

470+
if (fdobj.is_canceled())
471+
return null;
472+
435473
if (net.getsockopt_int(fdobj.fd, c.SOL_SOCKET, c.SO_ERROR, "SO_ERROR")) |err| {
436474
if (err == 0) return;
437475
cc.set_errno(err);
@@ -444,7 +482,7 @@ pub fn connect(self: *EvLoop, fdobj: *Fd, addr: *const cc.SockAddr) ?void {
444482
}
445483

446484
pub fn accept(self: *EvLoop, fdobj: *Fd, src_addr: ?*cc.SockAddr) ?c_int {
447-
while (true) {
485+
while (!fdobj.is_canceled()) {
448486
return cc.accept4(fdobj.fd, src_addr, c.SOCK_NONBLOCK | c.SOCK_CLOEXEC) orelse {
449487
if (cc.errno() != c.EAGAIN)
450488
return null;
@@ -455,11 +493,11 @@ pub fn accept(self: *EvLoop, fdobj: *Fd, src_addr: ?*cc.SockAddr) ?c_int {
455493

456494
continue;
457495
};
458-
}
496+
} else return null;
459497
}
460498

461499
pub fn read(self: *EvLoop, fdobj: *Fd, buf: []u8) ?usize {
462-
while (true) {
500+
while (!fdobj.is_canceled()) {
463501
return cc.read(fdobj.fd, buf) orelse {
464502
if (cc.errno() != c.EAGAIN)
465503
return null;
@@ -470,11 +508,11 @@ pub fn read(self: *EvLoop, fdobj: *Fd, buf: []u8) ?usize {
470508

471509
continue;
472510
};
473-
}
511+
} else return null;
474512
}
475513

476514
pub fn recvfrom(self: *EvLoop, fdobj: *Fd, buf: []u8, flags: c_int, src_addr: *cc.SockAddr) ?usize {
477-
while (true) {
515+
while (!fdobj.is_canceled()) {
478516
return cc.recvfrom(fdobj.fd, buf, flags, src_addr) orelse {
479517
if (cc.errno() != c.EAGAIN)
480518
return null;
@@ -485,11 +523,11 @@ pub fn recvfrom(self: *EvLoop, fdobj: *Fd, buf: []u8, flags: c_int, src_addr: *c
485523

486524
continue;
487525
};
488-
}
526+
} else return null;
489527
}
490528

491529
pub fn recv(self: *EvLoop, fdobj: *Fd, buf: []u8, flags: c_int) ?usize {
492-
while (true) {
530+
while (!fdobj.is_canceled()) {
493531
return cc.recv(fdobj.fd, buf, flags) orelse {
494532
if (cc.errno() != c.EAGAIN)
495533
return null;
@@ -500,16 +538,18 @@ pub fn recv(self: *EvLoop, fdobj: *Fd, buf: []u8, flags: c_int) ?usize {
500538

501539
continue;
502540
};
503-
}
541+
} else return null;
504542
}
505543

506-
const ReadErr = error{ eof, other };
544+
const ReadErr = error{ eof, errno };
507545

508-
pub fn recv_exactly(self: *EvLoop, fdobj: *Fd, buf: []u8, flags: c_int) ReadErr!void {
546+
pub fn recv_full(self: *EvLoop, fdobj: *Fd, buf: []u8, flags: c_int) ReadErr!void {
509547
var nread: usize = 0;
510548
while (nread < buf.len) {
549+
if (fdobj.is_canceled())
550+
return ReadErr.errno; // ECANCELED
511551
const n = self.recv(fdobj, buf[nread..], flags) orelse
512-
return ReadErr.other;
552+
return ReadErr.errno;
513553
if (n == 0)
514554
return ReadErr.eof;
515555
nread += n;
@@ -525,6 +565,8 @@ pub fn recv_exactly(self: *EvLoop, fdobj: *Fd, buf: []u8, flags: c_int) ReadErr!
525565
pub fn send(self: *EvLoop, fdobj: *Fd, data: []const u8, flags: c_int) ?void {
526566
var nsend: usize = 0;
527567
while (nsend < data.len) {
568+
if (fdobj.is_canceled())
569+
return null;
528570
const n = cc.send(fdobj.fd, data[nsend..], flags) orelse b: {
529571
if (cc.errno() != c.EAGAIN)
530572
return null;
@@ -544,6 +586,8 @@ pub fn send(self: *EvLoop, fdobj: *Fd, data: []const u8, flags: c_int) ?void {
544586
pub fn sendmsg(self: *EvLoop, fdobj: *Fd, msg: *const cc.msghdr_t, flags: c_int) ?void {
545587
var remain_len: usize = msg.calc_len();
546588
while (remain_len > 0) {
589+
if (fdobj.is_canceled())
590+
return null;
547591
const n = cc.sendmsg(fdobj.fd, msg, flags) orelse b: {
548592
if (cc.errno() != c.EAGAIN)
549593
return null;

‎src/ListNode.zig ‎src/Node.zig

+25-25
Original file line numberDiff line numberDiff line change
@@ -5,55 +5,55 @@ const assert = std.debug.assert;
55

66
// =====================================================
77

8-
const ListNode = @This();
8+
const Node = @This();
99

10-
prev: *ListNode,
11-
next: *ListNode,
10+
prev: *Node,
11+
next: *Node,
1212

1313
// =================== `list_head(sentinel)` ===================
1414

1515
/// empty list (sentinel node)
16-
pub fn init(list: *ListNode) void {
16+
pub fn init(list: *Node) void {
1717
list.prev = list;
1818
list.next = list;
1919
}
2020

2121
/// first node
22-
pub inline fn head(list: *const ListNode) *ListNode {
22+
pub inline fn head(list: *const Node) *Node {
2323
return list.next;
2424
}
2525

2626
/// last node
27-
pub inline fn tail(list: *const ListNode) *ListNode {
27+
pub inline fn tail(list: *const Node) *Node {
2828
return list.prev;
2929
}
3030

3131
/// is sentinel node
32-
pub inline fn is_empty(list: *const ListNode) bool {
32+
pub inline fn is_empty(list: *const Node) bool {
3333
return list.head() == list;
3434
}
3535

3636
/// `unlink(node)` and/or `free(node)` is safe
37-
pub fn iterator(list: *const ListNode) Iterator {
37+
pub fn iterator(list: *const Node) Iterator {
3838
return .{
3939
.sentinel = list,
4040
.node = list.head(),
4141
};
4242
}
4343

4444
/// `unlink(node)` and/or `free(node)` is safe
45-
pub fn reverse_iterator(list: *const ListNode) ReverseIterator {
45+
pub fn reverse_iterator(list: *const Node) ReverseIterator {
4646
return .{
4747
.sentinel = list,
4848
.node = list.tail(),
4949
};
5050
}
5151

5252
pub const Iterator = struct {
53-
sentinel: *const ListNode,
54-
node: *ListNode,
53+
sentinel: *const Node,
54+
node: *Node,
5555

56-
pub fn next(it: *Iterator) ?*ListNode {
56+
pub fn next(it: *Iterator) ?*Node {
5757
const node = it.node;
5858
if (node != it.sentinel) {
5959
it.node = node.next;
@@ -64,10 +64,10 @@ pub const Iterator = struct {
6464
};
6565

6666
pub const ReverseIterator = struct {
67-
sentinel: *const ListNode,
68-
node: *ListNode,
67+
sentinel: *const Node,
68+
node: *Node,
6969

70-
pub fn next(it: *ReverseIterator) ?*ListNode {
70+
pub fn next(it: *ReverseIterator) ?*Node {
7171
const node = it.node;
7272
if (node != it.sentinel) {
7373
it.node = node.prev;
@@ -79,31 +79,31 @@ pub const ReverseIterator = struct {
7979

8080
// =================== `node` ===================
8181

82-
pub fn link_to_head(list: *ListNode, node: *ListNode) void {
82+
pub fn link_to_head(list: *Node, node: *Node) void {
8383
return node.link(list, list.head());
8484
}
8585

86-
pub fn link_to_tail(list: *ListNode, node: *ListNode) void {
86+
pub fn link_to_tail(list: *Node, node: *Node) void {
8787
return node.link(list.tail(), list);
8888
}
8989

9090
/// assume that the `node` is linked to the `list`
91-
pub fn move_to_head(list: *ListNode, node: *ListNode) void {
91+
pub fn move_to_head(list: *Node, node: *Node) void {
9292
if (node != list.head()) {
9393
node.unlink();
9494
list.link_to_head(node);
9595
}
9696
}
9797

9898
/// assume that the `node` is linked to the `list`
99-
pub fn move_to_tail(list: *ListNode, node: *ListNode) void {
99+
pub fn move_to_tail(list: *Node, node: *Node) void {
100100
if (node != list.tail()) {
101101
node.unlink();
102102
list.link_to_tail(node);
103103
}
104104
}
105105

106-
fn link(node: *ListNode, prev: *ListNode, next: *ListNode) void {
106+
fn link(node: *Node, prev: *Node, next: *Node) void {
107107
prev.next = node;
108108
node.prev = prev;
109109
node.next = next;
@@ -112,7 +112,7 @@ fn link(node: *ListNode, prev: *ListNode, next: *ListNode) void {
112112

113113
/// `node.prev` and `node.next` are unmodified, use `node.init()` if needed.
114114
/// `list_head.unlink()` is not allowed unless `list_head` is an empty list.
115-
pub fn unlink(node: *const ListNode) void {
115+
pub fn unlink(node: *const Node) void {
116116
node.prev.next = node.next;
117117
node.next.prev = node.prev;
118118
}
@@ -121,15 +121,15 @@ pub fn unlink(node: *const ListNode) void {
121121

122122
const Object = struct {
123123
id: u32,
124-
node: ListNode,
124+
node: Node,
125125

126-
pub fn from_node(node: *ListNode) *Object {
126+
pub fn from_node(node: *Node) *Object {
127127
return @fieldParentPtr(Object, "node", node);
128128
}
129129
};
130130

131131
pub fn @"test: linked list"() !void {
132-
var list: ListNode = undefined;
132+
var list: Node = undefined;
133133
list.init();
134134

135135
defer {
@@ -195,7 +195,7 @@ pub fn @"test: linked list"() !void {
195195
}
196196

197197
// link_to_head
198-
var l: ListNode = undefined;
198+
var l: Node = undefined;
199199
l.init();
200200

201201
defer {

‎src/Upstream.zig

+50-30
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const Tag = @import("tag.zig").Tag;
1313
const DynStr = @import("DynStr.zig");
1414
const EvLoop = @import("EvLoop.zig");
1515
const RcMsg = @import("RcMsg.zig");
16+
const Node = @import("Node.zig");
1617
const flags_op = @import("flags_op.zig");
1718
const assert = std.debug.assert;
1819

@@ -29,7 +30,8 @@ comptime {
2930
const Upstream = @This();
3031

3132
// runtime info
32-
ctx: ?*anyopaque = null,
33+
data: ?*anyopaque = null,
34+
data2: ?*anyopaque = null, // ssl session for resumption
3335

3436
// config info
3537
host: ?cc.ConstStr, // DoT SNI
@@ -70,7 +72,7 @@ fn init(tag: Tag, proto: Proto, addr: *const cc.SockAddr, host: []const u8, ip:
7072
}
7173

7274
fn deinit(self: *const Upstream) void {
73-
assert(self.ctx == null);
75+
assert(self.data == null and self.data2 == null);
7476

7577
if (self.host) |host|
7678
g.allocator.free(cc.strslice_c(host));
@@ -100,20 +102,38 @@ fn send(self: *Upstream, qmsg: *RcMsg) void {
100102

101103
// ======================================================
102104

105+
const TypedNode = struct {
106+
type: enum { udp, tcp }, // `struct UDP` or `struct TCP`
107+
node: Node,
108+
109+
pub fn from_node(node: *Node) *TypedNode {
110+
return @fieldParentPtr(TypedNode, "node", node);
111+
}
112+
};
113+
114+
const UDP = struct {
115+
typed_node: TypedNode,
116+
fdobj: *EvLoop.Fd,
117+
create_time: c.time_t,
118+
query_rtime: u16, // relative to create_time
119+
query_count: u16,
120+
reply_count: u16,
121+
};
122+
103123
fn udp_get_fdobj(self: *const Upstream) ?*EvLoop.Fd {
104124
assert(self.proto == .udpi or self.proto == .udp);
105-
return cc.ptrcast(?*EvLoop.Fd, self.ctx);
125+
return cc.ptrcast(?*EvLoop.Fd, self.data);
106126
}
107127

108128
fn udp_set_fdobj(self: *Upstream, fdobj: ?*EvLoop.Fd) void {
109129
assert(self.proto == .udpi or self.proto == .udp);
110-
self.ctx = fdobj;
130+
self.data = fdobj;
111131
}
112132

113133
fn udp_send(self: *Upstream, qmsg: *RcMsg) void {
114134
const fd = if (self.udp_get_fdobj()) |fdobj| fdobj.fd else b: {
115135
const fd = net.new_sock(self.addr.family(), .udp) orelse return;
116-
co.create(udp_recv, .{ self, fd });
136+
co.start(udp_recv, .{ self, fd });
117137
assert(self.udp_get_fdobj() != null);
118138
break :b fd;
119139
};
@@ -179,7 +199,7 @@ fn udp_recv(self: *Upstream, fd: c_int) void {
179199
log.warn(@src(), "recv(%s) failed: (%d) %m", .{ self.url, cc.errno() });
180200
return;
181201
}
182-
g.evloop.wait_readable(fdobj);
202+
g.evloop.wait_readable(fdobj) orelse return;
183203
continue;
184204
};
185205
} else return;
@@ -203,7 +223,8 @@ fn udp_on_eol(self: *Upstream) void {
203223
if (fdobj.read_frame) |frame| {
204224
co.do_resume(frame);
205225
} else {
206-
// this coroutine may be sending a response to the tcp client (suspended)
226+
// this coroutine may be sending a response to the tcp client (suspending)
227+
// TODO: change to nosuspend send_reply
207228
}
208229
}
209230

@@ -279,26 +300,26 @@ pub const TLS = struct {
279300

280301
const TCP = struct {
281302
upstream: *const Upstream,
282-
fdobj: ?*EvLoop.Fd = null,
303+
fdobj: ?*EvLoop.Fd = null, // tcp connection
304+
tls: TLS_ = .{}, // tls connection (DoT)
283305
send_list: MsgQueue = .{}, // qmsg to be sent
284306
ack_list: std.AutoHashMapUnmanaged(u16, *RcMsg) = .{}, // qmsg to be ack
285307
pending_n: u16 = 0, // outstanding queries: send_list + ack_list
286308
healthy: bool = false, // current connection processed at least one query ?
287-
tls: TLS_ = .{}, // for DoT upstream
288309

289310
const TLS_ = if (has_tls) TLS else struct {};
290311

291312
/// must <= u16_max
292313
const PENDING_MAX = std.math.maxInt(u16);
293314

294315
const MsgQueue = struct {
295-
head: ?*Node = null,
296-
tail: ?*Node = null,
316+
head: ?*MsgNode = null,
317+
tail: ?*MsgNode = null,
297318
waiter: ?anyframe = null,
298319

299-
const Node = struct {
320+
const MsgNode = struct {
300321
msg: *RcMsg,
301-
next: *Node,
322+
next: *MsgNode,
302323
};
303324

304325
/// `null`: cancel wait
@@ -312,7 +333,7 @@ const TCP = struct {
312333
return;
313334
}
314335

315-
const node = g.allocator.create(Node) catch unreachable;
336+
const node = g.allocator.create(MsgNode) catch unreachable;
316337
node.* = .{
317338
.msg = msg,
318339
.next = undefined,
@@ -342,7 +363,7 @@ const TCP = struct {
342363
}
343364

344365
/// `null`: cancel wait
345-
pub fn pop(self: *MsgQueue, blocking: bool) ?*RcMsg {
366+
pub fn pop(self: *MsgQueue, suspending: bool) ?*RcMsg {
346367
if (self.head) |node| {
347368
defer g.allocator.destroy(node);
348369
if (node == self.tail) {
@@ -354,7 +375,7 @@ const TCP = struct {
354375
}
355376
return node.msg;
356377
} else {
357-
if (!blocking)
378+
if (!suspending)
358379
return null;
359380
self.waiter = @frame();
360381
suspend {}
@@ -404,8 +425,7 @@ const TCP = struct {
404425
self.start();
405426
}
406427

407-
/// [async] used to send qmsg to upstream
408-
/// pop from send_list && add to ack_list
428+
/// [suspending] pop from send_list && add to ack_list
409429
fn pop_qmsg(self: *TCP, fdobj: *EvLoop.Fd) ?*RcMsg {
410430
if (!self.fdobj_ok(fdobj)) return null;
411431

@@ -477,7 +497,7 @@ const TCP = struct {
477497
assert(self.ack_list.count() == 0);
478498

479499
self.healthy = false;
480-
co.create(TCP.send, .{self});
500+
co.start(TCP.send, .{self});
481501
}
482502

483503
fn send(self: *TCP) void {
@@ -494,7 +514,7 @@ const TCP = struct {
494514

495515
self.do_connect(fdobj) orelse return;
496516

497-
co.create(recv, .{self});
517+
co.start(recv, .{self});
498518

499519
while (self.pop_qmsg(fdobj)) |qmsg| {
500520
self.do_send(fdobj, qmsg) orelse return;
@@ -574,11 +594,11 @@ const TCP = struct {
574594
var err: c_int = undefined;
575595
cc.SSL_connect(self.ssl(), &err) orelse switch (err) {
576596
c.WOLFSSL_ERROR_WANT_READ => {
577-
g.evloop.wait_readable(fdobj);
597+
g.evloop.wait_readable(fdobj) orelse return null;
578598
continue;
579599
},
580600
c.WOLFSSL_ERROR_WANT_WRITE => {
581-
g.evloop.wait_writable(fdobj);
601+
g.evloop.wait_writable(fdobj) orelse return null;
582602
continue;
583603
},
584604
else => {
@@ -624,7 +644,7 @@ const TCP = struct {
624644
.msg_iov = &iov,
625645
.msg_iovlen = iov.len,
626646
};
627-
g.evloop.sendmsg(fdobj, &msg, 0) orelse break :e null; // async
647+
g.evloop.sendmsg(fdobj, &msg, 0) orelse break :e null;
628648
} else if (has_tls) {
629649
// merge into one ssl record
630650
var buf: [2 + c.DNS_QMSG_MAXSIZE]u8 align(2) = undefined;
@@ -638,7 +658,7 @@ const TCP = struct {
638658
var err: c_int = undefined;
639659
cc.SSL_write(self.ssl(), data, &err) orelse switch (err) {
640660
c.WOLFSSL_ERROR_WANT_WRITE => {
641-
g.evloop.wait_writable(fdobj); // async
661+
g.evloop.wait_writable(fdobj) orelse return null;
642662
continue;
643663
},
644664
else => {
@@ -661,9 +681,9 @@ const TCP = struct {
661681
if (!self.fdobj_ok(fdobj)) return null;
662682

663683
if (self.upstream.proto != .tls) {
664-
g.evloop.recv_exactly(fdobj, buf, flags) catch |err| switch (err) {
684+
g.evloop.recv_full(fdobj, buf, flags) catch |err| switch (err) {
665685
error.eof => return null,
666-
error.other => break :e null,
686+
error.errno => break :e null,
667687
};
668688
} else if (has_tls) {
669689
var nread: usize = 0;
@@ -676,7 +696,7 @@ const TCP = struct {
676696
return null;
677697
},
678698
c.WOLFSSL_ERROR_WANT_READ => {
679-
g.evloop.wait_readable(fdobj); // async
699+
g.evloop.wait_readable(fdobj) orelse return null;
680700
continue;
681701
},
682702
else => {
@@ -696,9 +716,9 @@ const TCP = struct {
696716

697717
fn tcp_ctx(self: *Upstream) *TCP {
698718
assert(self.proto == .tcpi or self.proto == .tcp or self.proto == .tls);
699-
if (self.ctx == null)
700-
self.ctx = TCP.new(self);
701-
return cc.ptrcast(*TCP, self.ctx.?);
719+
if (self.data == null)
720+
self.data = TCP.new(self);
721+
return cc.ptrcast(*TCP, self.data.?);
702722
}
703723

704724
fn tcp_send(self: *Upstream, qmsg: *RcMsg) void {

‎src/cache.zig

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ const g = @import("g.zig");
33
const c = @import("c.zig");
44
const cc = @import("cc.zig");
55
const dns = @import("dns.zig");
6-
const ListNode = @import("ListNode.zig");
6+
const Node = @import("Node.zig");
77
const CacheMsg = @import("CacheMsg.zig");
88
const cache_ignore = @import("cache_ignore.zig");
99
const log = @import("log.zig");
1010
const assert = std.debug.assert;
1111
const Bytes = cc.Bytes;
1212

1313
/// LRU
14-
var _list: ListNode = undefined;
14+
var _list: Node = undefined;
1515

1616
pub fn module_init() void {
1717
_list.init();

‎src/co.zig

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@ const g = @import("g.zig");
44
const assert = std.debug.assert;
55

66
/// create and start a new coroutine
7-
pub fn create(comptime func: anytype, args: anytype) void {
7+
pub fn start(comptime func: anytype, args: anytype) void {
88
const buf = g.allocator.alignedAlloc(u8, std.Target.stack_align, @frameSize(func)) catch unreachable;
99
_ = @asyncCall(buf, {}, func, args);
10-
// @call(.{ .modifier = .async_kw, .stack = buf }, func, args);
1110
check_terminated();
1211
}
1312

14-
/// if the coroutine is at the last pause point, its memory will be freed after resume
13+
/// if the coroutine is at the last suspend point, its memory will be freed after resume
1514
pub fn do_resume(frame: anyframe) void {
1615
resume frame;
1716
check_terminated();

‎src/dns.c

+6-6
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,8 @@ static bool get_ttl(struct dns_record *noalias record, int rnamelen, void *ud, b
481481
if (ntohs(record->rtype) != DNS_TYPE_OPT) {
482482
/* it is hereby specified that a TTL value is an unsigned number,
483483
with a minimum value of 0, and a maximum value of 2147483647. */
484-
s32 ttl = ntohl(record->rttl);
485-
s32 *final_ttl = ud;
484+
i32 ttl = ntohl(record->rttl);
485+
i32 *final_ttl = ud;
486486
if (ttl < *final_ttl)
487487
*final_ttl = ttl;
488488
}
@@ -497,21 +497,21 @@ static bool update_ttl(struct dns_record *noalias record, int rnamelen, void *ud
497497
if (ntohs(record->rtype) != DNS_TYPE_OPT) {
498498
/* it is hereby specified that a TTL value is an unsigned number,
499499
with a minimum value of 0, and a maximum value of 2147483647. */
500-
s32 ttl = (s32)ntohl(record->rttl) + (intptr_t)ud;
500+
i32 ttl = (i32)ntohl(record->rttl) + (intptr_t)ud;
501501
record->rttl = htonl(max(ttl, 1));
502502
}
503503

504504
return true;
505505
}
506506

507-
s32 dns_get_ttl(const void *noalias msg, ssize_t len, int qnamelen, s32 nodata_ttl) {
507+
i32 dns_get_ttl(const void *noalias msg, ssize_t len, int qnamelen, i32 nodata_ttl) {
508508
if (!is_normal_msg(msg))
509509
return -1;
510510

511511
int count = get_records_count(msg);
512512
move_to_records(msg, len, qnamelen);
513513

514-
s32 ttl = INT32_MAX;
514+
i32 ttl = INT32_MAX;
515515

516516
unlikely_if (!foreach_record((void **)&msg, &len, count, get_ttl, &ttl))
517517
ttl = -1;
@@ -522,7 +522,7 @@ s32 dns_get_ttl(const void *noalias msg, ssize_t len, int qnamelen, s32 nodata_t
522522
return ttl;
523523
}
524524

525-
void dns_update_ttl(void *noalias msg, ssize_t len, int qnamelen, s32 ttl_change) {
525+
void dns_update_ttl(void *noalias msg, ssize_t len, int qnamelen, i32 ttl_change) {
526526
int count = get_records_count(msg);
527527
move_to_records(msg, len, qnamelen);
528528

‎src/dns.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ int dns_test_ip(const void *noalias msg, ssize_t len, int qnamelen, const struct
8989
void dns_add_ip(const void *noalias msg, ssize_t len, int qnamelen, struct ipset_addctx *noalias ctx);
9090

9191
/* return -1 if failed */
92-
s32 dns_get_ttl(const void *noalias msg, ssize_t len, int qnamelen, s32 nodata_ttl);
92+
i32 dns_get_ttl(const void *noalias msg, ssize_t len, int qnamelen, i32 nodata_ttl);
9393

9494
/* it should not fail because it has been checked by `get_ttl` */
95-
void dns_update_ttl(void *noalias msg, ssize_t len, int qnamelen, s32 ttl_change);
95+
void dns_update_ttl(void *noalias msg, ssize_t len, int qnamelen, i32 ttl_change);
9696

9797
/*
9898
* `levels`: the level of the domain to get (8 bools)

‎src/main.zig

+9-19
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ pub fn main() u8 {
135135
// ============================================================================
136136

137137
// used only for business-independent initialization, such as global variables
138-
init_all_module();
139-
defer if (_debug) deinit_all_module();
138+
call_module_fn("module_init", .{});
139+
defer if (_debug) call_module_fn("module_deinit", .{});
140140

141141
// ============================================================================
142142

@@ -214,33 +214,23 @@ pub fn main() u8 {
214214
server.start();
215215

216216
if (_debug)
217-
co.create(memleak_checker, .{});
217+
co.start(memleak_checker, .{});
218218

219219
g.evloop.run();
220220

221221
return 0;
222222
}
223223

224-
fn init_all_module() void {
224+
fn call_module_fn(comptime fn_name: [:0]const u8, args: anytype) void {
225225
comptime var i = 0;
226226
inline while (i < modules.module_list.len) : (i += 1) {
227227
const module = modules.module_list[i];
228228
const module_name: cc.ConstStr = modules.name_list[i];
229-
if (@hasDecl(module, "module_init")) {
230-
if (false) log.debug(@src(), "%s.module_init()", .{module_name});
231-
module.module_init();
232-
}
233-
}
234-
}
235-
236-
fn deinit_all_module() void {
237-
comptime var i = 0;
238-
inline while (i < modules.module_list.len) : (i += 1) {
239-
const module = modules.module_list[i];
240-
const module_name: cc.ConstStr = modules.name_list[i];
241-
if (@hasDecl(module, "module_deinit")) {
242-
if (false) log.debug(@src(), "%s.module_deinit()", .{module_name});
243-
module.module_deinit();
229+
if (@hasDecl(module, fn_name)) {
230+
if (false) log.debug(@src(), "%s.%s()", .{ module_name, fn_name.ptr });
231+
const options: std.builtin.CallOptions = .{};
232+
const func = @field(module, fn_name);
233+
@call(options, func, args);
244234
}
245235
}
246236
}

‎src/misc.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,22 @@ typedef uint16_t u16;
3434
typedef uint32_t u32;
3535
typedef uint64_t u64;
3636

37-
typedef int8_t s8;
38-
typedef int16_t s16;
39-
typedef int32_t s32;
40-
typedef int64_t s64;
37+
typedef int8_t i8;
38+
typedef int16_t i16;
39+
typedef int32_t i32;
40+
typedef int64_t i64;
4141

4242
#define U8C UINT8_C
4343
#define U16C UINT16_C
4444
#define U32C UINT32_C
4545
#define U64C UINT64_C
4646

47-
#define S8C INT8_C
48-
#define S16C INT16_C
49-
#define S32C INT32_C
50-
#define S64C INT64_C
47+
#define I8C INT8_C
48+
#define I16C INT16_C
49+
#define I32C INT32_C
50+
#define I64C INT64_C
5151

52-
// typedef signed char byte; /* >= 8 bits */
52+
typedef signed char ibyte; /* >= 8 bits */
5353
typedef unsigned char ubyte; /* >= 8 bits */
5454
typedef unsigned short ushort; /* >= 16 bits */
5555
typedef unsigned int uint; /* >= 16 bits */
@@ -118,7 +118,7 @@ typedef u8 bitvec_t;
118118

119119
/* align to `n` (struct, struct member) */
120120
#define struct_alignto(n) \
121-
__attribute__((packed,aligned(n)))
121+
__attribute__((packed, aligned(n)))
122122

123123
/* ======================================================== */
124124

‎src/modules.zig

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
pub const name_list = .{ "CacheMsg", "DynStr", "EvLoop", "ListNode", "NoAAAA", "Rc", "RcMsg", "StrList", "Upstream", "c", "cache", "cache_ignore", "cc", "co", "dnl", "dns", "flags_op", "fmtchk", "g", "groups", "ipset", "local_rr", "log", "main", "modules", "net", "opt", "sentinel_vector", "server", "str2int", "tag", "tests", "verdict_cache" };
2-
pub const module_list = .{ CacheMsg, DynStr, EvLoop, ListNode, NoAAAA, Rc, RcMsg, StrList, Upstream, c, cache, cache_ignore, cc, co, dnl, dns, flags_op, fmtchk, g, groups, ipset, local_rr, log, main, modules, net, opt, sentinel_vector, server, str2int, tag, tests, verdict_cache };
1+
pub const name_list = .{ "CacheMsg", "DynStr", "EvLoop", "NoAAAA", "Node", "Rc", "RcMsg", "StrList", "Upstream", "c", "cache", "cache_ignore", "cc", "co", "dnl", "dns", "flags_op", "fmtchk", "g", "groups", "ipset", "local_rr", "log", "main", "modules", "net", "opt", "sentinel_vector", "server", "str2int", "tag", "tests", "verdict_cache" };
2+
pub const module_list = .{ CacheMsg, DynStr, EvLoop, NoAAAA, Node, Rc, RcMsg, StrList, Upstream, c, cache, cache_ignore, cc, co, dnl, dns, flags_op, fmtchk, g, groups, ipset, local_rr, log, main, modules, net, opt, sentinel_vector, server, str2int, tag, tests, verdict_cache };
33

44
const CacheMsg = @import("CacheMsg.zig");
55
const DynStr = @import("DynStr.zig");
66
const EvLoop = @import("EvLoop.zig");
7-
const ListNode = @import("ListNode.zig");
87
const NoAAAA = @import("NoAAAA.zig");
8+
const Node = @import("Node.zig");
99
const Rc = @import("Rc.zig");
1010
const RcMsg = @import("RcMsg.zig");
1111
const StrList = @import("StrList.zig");

‎src/server.zig

+21-17
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ const Upstream = @import("Upstream.zig");
1414
const NoAAAA = @import("NoAAAA.zig");
1515
const EvLoop = @import("EvLoop.zig");
1616
const RcMsg = @import("RcMsg.zig");
17-
const ListNode = @import("ListNode.zig");
17+
const Node = @import("Node.zig");
1818
const flags_op = @import("flags_op.zig");
1919
const verdict_cache = @import("verdict_cache.zig");
2020
const local_rr = @import("local_rr.zig");
@@ -28,7 +28,7 @@ comptime {
2828

2929
const QueryCtx = struct {
3030
// linked list
31-
list_node: ListNode = undefined,
31+
list_node: Node = undefined,
3232

3333
// alignment: 8/4
3434
fdobj: *EvLoop.Fd, // requester's fdobj
@@ -87,13 +87,13 @@ const QueryCtx = struct {
8787
g.allocator.destroy(self);
8888
}
8989

90-
pub fn from_list_node(node: *ListNode) *QueryCtx {
90+
pub fn from_list_node(node: *Node) *QueryCtx {
9191
return @fieldParentPtr(QueryCtx, "list_node", node);
9292
}
9393

9494
pub const List = struct {
9595
map: std.AutoHashMapUnmanaged(u16, *QueryCtx),
96-
list: ListNode,
96+
list: Node,
9797

9898
var _last_qid: u16 = 0;
9999

@@ -186,7 +186,7 @@ fn listen_tcp(fd: c_int, ip: cc.ConstStr, port: u16) void {
186186
continue;
187187
};
188188
net.setup_tcp_conn_sock(conn_fd);
189-
co.create(service_tcp, .{ conn_fd, &src_addr });
189+
co.start(service_tcp, .{ conn_fd, &src_addr });
190190
}
191191
}
192192

@@ -215,9 +215,9 @@ fn service_tcp(fd: c_int, p_src_addr: *const cc.SockAddr) void {
215215
while (true) {
216216
// read len (be16)
217217
var len: u16 = undefined;
218-
g.evloop.recv_exactly(fdobj, std.mem.asBytes(&len), 0) catch |err| switch (err) {
218+
g.evloop.recv_full(fdobj, std.mem.asBytes(&len), 0) catch |err| switch (err) {
219219
error.eof => return,
220-
error.other => break :e .{ .op = "read_len" },
220+
error.errno => break :e .{ .op = "read_len" },
221221
};
222222

223223
len = cc.ntohs(len);
@@ -238,9 +238,9 @@ fn service_tcp(fd: c_int, p_src_addr: *const cc.SockAddr) void {
238238

239239
// read msg
240240
qmsg.len = len;
241-
g.evloop.recv_exactly(fdobj, qmsg.msg(), 0) catch |err| switch (err) {
241+
g.evloop.recv_full(fdobj, qmsg.msg(), 0) catch |err| switch (err) {
242242
error.eof => break :e .{ .op = "read_msg", .msg = "connection closed" },
243-
error.other => break :e .{ .op = "read_msg" },
243+
error.errno => break :e .{ .op = "read_msg" },
244244
};
245245

246246
on_query(qmsg, fdobj, &src_addr, .from_tcp);
@@ -477,8 +477,10 @@ fn on_query(qmsg: *RcMsg, fdobj: *EvLoop.Fd, src_addr: *const cc.SockAddr, in_qf
477477

478478
dns.make_reply(rmsg, msg, qnamelen, answer, answer_n);
479479

480-
// [async func]
481-
if (g.verbose()) qlog.local_rr(answer_n, answer.len);
480+
if (g.verbose())
481+
qlog.local_rr(answer_n, answer.len);
482+
483+
// suspending
482484
return send_reply(rmsg, fdobj, src_addr, bufsz, id, qflags);
483485
}
484486

@@ -493,7 +495,7 @@ fn on_query(qmsg: *RcMsg, fdobj: *EvLoop.Fd, src_addr: *const cc.SockAddr, in_qf
493495
var ttl_r: i32 = undefined;
494496
var add_ip: bool = undefined;
495497
if (cache.get(msg, qnamelen, &ttl, &ttl_r, &add_ip)) |cache_msg| {
496-
// because send_reply is async func
498+
// because send_reply is a suspending func
497499
cache.ref(cache_msg);
498500
defer cache.unref(cache_msg);
499501

@@ -507,6 +509,7 @@ fn on_query(qmsg: *RcMsg, fdobj: *EvLoop.Fd, src_addr: *const cc.SockAddr, in_qf
507509
}
508510
}
509511

512+
// suspending
510513
send_reply(cache_msg, fdobj, src_addr, bufsz, id, qflags);
511514

512515
if (ttl > ttl_r)
@@ -768,7 +771,8 @@ pub fn on_reply(rmsg: *RcMsg, upstream: *const Upstream) void {
768771
dns.add_ip(msg, qnamelen, addctx);
769772
};
770773

771-
// [async] send reply to client
774+
// [suspending] send reply to client
775+
// TODO: change to nosuspend send_reply
772776
if (!qctx.flags.has(.from_local))
773777
send_reply(msg, qctx.fdobj, &qctx.src_addr, qctx.bufsz, qctx.id, qctx.flags);
774778

@@ -782,7 +786,7 @@ pub fn on_reply(rmsg: *RcMsg, upstream: *const Upstream) void {
782786
qctx.free();
783787
}
784788

785-
/// [async]
789+
/// [suspending]
786790
fn send_reply(msg: []const u8, fdobj: *EvLoop.Fd, src_addr: *const cc.SockAddr, bufsz: u16, id: c.be16, qflags: QueryCtx.Flags) void {
787791
var iov = [_]cc.iovec_t{
788792
undefined, // for tcp
@@ -839,7 +843,7 @@ fn send_reply(msg: []const u8, fdobj: *EvLoop.Fd, src_addr: *const cc.SockAddr,
839843
);
840844
}
841845

842-
/// [async] for bad query msg
846+
/// [suspending] for bad query msg
843847
fn send_reply_xxx(msg: []u8, fdobj: *EvLoop.Fd, src_addr: *const cc.SockAddr, qflags: QueryCtx.Flags) void {
844848
if (msg.len >= dns.header_len())
845849
_ = dns.empty_reply(msg, 0);
@@ -922,10 +926,10 @@ noinline fn do_start(ip: cc.ConstStr, port: u16, socktype: net.SockType) void {
922926
switch (socktype) {
923927
.tcp => {
924928
cc.listen(fd, 1024) orelse break :e "listen";
925-
co.create(listen_tcp, .{ fd, ip, port });
929+
co.start(listen_tcp, .{ fd, ip, port });
926930
},
927931
.udp => {
928-
co.create(listen_udp, .{ fd, ip, port });
932+
co.start(listen_udp, .{ fd, ip, port });
929933
},
930934
}
931935
return;

‎src/str2int.zig

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ const std = @import("std");
22

33
pub fn parse(comptime T: type, str: []const u8, radix: u8) ?T {
44
if (@bitSizeOf(T) < 1 or @bitSizeOf(T) > 64)
5-
@compileError("expected i1..i64 or s1..s64, found " ++ @typeName(T));
5+
@compileError("expected i1..i64 or u1..u64, found " ++ @typeName(T));
66

77
return @intCast(T, parse_internal(
88
if (comptime std.meta.trait.isSignedInt(T)) i64 else u64,

‎src/wolfssl_opt.h

+2
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
#define WOLFSSL_NO_ATOMICS
55
#define WOLFSSL_AEAD_ONLY
66
#define LARGE_STATIC_BUFFERS
7+
#define WOLFSSL_JNI /* ssl->data (void *) */
8+
#define HAVE_EXT_CACHE /* new session callback */

‎tool/dns_cache_mgr.c

+8-7
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
#include "../src/misc.h"
1010

1111
struct header {
12-
s64 update_time;
12+
i64 update_time;
1313
u32 hashv;
14-
s32 ttl;
15-
s32 ttl_r;
14+
i32 ttl;
15+
i32 ttl_r;
1616
u16 msg_len;
1717
u8 qnamelen;
1818
// msg: [msg_len]u8, // {header, question, answer, authority, additional}
@@ -55,13 +55,14 @@ static void list(FILE *file) {
5555
char buf[sizeof(*h)] alignto(__alignof__(*h));
5656
h = (void *)buf;
5757

58-
char name[DNS_NAME_MAXLEN + 1];
59-
6058
void *msg = malloc(DNS_MSG_MAXSIZE);
59+
char name[DNS_NAME_MAXLEN + 1];
6160

62-
s64 now = time(NULL);
61+
i64 now = time(NULL);
6362
while (next(file, h, msg, name))
64-
printf("%-60s qtype:%-5u ttl:%-10d size:%u\n", name, dns_get_qtype(msg, h->qnamelen), h->ttl - (s32)(now - h->update_time), h->msg_len);
63+
printf("%-60s qtype:%-5u ttl:%-10d size:%u\n",
64+
name, dns_get_qtype(msg, h->qnamelen),
65+
h->ttl - (i32)(now - h->update_time), h->msg_len);
6566

6667
free(msg);
6768
}

0 commit comments

Comments
 (0)
Please sign in to comment.