diff --git a/bee/net/event.cpp b/bee/net/event.cpp index 115734c6..c98c7a69 100644 --- a/bee/net/event.cpp +++ b/bee/net/event.cpp @@ -17,7 +17,7 @@ namespace bee::net { bool event::open() noexcept { if (pipe[0] != retired_fd) return false; - return socket::pair(pipe, socket::fd_flags::none); + return socket::pair(pipe, socket::fd_flags::nonblock); } void event::set() noexcept { @@ -30,7 +30,7 @@ namespace bee::net { socket::send(pipe[1], rc, tmp, sizeof(tmp)); } - void event::wait() noexcept { + void event::clear() noexcept { char tmp[128]; int rc = 0; for (;;) { @@ -48,7 +48,7 @@ namespace bee::net { e.clear(std::memory_order_seq_cst); } - fd_t event::fd() noexcept { + fd_t event::fd() const noexcept { return pipe[0]; } } diff --git a/bee/net/event.h b/bee/net/event.h index 3b65fe66..586b1fec 100644 --- a/bee/net/event.h +++ b/bee/net/event.h @@ -10,7 +10,7 @@ namespace bee::net { ~event() noexcept; bool open() noexcept; void set() noexcept; - void wait() noexcept; - fd_t fd() noexcept; + void clear() noexcept; + fd_t fd() const noexcept; }; } diff --git a/binding/lua_channel.cpp b/binding/lua_channel.cpp new file mode 100644 index 00000000..85998b59 --- /dev/null +++ b/binding/lua_channel.cpp @@ -0,0 +1,171 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +extern "C" { +#include <3rd/lua-seri/lua-seri.h> +} + +namespace bee::lua_channel { + class channel { + public: + using box = std::shared_ptr; + using value_type = void*; + + bool init() noexcept { + if (!ev.open()) { + return false; + } + return true; + } + net::fd_t fd() const noexcept { + return ev.fd(); + } + void push(value_type data) noexcept { + std::unique_lock lk(mutex); + queue.push(data); + ev.set(); + } + bool pop(value_type& data) noexcept { + std::unique_lock lk(mutex); + if (queue.empty()) { + return false; + } + data = queue.front(); + queue.pop(); + return true; + } + + private: + std::queue queue; + spinlock mutex; + net::event ev; + }; + + class channelmgr { + public: + channel::box create(zstring_view name) noexcept { + std::unique_lock lk(mutex); + channel* c = new channel; + if (!c->init()) { + return nullptr; + } + std::string namestr { name.data(), name.size() }; + auto [r, ok] = channels.emplace(namestr, channel::box { c }); + if (!ok) { + return nullptr; + } + return r->second; + } + void clear() noexcept { + std::unique_lock lk(mutex); + channels.clear(); + } + channel::box query(zstring_view name) noexcept { + std::unique_lock lk(mutex); + std::string namestr { name.data(), name.size() }; + auto it = channels.find(namestr); + if (it != channels.end()) { + return it->second; + } + return nullptr; + } + + private: + std::map channels; + spinlock mutex; + }; + + static channelmgr g_channel; + + static int lchannel_push(lua_State* L) { + auto& bc = lua::checkudata(L, 1); + void* buffer = seri_pack(L, 1, NULL); + bc->push(buffer); + return 0; + } + + static int lchannel_pop(lua_State* L) { + auto& bc = lua::checkudata(L, 1); + void* data; + if (!bc->pop(data)) { + lua_pushboolean(L, 0); + return 1; + } + lua_pushboolean(L, 1); + return 1 + seri_unpackptr(L, data); + } + + static int lchannel_fd(lua_State* L) { + auto& bc = lua::checkudata(L, 1); + lua_pushlightuserdata(L, (void*)(intptr_t)bc->fd()); + return 1; + } + + static void metatable(lua_State* L) { + luaL_Reg lib[] = { + { "push", lchannel_push }, + { "pop", lchannel_pop }, + { "fd", lchannel_fd }, + { NULL, NULL }, + }; + luaL_newlibtable(L, lib); + luaL_setfuncs(L, lib, 0); + lua_setfield(L, -2, "__index"); + } + + static int lcreate(lua_State* L) { + auto name = lua::checkstrview(L, 1); + channel::box c = g_channel.create(name); + if (!c) { + return luaL_error(L, "Duplicate channel '%s'", name.data()); + } + lua::newudata(L, c); + return 1; + } + + static int lquery(lua_State* L) { + auto name = lua::checkstrview(L, 1); + channel::box c = g_channel.query(name); + if (!c) { + return luaL_error(L, "Can't query channel '%s'", name.data()); + } + lua::newudata(L, c); + return 1; + } + + static int lreset(lua_State* L) { + g_channel.clear(); + return 0; + } + + static int luaopen(lua_State* L) { + luaL_Reg lib[] = { + { "create", lcreate }, + { "query", lquery }, + { "reset", lreset }, + { NULL, NULL }, + }; + luaL_newlibtable(L, lib); + luaL_setfuncs(L, lib, 0); + return 1; + } +} + +DEFINE_LUAOPEN(channel) + +namespace bee::lua { + template <> + struct udata { + static inline auto name = "bee::channel"; + static inline auto metatable = bee::lua_channel::metatable; + }; +} diff --git a/binding/lua_epoll.cpp b/binding/lua_epoll.cpp index 98a7dce8..9497cd25 100644 --- a/binding/lua_epoll.cpp +++ b/binding/lua_epoll.cpp @@ -33,8 +33,15 @@ namespace bee::lua_epoll { }; static net::fd_t ep_tofd(lua_State *L, int idx) { - luaL_checktype(L, idx, LUA_TUSERDATA); - return lua::toudata(L, idx); + switch (lua_type(L, idx)) { + case LUA_TLIGHTUSERDATA: + return lua::tolightud(L, idx); + case LUA_TUSERDATA: + return lua::toudata(L, idx); + default: + luaL_checktype(L, idx, LUA_TUSERDATA); + std::unreachable(); + } } static int ep_events(lua_State *L) { diff --git a/binding/lua_select.cpp b/binding/lua_select.cpp index c1acbedd..e233f565 100644 --- a/binding/lua_select.cpp +++ b/binding/lua_select.cpp @@ -218,9 +218,20 @@ namespace bee::lua_select { ctx.writeset.clear(); return 0; } + static net::fd_t tofd(lua_State* L, int idx) { + switch (lua_type(L, 1)) { + case LUA_TLIGHTUSERDATA: + return lua::tolightud(L, idx); + case LUA_TUSERDATA: + return lua::toudata(L, idx); + default: + luaL_checktype(L, idx, LUA_TUSERDATA); + std::unreachable(); + } + } static int event_add(lua_State* L) { auto& ctx = lua::checkudata(L, 1); - auto fd = lua::checkudata(L, 2); + auto fd = tofd(L, 2); auto events = luaL_checkinteger(L, 3); storeref(L, fd); if (events & SELECT_READ) { @@ -238,7 +249,7 @@ namespace bee::lua_select { } static int event_mod(lua_State* L) { auto& ctx = lua::checkudata(L, 1); - auto fd = lua::checkudata(L, 2); + auto fd = tofd(L, 2); auto events = luaL_checkinteger(L, 3); if (events & SELECT_READ) { ctx.readset.insert(fd); @@ -255,7 +266,7 @@ namespace bee::lua_select { } static int event_del(lua_State* L) { auto& ctx = lua::checkudata(L, 1); - auto fd = lua::checkudata(L, 2); + auto fd = tofd(L, 2); cleanref(L, fd); ctx.readset.erase(fd); ctx.writeset.erase(fd); diff --git a/binding/lua_thread.cpp b/binding/lua_thread.cpp index a05da43c..036c9a61 100644 --- a/binding/lua_thread.cpp +++ b/binding/lua_thread.cpp @@ -330,7 +330,7 @@ DEFINE_LUAOPEN(thread) namespace bee::lua { template <> struct udata { - static inline auto name = "bee::channel"; + static inline auto name = "bee::legacy_channel"; static inline auto metatable = bee::lua_thread::channel_metatable; }; } diff --git a/test/test.lua b/test/test.lua index a442b65d..2cc9cba9 100644 --- a/test/test.lua +++ b/test/test.lua @@ -37,6 +37,7 @@ require "test_socket" require "test_epoll" require "test_filewatch" require "test_time" +require "test_channel" do local fs = require "bee.filesystem" diff --git a/test/test_channel.lua b/test/test_channel.lua new file mode 100644 index 00000000..7c710bd3 --- /dev/null +++ b/test/test_channel.lua @@ -0,0 +1,209 @@ +local lt = require "ltest" + +local thread = require "bee.thread" +local channel = require "bee.channel" +local epoll = require "bee.epoll" + +local function assertNotThreadError() + lt.assertEquals(thread.errlog(), false) +end + +local test_channel = lt.test "channel" + +function test_channel:test_channel_create() + channel.reset() + lt.assertErrorMsgEquals("Can't query channel 'test'", channel.query, "test") + channel.create "test" + lt.assertIsUserdata(channel.query "test") + lt.assertIsUserdata(channel.query "test") + channel.reset() + channel.create "test" + lt.assertErrorMsgEquals("Duplicate channel 'test'", channel.create, "test") + channel.reset() +end + +function test_channel:test_reset_1() + channel.reset() + lt.assertErrorMsgEquals("Can't query channel 'test'", channel.query, "test") + channel.create "test" + lt.assertIsUserdata(channel.query "test") + channel.reset() + lt.assertErrorMsgEquals("Can't query channel 'test'", channel.query, "test") + channel.create "test" + lt.assertIsUserdata(channel.query "test") + channel.reset() +end + +local function TestSuit(f) + f(1) + f(0.0001) + f("TEST") + f(true) + f(false) + f({}) + f({ 1, 2 }) + f(1, { 1, 2 }) + f(1, 2, { A = { B = { C = "D" } } }) + f(1, nil, 2) +end + +function test_channel:test_pop_1() + channel.reset() + local chan = channel.create "test" + local function pack_pop(ok, ...) + lt.assertEquals(ok, true) + return table.pack(...) + end + local function test_ok(...) + chan:push(...) + lt.assertEquals(pack_pop(chan:pop()), table.pack(...)) + end + TestSuit(test_ok) + -- 基本和serialization的测试重复,所以failed就不测了 +end + +function test_channel:test_pop_2() + channel.reset() + local chan = channel.create "test" + + local function assertIs(expected) + local ok, v = chan:pop() + lt.assertEquals(ok, true) + lt.assertEquals(v, expected) + end + local function assertEmpty() + local ok, v = chan:pop() + lt.assertEquals(ok, false) + lt.assertEquals(v, nil) + end + + assertEmpty() + + chan:push(1024) + assertIs(1024) + assertEmpty() + + chan:push(1024) + chan:push(1025) + chan:push(1026) + assertIs(1024) + assertIs(1025) + assertIs(1026) + assertEmpty() + + chan:push(1024) + chan:push(1025) + assertIs(1024) + chan:push(1026) + assertIs(1025) + assertIs(1026) + assertEmpty() + + channel.reset() +end + +function test_channel:test_pop_3() + channel.reset() + assertNotThreadError() + thread.reset() + local req = channel.create "testReq" + local res = channel.create "testRes" + local thd = thread.thread [[ + local thread = require "bee.thread" + local channel = require "bee.channel" + local req = channel.query "testReq" + local res = channel.query "testRes" + local function dispatch(ok, what, ...) + if not ok then + return + end + if what == "exit" then + return true + end + res:push(what, ...) + end + while not dispatch(req:pop()) do + thread.sleep(0) + end + ]] + local function pack_pop(ok, ...) + if not ok then + return + end + return table.pack(...) + end + local function test_ok(...) + req:push(...) + local t + while true do + t = pack_pop(res:pop()) + if t then + break + end + thread.sleep(0) + end + lt.assertEquals(t, table.pack(...)) + end + TestSuit(test_ok) + req:push "exit" + thread.wait(thd) + assertNotThreadError() +end + +function test_channel:test_fd() + channel.reset() + assertNotThreadError() + thread.reset() + local req = channel.create "testReq" + local res = channel.create "testRes" + local thd = thread.thread [[ + local thread = require "bee.thread" + local channel = require "bee.channel" + local epoll = require "bee.epoll" + local req = channel.query "testReq" + local res = channel.query "testRes" + local epfd = epoll.create(16) + epfd:event_add(req:fd(), epoll.EPOLLIN) + local function dispatch(ok, what, ...) + if not ok then + return true + end + if what == "exit" then + os.exit() + return + end + res:push(what, ...) + end + for _, event in epfd:wait() do + if event & (epoll.EPOLLERR | epoll.EPOLLHUP) ~= 0 then + assert(false, "unknown error") + return + end + if event & epoll.EPOLLIN ~= 0 then + while not dispatch(req:pop()) do + end + end + end + ]] + local epfd = epoll.create(16) + epfd:event_add(res:fd(), epoll.EPOLLIN) + local function test_ok(...) + req:push(...) + for _, event in epfd:wait() do + if event & (epoll.EPOLLERR | epoll.EPOLLHUP) ~= 0 then + lt.failure("unknown error") + end + if event & epoll.EPOLLIN ~= 0 then + local r = table.pack(res:pop()) + if r[1] == true then + lt.assertEquals(r, table.pack(true, ...)) + break + end + end + end + end + TestSuit(test_ok) + req:push "exit" + thread.wait(thd) + assertNotThreadError() +end